blob: 52da17407cf684fe1b29c154a47a5352d2c8f8ae [file] [log] [blame]
// Copyright 2017 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pubsub
// This file provides a mock in-memory pubsub server for streaming pull testing.
import (
"context"
"io"
"sync"
"time"
"cloud.google.com/go/internal/testutil"
emptypb "github.com/golang/protobuf/ptypes/empty"
pb "google.golang.org/genproto/googleapis/pubsub/v1"
)
type mockServer struct {
srv *testutil.Server
pb.SubscriberServer
Addr string
mu sync.Mutex
Acked map[string]bool // acked message IDs
Deadlines map[string]int32 // deadlines by message ID
pullResponses []*pullResponse
ackErrs []error
modAckErrs []error
wg sync.WaitGroup
sub *pb.Subscription
}
type pullResponse struct {
msgs []*pb.ReceivedMessage
err error
}
func newMockServer(port int) (*mockServer, error) {
srv, err := testutil.NewServerWithPort(port)
if err != nil {
return nil, err
}
mock := &mockServer{
srv: srv,
Addr: srv.Addr,
Acked: map[string]bool{},
Deadlines: map[string]int32{},
sub: &pb.Subscription{
AckDeadlineSeconds: 10,
PushConfig: &pb.PushConfig{},
},
}
pb.RegisterSubscriberServer(srv.Gsrv, mock)
srv.Start()
return mock, nil
}
// Each call to addStreamingPullMessages results in one StreamingPullResponse.
func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
s.mu.Lock()
s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
s.mu.Unlock()
}
func (s *mockServer) addStreamingPullError(err error) {
s.mu.Lock()
s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
s.mu.Unlock()
}
func (s *mockServer) addAckResponse(err error) {
s.mu.Lock()
s.ackErrs = append(s.ackErrs, err)
s.mu.Unlock()
}
func (s *mockServer) addModAckResponse(err error) {
s.mu.Lock()
s.modAckErrs = append(s.modAckErrs, err)
s.mu.Unlock()
}
func (s *mockServer) wait() {
s.wg.Wait()
}
func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
s.wg.Add(1)
defer s.wg.Done()
errc := make(chan error, 1)
s.wg.Add(1)
go func() {
defer s.wg.Done()
for {
req, err := stream.Recv()
if err != nil {
errc <- err
return
}
s.mu.Lock()
for _, id := range req.AckIds {
s.Acked[id] = true
}
for i, id := range req.ModifyDeadlineAckIds {
s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
}
s.mu.Unlock()
}
}()
// Send responses.
for {
s.mu.Lock()
if len(s.pullResponses) == 0 {
s.mu.Unlock()
// Nothing to send, so wait for the client to shut down the stream.
err := <-errc // a real error, or at least EOF
if err == io.EOF {
return nil
}
return err
}
pr := s.pullResponses[0]
s.pullResponses = s.pullResponses[1:]
s.mu.Unlock()
if pr.err != nil {
// Add a slight delay to ensure the server receives any
// messages en route from the client before shutting down the stream.
// This reduces flakiness of tests involving retry.
time.Sleep(200 * time.Millisecond)
}
if pr.err == io.EOF {
return nil
}
if pr.err != nil {
return pr.err
}
// Return any error from Recv.
select {
case err := <-errc:
return err
default:
}
res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
if err := stream.Send(res); err != nil {
return err
}
}
}
func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
var err error
s.mu.Lock()
if len(s.ackErrs) > 0 {
err = s.ackErrs[0]
s.ackErrs = s.ackErrs[1:]
}
s.mu.Unlock()
if err != nil {
return nil, err
}
for _, id := range req.AckIds {
s.Acked[id] = true
}
return &emptypb.Empty{}, nil
}
func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
var err error
s.mu.Lock()
if len(s.modAckErrs) > 0 {
err = s.modAckErrs[0]
s.modAckErrs = s.modAckErrs[1:]
}
s.mu.Unlock()
if err != nil {
return nil, err
}
for _, id := range req.AckIds {
s.Deadlines[id] = req.AckDeadlineSeconds
}
return &emptypb.Empty{}, nil
}
func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
return s.sub, nil
}