blob: e1aac6d11cfee73fdad19d2fb28377ae05cfb0ef [file] [log] [blame]
// Copyright 2017 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package pstest provides a fake Cloud PubSub service for testing. It implements a
// simplified form of the service, suitable for unit tests. It may behave
// differently from the actual service in ways in which the service is
// non-deterministic or unspecified: timing, delivery order, etc.
//
// This package is EXPERIMENTAL and is subject to change without notice.
//
// See the example for usage.
package pstest
import (
"fmt"
"io"
"path"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"cloud.google.com/go/internal/testutil"
"github.com/golang/protobuf/ptypes"
durpb "github.com/golang/protobuf/ptypes/duration"
emptypb "github.com/golang/protobuf/ptypes/empty"
"golang.org/x/net/context"
pb "google.golang.org/genproto/googleapis/pubsub/v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// For testing. Note that even though changes to the now variable are atomic, a call
// to the stored function can race with a change to that function. This could be a
// problem if tests are run in parallel, or even if concurrent parts of the same test
// change the value of the variable.
var now atomic.Value
func init() {
now.Store(time.Now)
}
func timeNow() time.Time {
return now.Load().(func() time.Time)()
}
type Server struct {
Addr string // The address that the server is listening on.
gServer gServer
}
type gServer struct {
pb.PublisherServer
pb.SubscriberServer
mu sync.Mutex
topics map[string]*topic
subs map[string]*subscription
msgs []*Message // all messages ever published
msgsByID map[string]*Message
wg sync.WaitGroup
nextID int
streamTimeout time.Duration
}
// NewServer creates a new fake server running in the current process.
func NewServer() *Server {
srv, err := testutil.NewServer()
if err != nil {
panic(fmt.Sprintf("pstest.NewServer: %v", err))
}
s := &Server{
Addr: srv.Addr,
gServer: gServer{
topics: map[string]*topic{},
subs: map[string]*subscription{},
msgsByID: map[string]*Message{},
},
}
pb.RegisterPublisherServer(srv.Gsrv, &s.gServer)
pb.RegisterSubscriberServer(srv.Gsrv, &s.gServer)
srv.Start()
return s
}
// Publish behaves as if the Publish RPC was called with a message with the given
// data and attrs. It returns the ID of the message.
// The topic will be created if it doesn't exist.
//
// Publish panics if there is an error, which is appropriate for testing.
func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
const topicPattern = "projects/*/topics/*"
ok, err := path.Match(topicPattern, topic)
if err != nil {
panic(err)
}
if !ok {
panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
}
_, _ = s.gServer.CreateTopic(nil, &pb.Topic{Name: topic})
req := &pb.PublishRequest{
Topic: topic,
Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}},
}
res, err := s.gServer.Publish(nil, req)
if err != nil {
panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
}
return res.MessageIds[0]
}
// SetStreamTimeout sets the amount of time a stream will be active before it shuts
// itself down. This mimics the real service's behavior of closing streams after 30
// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
// down.
func (s *Server) SetStreamTimeout(d time.Duration) {
s.gServer.mu.Lock()
defer s.gServer.mu.Unlock()
s.gServer.streamTimeout = d
}
// A Message is a message that was published to the server.
type Message struct {
ID string
Data []byte
Attributes map[string]string
PublishTime time.Time
Deliveries int // number of times delivery of the message was attempted
Acks int // number of acks received from clients
// protected by server mutex
deliveries int
acks int
}
// Messages returns information about all messages ever published.
func (s *Server) Messages() []*Message {
s.gServer.mu.Lock()
defer s.gServer.mu.Unlock()
var msgs []*Message
for _, m := range s.gServer.msgs {
m.Deliveries = m.deliveries
m.Acks = m.acks
msgs = append(msgs, m)
}
return msgs
}
// Message returns the message with the given ID, or nil if no message
// with that ID was published.
func (s *Server) Message(id string) *Message {
s.gServer.mu.Lock()
defer s.gServer.mu.Unlock()
m := s.gServer.msgsByID[id]
if m != nil {
m.Deliveries = m.deliveries
m.Acks = m.acks
}
return m
}
// Wait blocks until all server activity has completed.
func (s *Server) Wait() {
s.gServer.wg.Wait()
}
func (s *gServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.topics[t.Name] != nil {
return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
}
top := newTopic(t)
s.topics[t.Name] = top
return top.proto, nil
}
func (s *gServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
s.mu.Lock()
defer s.mu.Unlock()
if t := s.topics[req.Topic]; t != nil {
return t.proto, nil
}
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
}
func (s *gServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
}
func (s *gServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
var names []string
for n := range s.topics {
if strings.HasPrefix(n, req.Project) {
names = append(names, n)
}
}
sort.Strings(names)
from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
if err != nil {
return nil, err
}
res := &pb.ListTopicsResponse{NextPageToken: nextToken}
for i := from; i < to; i++ {
res.Topics = append(res.Topics, s.topics[names[i]].proto)
}
return res, nil
}
func (s *gServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
var names []string
for name, sub := range s.subs {
if sub.topic.proto.Name == req.Topic {
names = append(names, name)
}
}
sort.Strings(names)
from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
if err != nil {
return nil, err
}
return &pb.ListTopicSubscriptionsResponse{
Subscriptions: names[from:to],
NextPageToken: nextToken,
}, nil
}
func (s *gServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
t := s.topics[req.Topic]
if t == nil {
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
}
t.stop()
delete(s.topics, req.Topic)
return &emptypb.Empty{}, nil
}
func (s *gServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
if ps.Name == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing name")
}
if s.subs[ps.Name] != nil {
return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
}
if ps.Topic == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing topic")
}
top := s.topics[ps.Topic]
if top == nil {
return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
}
if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
return nil, err
}
if ps.MessageRetentionDuration == nil {
ps.MessageRetentionDuration = defaultMessageRetentionDuration
}
if err := checkMRD(ps.MessageRetentionDuration); err != nil {
return nil, err
}
if ps.PushConfig == nil {
ps.PushConfig = &pb.PushConfig{}
}
sub := newSubscription(top, &s.mu, ps)
top.subs[ps.Name] = sub
s.subs[ps.Name] = sub
sub.start(&s.wg)
return ps, nil
}
// Can be set for testing.
var minAckDeadlineSecs int32 = 10
func checkAckDeadline(ads int32) error {
if ads < minAckDeadlineSecs || ads > 600 {
// PubSub service returns Unknown.
return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
}
return nil
}
const (
minMessageRetentionDuration = 10 * time.Minute
maxMessageRetentionDuration = 168 * time.Hour
)
var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
func checkMRD(pmrd *durpb.Duration) error {
mrd, err := ptypes.Duration(pmrd)
if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
}
return nil
}
func (s *gServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
if sub := s.subs[req.Subscription]; sub != nil {
return sub.proto, nil
}
return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
}
func (s *gServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
sub := s.subs[req.Subscription.Name]
if sub == nil {
return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name)
}
for _, path := range req.UpdateMask.Paths {
switch path {
case "push_config":
sub.proto.PushConfig = req.Subscription.PushConfig
case "ack_deadline_seconds":
a := req.Subscription.AckDeadlineSeconds
if err := checkAckDeadline(a); err != nil {
return nil, err
}
sub.proto.AckDeadlineSeconds = a
case "retain_acked_messages":
sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
case "message_retention_duration":
if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
return nil, err
}
sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
// TODO(jba): labels
default:
return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
}
}
return sub.proto, nil
}
func (s *gServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
var names []string
for name := range s.subs {
if strings.HasPrefix(name, req.Project) {
names = append(names, name)
}
}
sort.Strings(names)
from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
if err != nil {
return nil, err
}
res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
for i := from; i < to; i++ {
res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
}
return res, nil
}
func (s *gServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
sub := s.subs[req.Subscription]
if sub == nil {
return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
}
sub.stop()
delete(s.subs, req.Subscription)
sub.topic.deleteSub(sub)
return &emptypb.Empty{}, nil
}
func (s *gServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if req.Topic == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing topic")
}
top := s.topics[req.Topic]
if top == nil {
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
}
var ids []string
for _, pm := range req.Messages {
id := fmt.Sprintf("m%d", s.nextID)
s.nextID++
pm.MessageId = id
pubTime := timeNow()
tsPubTime, err := ptypes.TimestampProto(pubTime)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
}
pm.PublishTime = tsPubTime
m := &Message{
ID: id,
Data: pm.Data,
Attributes: pm.Attributes,
PublishTime: pubTime,
}
top.publish(pm, m)
ids = append(ids, id)
s.msgs = append(s.msgs, m)
s.msgsByID[id] = m
}
return &pb.PublishResponse{MessageIds: ids}, nil
}
type topic struct {
proto *pb.Topic
subs map[string]*subscription
}
func newTopic(pt *pb.Topic) *topic {
return &topic{
proto: pt,
subs: map[string]*subscription{},
}
}
func (t *topic) stop() {
for _, sub := range t.subs {
sub.proto.Topic = "_deleted-topic_"
sub.stop()
}
}
func (t *topic) deleteSub(sub *subscription) {
delete(t.subs, sub.proto.Name)
}
func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
for _, s := range t.subs {
s.msgs[pm.MessageId] = &message{
publishTime: m.PublishTime,
proto: &pb.ReceivedMessage{
AckId: pm.MessageId,
Message: pm,
},
deliveries: &m.deliveries,
acks: &m.acks,
streamIndex: -1,
}
}
}
type subscription struct {
topic *topic
mu *sync.Mutex
proto *pb.Subscription
ackTimeout time.Duration
msgs map[string]*message // unacked messages by message ID
streams []*stream
done chan struct{}
}
func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription {
at := time.Duration(ps.AckDeadlineSeconds) * time.Second
if at == 0 {
at = 10 * time.Second
}
return &subscription{
topic: t,
mu: mu,
proto: ps,
ackTimeout: at,
msgs: map[string]*message{},
done: make(chan struct{}),
}
}
func (s *subscription) start(wg *sync.WaitGroup) {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-s.done:
return
case <-time.After(10 * time.Millisecond):
s.deliver()
}
}
}()
}
func (s *subscription) stop() {
close(s.done)
}
func (s *gServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
if req.Subscription == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
}
sub := s.subs[req.Subscription]
for _, id := range req.AckIds {
sub.ack(id)
}
return &emptypb.Empty{}, nil
}
func (s *gServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
if req.Subscription == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
}
sub := s.subs[req.Subscription]
dur := secsToDur(req.AckDeadlineSeconds)
for _, id := range req.AckIds {
sub.modifyAckDeadline(id, dur)
}
return &emptypb.Empty{}, nil
}
func (s *gServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
// Receive initial message configuring the pull.
req, err := sps.Recv()
if err != nil {
return err
}
if req.Subscription == "" {
return status.Errorf(codes.InvalidArgument, "missing subscription")
}
s.mu.Lock()
sub := s.subs[req.Subscription]
s.mu.Unlock()
if sub == nil {
return status.Errorf(codes.NotFound, "subscription %s", req.Subscription)
}
// Create a new stream to handle the pull.
st := sub.newStream(sps, s.streamTimeout)
err = st.pull(&s.wg)
sub.deleteStream(st)
return err
}
var retentionDuration = 10 * time.Minute
func (s *subscription) deliver() {
s.mu.Lock()
defer s.mu.Unlock()
tNow := timeNow()
for id, m := range s.msgs {
// Mark a message as re-deliverable if its ack deadline has expired.
if m.outstanding() && tNow.After(m.ackDeadline) {
m.makeAvailable()
}
pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime)
if err != nil {
panic(err)
}
// Remove messages that have been undelivered for a long time.
if !m.outstanding() && tNow.Sub(pubTime) > retentionDuration {
delete(s.msgs, id)
}
}
// Try to deliver each remaining message.
curIndex := 0
for _, m := range s.msgs {
if m.outstanding() {
continue
}
// If the message was never delivered before, start with the stream at
// curIndex. If it was delivered before, start with the stream after the one
// that owned it.
if m.streamIndex < 0 {
delIndex, ok := s.deliverMessage(m, curIndex, tNow)
if !ok {
break
}
curIndex = delIndex + 1
m.streamIndex = curIndex
} else {
delIndex, ok := s.deliverMessage(m, m.streamIndex, tNow)
if !ok {
break
}
m.streamIndex = delIndex
}
}
}
// deliverMessage attempts to deliver m to the stream at index i. If it can't, it
// tries streams i+1, i+2, ..., wrapping around. It returns the index of the stream
// it delivered the message to, or 0, false if it didn't deliver the message because
// there are no active streams.
func (s *subscription) deliverMessage(m *message, i int, tNow time.Time) (int, bool) {
for len(s.streams) > 0 {
if i >= len(s.streams) {
i = 0
}
st := s.streams[i]
select {
case <-st.done:
s.streams = deleteStreamAt(s.streams, i)
case st.msgc <- m.proto:
(*m.deliveries)++
m.ackDeadline = tNow.Add(st.ackTimeout)
return i, true
}
}
return 0, false
}
func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
st := &stream{
sub: s,
done: make(chan struct{}),
msgc: make(chan *pb.ReceivedMessage),
gstream: gs,
ackTimeout: s.ackTimeout,
timeout: timeout,
}
s.mu.Lock()
s.streams = append(s.streams, st)
s.mu.Unlock()
return st
}
func (s *subscription) deleteStream(st *stream) {
s.mu.Lock()
defer s.mu.Unlock()
var i int
for i = 0; i < len(s.streams); i++ {
if s.streams[i] == st {
break
}
}
if i < len(s.streams) {
s.streams = deleteStreamAt(s.streams, i)
}
}
func deleteStreamAt(s []*stream, i int) []*stream {
// Preserve order for round-robin delivery.
return append(s[:i], s[i+1:]...)
}
type message struct {
proto *pb.ReceivedMessage
publishTime time.Time
ackDeadline time.Time
deliveries *int
acks *int
streamIndex int // index of stream that currently owns msg, for round-robin delivery
}
// A message is outstanding if it is owned by some stream.
func (m *message) outstanding() bool {
return !m.ackDeadline.IsZero()
}
func (m *message) makeAvailable() {
m.ackDeadline = time.Time{}
}
type stream struct {
sub *subscription
done chan struct{} // closed when the stream is finished
msgc chan *pb.ReceivedMessage
gstream pb.Subscriber_StreamingPullServer
ackTimeout time.Duration
timeout time.Duration
}
// pull manages the StreamingPull interaction for the life of the stream.
func (st *stream) pull(wg *sync.WaitGroup) error {
errc := make(chan error, 2)
wg.Add(2)
go func() {
defer wg.Done()
errc <- st.sendLoop()
}()
go func() {
defer wg.Done()
errc <- st.recvLoop()
}()
var tchan <-chan time.Time
if st.timeout > 0 {
tchan = time.After(st.timeout)
}
// Wait until one of the goroutines returns an error, or we time out.
var err error
select {
case err = <-errc:
if err == io.EOF {
err = nil
}
case <-tchan:
}
close(st.done) // stop the other goroutine
return err
}
func (st *stream) sendLoop() error {
for {
select {
case <-st.done:
return nil
case rm := <-st.msgc:
res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
if err := st.gstream.Send(res); err != nil {
return err
}
}
}
}
func (st *stream) recvLoop() error {
for {
req, err := st.gstream.Recv()
if err != nil {
return err
}
st.sub.handleStreamingPullRequest(st, req)
}
}
func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
// Lock the entire server.
s.mu.Lock()
defer s.mu.Unlock()
for _, ackID := range req.AckIds {
s.ack(ackID)
}
for i, id := range req.ModifyDeadlineAckIds {
s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
}
if req.StreamAckDeadlineSeconds > 0 {
st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
}
}
func (s *subscription) ack(id string) {
m := s.msgs[id]
if m != nil {
(*m.acks)++
delete(s.msgs, id)
}
}
func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
m := s.msgs[id]
if m == nil { // already acked: ignore.
return
}
if d == 0 { // nack
m.makeAvailable()
} else { // extend the deadline by d
m.ackDeadline = timeNow().Add(d)
}
}
func secsToDur(secs int32) time.Duration {
return time.Duration(secs) * time.Second
}