| // Copyright 2017 Google Inc. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package pubsub |
| |
| // This file provides a fake/mock in-memory pubsub server. |
| |
| import ( |
| "io" |
| "sort" |
| "strings" |
| "sync" |
| "time" |
| |
| "cloud.google.com/go/internal/testutil" |
| "github.com/golang/protobuf/proto" |
| "github.com/golang/protobuf/ptypes" |
| durpb "github.com/golang/protobuf/ptypes/duration" |
| emptypb "github.com/golang/protobuf/ptypes/empty" |
| "golang.org/x/net/context" |
| pb "google.golang.org/genproto/googleapis/pubsub/v1" |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/codes" |
| ) |
| |
| type fakeServer struct { |
| pb.PublisherServer |
| pb.SubscriberServer |
| |
| Addr string |
| |
| mu sync.Mutex |
| Acked map[string]bool // acked message IDs |
| Deadlines map[string]int32 // deadlines by message ID |
| pullResponses []*pullResponse |
| wg sync.WaitGroup |
| subs map[string]*pb.Subscription |
| topics map[string]*pb.Topic |
| } |
| |
| type pullResponse struct { |
| msgs []*pb.ReceivedMessage |
| err error |
| } |
| |
| func newFakeServer() (*fakeServer, error) { |
| srv, err := testutil.NewServer() |
| if err != nil { |
| return nil, err |
| } |
| fake := &fakeServer{ |
| Addr: srv.Addr, |
| Acked: map[string]bool{}, |
| Deadlines: map[string]int32{}, |
| subs: map[string]*pb.Subscription{}, |
| topics: map[string]*pb.Topic{}, |
| } |
| pb.RegisterPublisherServer(srv.Gsrv, fake) |
| pb.RegisterSubscriberServer(srv.Gsrv, fake) |
| srv.Start() |
| return fake, nil |
| } |
| |
| // Each call to addStreamingPullMessages results in one StreamingPullResponse. |
| func (s *fakeServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) { |
| s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil}) |
| } |
| |
| func (s *fakeServer) addStreamingPullError(err error) { |
| s.pullResponses = append(s.pullResponses, &pullResponse{nil, err}) |
| } |
| |
| func (s *fakeServer) wait() { |
| s.wg.Wait() |
| } |
| |
| func (s *fakeServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error { |
| s.wg.Add(1) |
| defer s.wg.Done() |
| errc := make(chan error, 1) |
| s.wg.Add(1) |
| go func() { |
| defer s.wg.Done() |
| for { |
| req, err := stream.Recv() |
| if err != nil { |
| errc <- err |
| return |
| } |
| s.mu.Lock() |
| for _, id := range req.AckIds { |
| s.Acked[id] = true |
| } |
| for i, id := range req.ModifyDeadlineAckIds { |
| s.Deadlines[id] = req.ModifyDeadlineSeconds[i] |
| } |
| s.mu.Unlock() |
| } |
| }() |
| // Send responses. |
| for { |
| s.mu.Lock() |
| if len(s.pullResponses) == 0 { |
| s.mu.Unlock() |
| // Nothing to send, so wait for the client to shut down the stream. |
| err := <-errc // a real error, or at least EOF |
| if err == io.EOF { |
| return nil |
| } |
| return err |
| } |
| pr := s.pullResponses[0] |
| s.pullResponses = s.pullResponses[1:] |
| s.mu.Unlock() |
| if pr.err != nil { |
| // Add a slight delay to ensure the server receives any |
| // messages en route from the client before shutting down the stream. |
| // This reduces flakiness of tests involving retry. |
| time.Sleep(200 * time.Millisecond) |
| } |
| if pr.err == io.EOF { |
| return nil |
| } |
| if pr.err != nil { |
| return pr.err |
| } |
| // Return any error from Recv. |
| select { |
| case err := <-errc: |
| return err |
| default: |
| } |
| res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs} |
| if err := stream.Send(res); err != nil { |
| return err |
| } |
| } |
| } |
| |
| const ( |
| minMessageRetentionDuration = 10 * time.Minute |
| maxMessageRetentionDuration = 168 * time.Hour |
| ) |
| |
| var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration) |
| |
| func checkMRD(pmrd *durpb.Duration) error { |
| mrd, err := ptypes.Duration(pmrd) |
| if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration { |
| return grpc.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd) |
| } |
| return nil |
| } |
| |
| func checkAckDeadline(ads int32) error { |
| if ads < 10 || ads > 600 { |
| // PubSub service returns Unknown. |
| return grpc.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads) |
| } |
| return nil |
| } |
| |
| func (s *fakeServer) CreateSubscription(ctx context.Context, sub *pb.Subscription) (*pb.Subscription, error) { |
| if s.subs[sub.Name] != nil { |
| return nil, grpc.Errorf(codes.AlreadyExists, "subscription %q", sub.Name) |
| } |
| sub2 := proto.Clone(sub).(*pb.Subscription) |
| if err := checkAckDeadline(sub.AckDeadlineSeconds); err != nil { |
| return nil, err |
| } |
| if sub.MessageRetentionDuration == nil { |
| sub2.MessageRetentionDuration = defaultMessageRetentionDuration |
| } |
| if err := checkMRD(sub2.MessageRetentionDuration); err != nil { |
| return nil, err |
| } |
| if sub.PushConfig == nil { |
| sub2.PushConfig = &pb.PushConfig{} |
| } |
| s.subs[sub.Name] = sub2 |
| return sub2, nil |
| } |
| |
| func (s *fakeServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { |
| if sub := s.subs[req.Subscription]; sub != nil { |
| return sub, nil |
| } |
| return nil, grpc.Errorf(codes.NotFound, "subscription %q", req.Subscription) |
| } |
| |
| func (s *fakeServer) UpdateSubscription(ctx context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) { |
| sub := s.subs[req.Subscription.Name] |
| if sub == nil { |
| return nil, grpc.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name) |
| } |
| for _, path := range req.UpdateMask.Paths { |
| switch path { |
| case "push_config": |
| sub.PushConfig = req.Subscription.PushConfig |
| |
| case "ack_deadline_seconds": |
| a := req.Subscription.AckDeadlineSeconds |
| if err := checkAckDeadline(a); err != nil { |
| return nil, err |
| } |
| sub.AckDeadlineSeconds = a |
| |
| case "retain_acked_messages": |
| sub.RetainAckedMessages = req.Subscription.RetainAckedMessages |
| |
| case "message_retention_duration": |
| if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil { |
| return nil, err |
| } |
| sub.MessageRetentionDuration = req.Subscription.MessageRetentionDuration |
| |
| // TODO(jba): labels |
| default: |
| return nil, grpc.Errorf(codes.InvalidArgument, "unknown field name %q", path) |
| } |
| } |
| return sub, nil |
| } |
| |
| func (s *fakeServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) { |
| if s.subs[req.Subscription] == nil { |
| return nil, grpc.Errorf(codes.NotFound, "subscription %q", req.Subscription) |
| } |
| delete(s.subs, req.Subscription) |
| return &emptypb.Empty{}, nil |
| } |
| |
| func (s *fakeServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) { |
| if s.topics[t.Name] != nil { |
| return nil, grpc.Errorf(codes.AlreadyExists, "topic %q", t.Name) |
| } |
| t2 := proto.Clone(t).(*pb.Topic) |
| s.topics[t.Name] = t2 |
| return t2, nil |
| } |
| |
| func (s *fakeServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) { |
| if t := s.topics[req.Topic]; t != nil { |
| return t, nil |
| } |
| return nil, grpc.Errorf(codes.NotFound, "topic %q", req.Topic) |
| } |
| |
| func (s *fakeServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) { |
| if s.topics[req.Topic] == nil { |
| return nil, grpc.Errorf(codes.NotFound, "topic %q", req.Topic) |
| } |
| delete(s.topics, req.Topic) |
| return &emptypb.Empty{}, nil |
| } |
| |
| func (s *fakeServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) { |
| var names []string |
| for n := range s.topics { |
| if strings.HasPrefix(n, req.Project) { |
| names = append(names, n) |
| } |
| } |
| sort.Strings(names) |
| from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) |
| if err != nil { |
| return nil, err |
| } |
| res := &pb.ListTopicsResponse{NextPageToken: nextToken} |
| for i := from; i < to; i++ { |
| res.Topics = append(res.Topics, s.topics[names[i]]) |
| } |
| return res, nil |
| } |
| |
| func (s *fakeServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) { |
| var names []string |
| for _, sub := range s.subs { |
| if strings.HasPrefix(sub.Name, req.Project) { |
| names = append(names, sub.Name) |
| } |
| } |
| sort.Strings(names) |
| from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) |
| if err != nil { |
| return nil, err |
| } |
| res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken} |
| for i := from; i < to; i++ { |
| res.Subscriptions = append(res.Subscriptions, s.subs[names[i]]) |
| } |
| return res, nil |
| } |
| |
| func (s *fakeServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) { |
| var names []string |
| for _, sub := range s.subs { |
| if sub.Topic == req.Topic { |
| names = append(names, sub.Name) |
| } |
| } |
| sort.Strings(names) |
| from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) |
| if err != nil { |
| return nil, err |
| } |
| return &pb.ListTopicSubscriptionsResponse{ |
| Subscriptions: names[from:to], |
| NextPageToken: nextToken, |
| }, nil |
| } |