| // Copyright 2017 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package pubsub |
| |
| // This file provides a mock in-memory pubsub server for streaming pull testing. |
| |
| import ( |
| "context" |
| "io" |
| "sync" |
| "time" |
| |
| "cloud.google.com/go/internal/testutil" |
| emptypb "github.com/golang/protobuf/ptypes/empty" |
| pb "google.golang.org/genproto/googleapis/pubsub/v1" |
| ) |
| |
| type mockServer struct { |
| srv *testutil.Server |
| |
| pb.SubscriberServer |
| |
| Addr string |
| |
| mu sync.Mutex |
| Acked map[string]bool // acked message IDs |
| Deadlines map[string]int32 // deadlines by message ID |
| pullResponses []*pullResponse |
| ackErrs []error |
| modAckErrs []error |
| wg sync.WaitGroup |
| sub *pb.Subscription |
| } |
| |
| type pullResponse struct { |
| msgs []*pb.ReceivedMessage |
| err error |
| } |
| |
| func newMockServer(port int) (*mockServer, error) { |
| srv, err := testutil.NewServerWithPort(port) |
| if err != nil { |
| return nil, err |
| } |
| mock := &mockServer{ |
| srv: srv, |
| Addr: srv.Addr, |
| Acked: map[string]bool{}, |
| Deadlines: map[string]int32{}, |
| sub: &pb.Subscription{ |
| AckDeadlineSeconds: 10, |
| PushConfig: &pb.PushConfig{}, |
| }, |
| } |
| pb.RegisterSubscriberServer(srv.Gsrv, mock) |
| srv.Start() |
| return mock, nil |
| } |
| |
| // Each call to addStreamingPullMessages results in one StreamingPullResponse. |
| func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) { |
| s.mu.Lock() |
| s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil}) |
| s.mu.Unlock() |
| } |
| |
| func (s *mockServer) addStreamingPullError(err error) { |
| s.mu.Lock() |
| s.pullResponses = append(s.pullResponses, &pullResponse{nil, err}) |
| s.mu.Unlock() |
| } |
| |
| func (s *mockServer) addAckResponse(err error) { |
| s.mu.Lock() |
| s.ackErrs = append(s.ackErrs, err) |
| s.mu.Unlock() |
| } |
| |
| func (s *mockServer) addModAckResponse(err error) { |
| s.mu.Lock() |
| s.modAckErrs = append(s.modAckErrs, err) |
| s.mu.Unlock() |
| } |
| |
| func (s *mockServer) wait() { |
| s.wg.Wait() |
| } |
| |
| func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error { |
| s.wg.Add(1) |
| defer s.wg.Done() |
| errc := make(chan error, 1) |
| s.wg.Add(1) |
| go func() { |
| defer s.wg.Done() |
| for { |
| req, err := stream.Recv() |
| if err != nil { |
| errc <- err |
| return |
| } |
| s.mu.Lock() |
| for _, id := range req.AckIds { |
| s.Acked[id] = true |
| } |
| for i, id := range req.ModifyDeadlineAckIds { |
| s.Deadlines[id] = req.ModifyDeadlineSeconds[i] |
| } |
| s.mu.Unlock() |
| } |
| }() |
| // Send responses. |
| for { |
| s.mu.Lock() |
| if len(s.pullResponses) == 0 { |
| s.mu.Unlock() |
| // Nothing to send, so wait for the client to shut down the stream. |
| err := <-errc // a real error, or at least EOF |
| if err == io.EOF { |
| return nil |
| } |
| return err |
| } |
| pr := s.pullResponses[0] |
| s.pullResponses = s.pullResponses[1:] |
| s.mu.Unlock() |
| if pr.err != nil { |
| // Add a slight delay to ensure the server receives any |
| // messages en route from the client before shutting down the stream. |
| // This reduces flakiness of tests involving retry. |
| time.Sleep(200 * time.Millisecond) |
| } |
| if pr.err == io.EOF { |
| return nil |
| } |
| if pr.err != nil { |
| return pr.err |
| } |
| // Return any error from Recv. |
| select { |
| case err := <-errc: |
| return err |
| default: |
| } |
| res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs} |
| if err := stream.Send(res); err != nil { |
| return err |
| } |
| } |
| } |
| |
| func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) { |
| var err error |
| s.mu.Lock() |
| if len(s.ackErrs) > 0 { |
| err = s.ackErrs[0] |
| s.ackErrs = s.ackErrs[1:] |
| } |
| s.mu.Unlock() |
| if err != nil { |
| return nil, err |
| } |
| for _, id := range req.AckIds { |
| s.Acked[id] = true |
| } |
| return &emptypb.Empty{}, nil |
| } |
| |
| func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { |
| var err error |
| s.mu.Lock() |
| if len(s.modAckErrs) > 0 { |
| err = s.modAckErrs[0] |
| s.modAckErrs = s.modAckErrs[1:] |
| } |
| s.mu.Unlock() |
| if err != nil { |
| return nil, err |
| } |
| for _, id := range req.AckIds { |
| s.Deadlines[id] = req.AckDeadlineSeconds |
| } |
| return &emptypb.Empty{}, nil |
| } |
| |
| func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { |
| return s.sub, nil |
| } |