// Copyright 2017 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pubsub

// This file provides a mock in-memory pubsub server for streaming pull testing.

import (
	"context"
	"io"
	"sync"
	"time"

	"cloud.google.com/go/internal/testutil"
	pb "google.golang.org/genproto/googleapis/pubsub/v1"
	"google.golang.org/protobuf/types/known/emptypb"
)

type mockServer struct {
	srv *testutil.Server

	pb.SubscriberServer

	Addr string

	mu            sync.Mutex
	Acked         map[string]bool  // acked message IDs
	Deadlines     map[string]int32 // deadlines by message ID
	pullResponses []*pullResponse
	ackErrs       []error
	modAckErrs    []error
	wg            sync.WaitGroup
	sub           *pb.Subscription
}

type pullResponse struct {
	msgs []*pb.ReceivedMessage
	err  error
}

func newMockServer(port int) (*mockServer, error) {
	srv, err := testutil.NewServerWithPort(port)
	if err != nil {
		return nil, err
	}
	mock := &mockServer{
		srv:       srv,
		Addr:      srv.Addr,
		Acked:     map[string]bool{},
		Deadlines: map[string]int32{},
		sub: &pb.Subscription{
			AckDeadlineSeconds: 10,
			PushConfig:         &pb.PushConfig{},
		},
	}
	pb.RegisterSubscriberServer(srv.Gsrv, mock)
	srv.Start()
	return mock, nil
}

// Each call to addStreamingPullMessages results in one StreamingPullResponse.
func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
	s.mu.Lock()
	s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
	s.mu.Unlock()
}

func (s *mockServer) addStreamingPullError(err error) {
	s.mu.Lock()
	s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
	s.mu.Unlock()
}

func (s *mockServer) addAckResponse(err error) {
	s.mu.Lock()
	s.ackErrs = append(s.ackErrs, err)
	s.mu.Unlock()
}

func (s *mockServer) addModAckResponse(err error) {
	s.mu.Lock()
	s.modAckErrs = append(s.modAckErrs, err)
	s.mu.Unlock()
}

func (s *mockServer) wait() {
	s.wg.Wait()
}

func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
	s.wg.Add(1)
	defer s.wg.Done()
	errc := make(chan error, 1)
	s.wg.Add(1)
	go func() {
		defer s.wg.Done()
		for {
			req, err := stream.Recv()
			if err != nil {
				errc <- err
				return
			}
			s.mu.Lock()
			for _, id := range req.AckIds {
				s.Acked[id] = true
			}
			for i, id := range req.ModifyDeadlineAckIds {
				s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
			}
			s.mu.Unlock()
		}
	}()
	// Send responses.
	for {
		s.mu.Lock()
		if len(s.pullResponses) == 0 {
			s.mu.Unlock()
			// Nothing to send, so wait for the client to shut down the stream.
			err := <-errc // a real error, or at least EOF
			if err == io.EOF {
				return nil
			}
			return err
		}
		pr := s.pullResponses[0]
		s.pullResponses = s.pullResponses[1:]
		s.mu.Unlock()
		if pr.err != nil {
			// Add a slight delay to ensure the server receives any
			// messages en route from the client before shutting down the stream.
			// This reduces flakiness of tests involving retry.
			time.Sleep(200 * time.Millisecond)
		}
		if pr.err == io.EOF {
			return nil
		}
		if pr.err != nil {
			return pr.err
		}
		// Return any error from Recv.
		select {
		case err := <-errc:
			return err
		default:
		}
		res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
		if err := stream.Send(res); err != nil {
			return err
		}
	}
}

func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
	var err error
	s.mu.Lock()
	if len(s.ackErrs) > 0 {
		err = s.ackErrs[0]
		s.ackErrs = s.ackErrs[1:]
	}
	if err != nil {
		s.mu.Unlock()
		return nil, err
	}
	for _, id := range req.AckIds {
		s.Acked[id] = true
	}
	s.mu.Unlock()
	return &emptypb.Empty{}, nil
}

func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
	var err error
	s.mu.Lock()
	if len(s.modAckErrs) > 0 {
		err = s.modAckErrs[0]
		s.modAckErrs = s.modAckErrs[1:]
	}
	if err != nil {
		s.mu.Unlock()
		return nil, err
	}
	for _, id := range req.AckIds {
		s.Deadlines[id] = req.AckDeadlineSeconds
	}
	s.mu.Unlock()
	return &emptypb.Empty{}, nil
}

func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
	return s.sub, nil
}
