pubsub: Receive does not retry ResourceExhausted errors

The reasoning here is that StreamingPull should be a long-lived RPC. If a user
is seeing ResourceExhausted from StreamingPull, it's likely to be something
that they want to act on (increase quota, not over-deploy, etc).

As part of this CL, retrying in pubsub gets re-wired to use the gax.Retryer
pattern. This allows us to pass call options to call, which means that for
the streamingPull call we can pass a custom retryer that does not retry
ResourceExhausted.

As part of this CL, pstest's GServer is exposed. This makes it easier to
create fakes with slightly different semantics, which is a bit more scalable
long-term than exposing knobs for everything testers might want to do.

Fixes #1166

Change-Id: I2073607bcae410e49b0d139859d9b7c48e065a3c
Reviewed-on: https://code-review.googlesource.com/c/36050
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Eno Compton <enocom@google.com>
diff --git a/internal/kokoro/vet.sh b/internal/kokoro/vet.sh
index f2381e0..30d34c1 100755
--- a/internal/kokoro/vet.sh
+++ b/internal/kokoro/vet.sh
@@ -41,6 +41,7 @@
     grep -vE "exported const AllUsers|AllAuthenticatedUsers|RoleOwner|SSD|HDD|PRODUCTION|DEVELOPMENT should have comment" | \
     grep -v "exported func Value returns unexported type pretty.val, which can be annoying to use" | \
     grep -v "ExecuteStreamingSql" | \
+    grep -vE "pubsub\/pstest\/fake\.go.+should have comment or be unexported" | \
     grep -v "ClusterId" | \
     grep -v "InstanceId" | \
     grep -v "firestore.arrayUnion" | \
diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go
index 448c4c9..7cc1177 100644
--- a/pubsub/pstest/fake.go
+++ b/pubsub/pstest/fake.go
@@ -60,11 +60,13 @@
 // Server is a fake Pub/Sub server.
 type Server struct {
 	srv     *testutil.Server
-	Addr    string // The address that the server is listening on.
-	gServer gServer
+	Addr    string  // The address that the server is listening on.
+	GServer GServer // Not intended to be used directly.
 }
 
-type gServer struct {
+// GServer is the underlying service implementor. It is not intended to be used
+// directly.
+type GServer struct {
 	pb.PublisherServer
 	pb.SubscriberServer
 
@@ -87,14 +89,14 @@
 	s := &Server{
 		srv:  srv,
 		Addr: srv.Addr,
-		gServer: gServer{
+		GServer: GServer{
 			topics:   map[string]*topic{},
 			subs:     map[string]*subscription{},
 			msgsByID: map[string]*Message{},
 		},
 	}
-	pb.RegisterPublisherServer(srv.Gsrv, &s.gServer)
-	pb.RegisterSubscriberServer(srv.Gsrv, &s.gServer)
+	pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
+	pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
 	srv.Start()
 	return s
 }
@@ -113,12 +115,12 @@
 	if !ok {
 		panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
 	}
-	_, _ = s.gServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
+	_, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
 	req := &pb.PublishRequest{
 		Topic:    topic,
 		Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}},
 	}
-	res, err := s.gServer.Publish(context.TODO(), req)
+	res, err := s.GServer.Publish(context.TODO(), req)
 	if err != nil {
 		panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
 	}
@@ -130,9 +132,9 @@
 // minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
 // down.
 func (s *Server) SetStreamTimeout(d time.Duration) {
-	s.gServer.mu.Lock()
-	defer s.gServer.mu.Unlock()
-	s.gServer.streamTimeout = d
+	s.GServer.mu.Lock()
+	defer s.GServer.mu.Unlock()
+	s.GServer.streamTimeout = d
 }
 
 // A Message is a message that was published to the server.
@@ -160,11 +162,11 @@
 
 // Messages returns information about all messages ever published.
 func (s *Server) Messages() []*Message {
-	s.gServer.mu.Lock()
-	defer s.gServer.mu.Unlock()
+	s.GServer.mu.Lock()
+	defer s.GServer.mu.Unlock()
 
 	var msgs []*Message
-	for _, m := range s.gServer.msgs {
+	for _, m := range s.GServer.msgs {
 		m.Deliveries = m.deliveries
 		m.Acks = m.acks
 		msgs = append(msgs, m)
@@ -175,10 +177,10 @@
 // Message returns the message with the given ID, or nil if no message
 // with that ID was published.
 func (s *Server) Message(id string) *Message {
-	s.gServer.mu.Lock()
-	defer s.gServer.mu.Unlock()
+	s.GServer.mu.Lock()
+	defer s.GServer.mu.Unlock()
 
-	m := s.gServer.msgsByID[id]
+	m := s.GServer.msgsByID[id]
 	if m != nil {
 		m.Deliveries = m.deliveries
 		m.Acks = m.acks
@@ -188,21 +190,21 @@
 
 // Wait blocks until all server activity has completed.
 func (s *Server) Wait() {
-	s.gServer.wg.Wait()
+	s.GServer.wg.Wait()
 }
 
 // Close shuts down the server and releases all resources.
 func (s *Server) Close() error {
 	s.srv.Close()
-	s.gServer.mu.Lock()
-	defer s.gServer.mu.Unlock()
-	for _, sub := range s.gServer.subs {
+	s.GServer.mu.Lock()
+	defer s.GServer.mu.Unlock()
+	for _, sub := range s.GServer.subs {
 		sub.stop()
 	}
 	return nil
 }
 
-func (s *gServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
+func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -214,7 +216,7 @@
 	return top.proto, nil
 }
 
-func (s *gServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
+func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -224,7 +226,7 @@
 	return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
 }
 
-func (s *gServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
+func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -245,7 +247,7 @@
 	return t.proto, nil
 }
 
-func (s *gServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
+func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -267,7 +269,7 @@
 	return res, nil
 }
 
-func (s *gServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
+func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -288,7 +290,7 @@
 	}, nil
 }
 
-func (s *gServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
+func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -301,7 +303,7 @@
 	return &emptypb.Empty{}, nil
 }
 
-func (s *gServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
+func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -382,7 +384,7 @@
 	return nil
 }
 
-func (s *gServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
+func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	sub, err := s.findSubscription(req.Subscription)
@@ -392,7 +394,7 @@
 	return sub.proto, nil
 }
 
-func (s *gServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
+func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
 	if req.Subscription == nil {
 		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
 	}
@@ -433,7 +435,7 @@
 	return sub.proto, nil
 }
 
-func (s *gServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
+func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -455,7 +457,7 @@
 	return res, nil
 }
 
-func (s *gServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
+func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	sub, err := s.findSubscription(req.Subscription)
@@ -468,7 +470,7 @@
 	return &emptypb.Empty{}, nil
 }
 
-func (s *gServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
+func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -586,7 +588,7 @@
 	close(s.done)
 }
 
-func (s *gServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
+func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -600,7 +602,7 @@
 	return &emptypb.Empty{}, nil
 }
 
-func (s *gServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
+func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	sub, err := s.findSubscription(req.Subscription)
@@ -618,7 +620,7 @@
 	return &emptypb.Empty{}, nil
 }
 
-func (s *gServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
+func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
 	s.mu.Lock()
 	sub, err := s.findSubscription(req.Subscription)
 	if err != nil {
@@ -655,7 +657,7 @@
 	return &pb.PullResponse{ReceivedMessages: msgs}, nil
 }
 
-func (s *gServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
+func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
 	// Receive initial message configuring the pull.
 	req, err := sps.Recv()
 	if err != nil {
@@ -674,7 +676,7 @@
 	return err
 }
 
-func (s *gServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
+func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
 	// Only handle time-based seeking for now.
 	// This fake doesn't deal with snapshots.
 	var target time.Time
@@ -729,7 +731,7 @@
 
 // Gets a subscription that must exist.
 // Must be called with the lock held.
-func (s *gServer) findSubscription(name string) (*subscription, error) {
+func (s *GServer) findSubscription(name string) (*subscription, error) {
 	if name == "" {
 		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
 	}
diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go
index b1c99bc..66c014a 100644
--- a/pubsub/pstest/fake_test.go
+++ b/pubsub/pstest/fake_test.go
@@ -42,7 +42,7 @@
 			Labels: map[string]string{"num": fmt.Sprintf("%d", i)},
 		}))
 	}
-	if got, want := len(server.gServer.topics), len(topics); got != want {
+	if got, want := len(server.GServer.topics), len(topics); got != want {
 		t.Fatalf("got %d topics, want %d", got, want)
 	}
 	for _, top := range topics {
@@ -68,7 +68,7 @@
 			t.Fatal(err)
 		}
 	}
-	if got, want := len(server.gServer.topics), 0; got != want {
+	if got, want := len(server.GServer.topics), 0; got != want {
 		t.Fatalf("got %d topics, want %d", got, want)
 	}
 }
@@ -88,7 +88,7 @@
 		}))
 	}
 
-	if got, want := len(server.gServer.subs), len(subs); got != want {
+	if got, want := len(server.GServer.subs), len(subs); got != want {
 		t.Fatalf("got %d subscriptions, want %d", got, want)
 	}
 	for _, s := range subs {
@@ -128,7 +128,7 @@
 			t.Fatal(err)
 		}
 	}
-	if got, want := len(server.gServer.subs), 0; got != want {
+	if got, want := len(server.GServer.subs), 0; got != want {
 		t.Fatalf("got %d subscriptions, want %d", got, want)
 	}
 }
diff --git a/pubsub/pullstream.go b/pubsub/pullstream.go
index 5a50ff9..2fdb538 100644
--- a/pubsub/pullstream.go
+++ b/pubsub/pullstream.go
@@ -95,13 +95,14 @@
 }
 
 func (s *pullStream) openWithRetry() (pb.Subscriber_StreamingPullClient, error) {
-	var bo gax.Backoff
+	r := defaultRetryer{}
 	for {
 		recordStat(s.ctx, StreamOpenCount, 1)
 		spc, err := s.open()
-		if err != nil && isRetryable(err) {
+		bo, shouldRetry := r.Retry(err)
+		if err != nil && shouldRetry {
 			recordStat(s.ctx, StreamRetryCount, 1)
-			if err := gax.Sleep(s.ctx, bo.Pause()); err != nil {
+			if err := gax.Sleep(s.ctx, bo); err != nil {
 				return nil, err
 			}
 			continue
@@ -110,11 +111,19 @@
 	}
 }
 
-func (s *pullStream) call(f func(pb.Subscriber_StreamingPullClient) error) error {
+func (s *pullStream) call(f func(pb.Subscriber_StreamingPullClient) error, opts ...gax.CallOption) error {
+	var settings gax.CallSettings
+	for _, opt := range opts {
+		opt.Resolve(&settings)
+	}
+	var r gax.Retryer = &defaultRetryer{}
+	if settings.Retry != nil {
+		r = settings.Retry()
+	}
+
 	var (
 		spc *pb.Subscriber_StreamingPullClient
 		err error
-		bo  gax.Backoff
 	)
 	for {
 		spc, err = s.get(spc)
@@ -124,10 +133,11 @@
 		start := time.Now()
 		err = f(*spc)
 		if err != nil {
-			if isRetryable(err) {
+			bo, shouldRetry := r.Retry(err)
+			if shouldRetry {
 				recordStat(s.ctx, StreamRetryCount, 1)
 				if time.Since(start) < 30*time.Second { // don't sleep if we've been blocked for a while
-					if err := gax.Sleep(s.ctx, bo.Pause()); err != nil {
+					if err := gax.Sleep(s.ctx, bo); err != nil {
 						return err
 					}
 				}
@@ -167,7 +177,7 @@
 			recordStat(s.ctx, PullCount, int64(len(res.ReceivedMessages)))
 		}
 		return err
-	})
+	}, gax.WithRetry(func() gax.Retryer { return &streamingPullRetryer{defaultRetryer: &defaultRetryer{}} }))
 	return res, err
 }
 
diff --git a/pubsub/pullstream_test.go b/pubsub/pullstream_test.go
index 0594808..108344d 100644
--- a/pubsub/pullstream_test.go
+++ b/pubsub/pullstream_test.go
@@ -17,9 +17,14 @@
 import (
 	"context"
 	"testing"
+	"time"
 
+	"cloud.google.com/go/internal/testutil"
+	"cloud.google.com/go/pubsub/pstest"
 	gax "github.com/googleapis/gax-go"
+	"google.golang.org/api/option"
 	pb "google.golang.org/genproto/googleapis/pubsub/v1"
+	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
@@ -70,6 +75,61 @@
 	}
 }
 
+func TestPullStreamGet_ResourceUnavailable(t *testing.T) {
+	ctx := context.Background()
+
+	srv, err := testutil.NewServer()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer srv.Close()
+
+	ps := pstest.NewServer()
+	defer ps.Close()
+
+	s := ExhaustedServer{ps.GServer}
+	pb.RegisterPublisherServer(srv.Gsrv, &s)
+	pb.RegisterSubscriberServer(srv.Gsrv, &s)
+	srv.Start()
+
+	client, err := NewClient(ctx, "P",
+		option.WithEndpoint(srv.Addr),
+		option.WithoutAuthentication(),
+		option.WithGRPCDialOption(grpc.WithInsecure()))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer client.Close()
+
+	errc := make(chan error)
+	go func() {
+		errc <- client.Subscription("foo").Receive(ctx, func(context.Context, *Message) {
+			t.Error("should not have received any data")
+		})
+	}()
+
+	select {
+	case <-time.After(5 * time.Second):
+		t.Fatal("Receive should have failed immediately")
+	case err := <-errc:
+		if gerr, ok := status.FromError(err); ok {
+			if gerr.Code() != codes.ResourceExhausted {
+				t.Fatal("expected to receive a grpc ResourceExhausted error")
+			}
+		} else {
+			t.Fatal("expected to receive a grpc ResourceExhausted error")
+		}
+	}
+}
+
+type ExhaustedServer struct {
+	pstest.GServer
+}
+
+func (*ExhaustedServer) StreamingPull(_ pb.Subscriber_StreamingPullServer) error {
+	return status.Errorf(codes.ResourceExhausted, "This server is exhausted!")
+}
+
 type testStreamingPullClient struct {
 	pb.Subscriber_StreamingPullClient
 	sendError error
diff --git a/pubsub/service.go b/pubsub/service.go
index 2cf1d13..0752f3a 100644
--- a/pubsub/service.go
+++ b/pubsub/service.go
@@ -18,7 +18,9 @@
 	"fmt"
 	"math"
 	"strings"
+	"time"
 
+	gax "github.com/googleapis/gax-go"
 	pb "google.golang.org/genproto/googleapis/pubsub/v1"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
@@ -61,18 +63,45 @@
 	return int32(i)
 }
 
-// Logic from https://github.com/GoogleCloudPlatform/google-cloud-java/blob/master/google-cloud-clients/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StatusUtil.java
-func isRetryable(err error) bool {
+type defaultRetryer struct {
+	bo gax.Backoff
+}
+
+// Logic originally from
+// https://github.com/GoogleCloudPlatform/google-cloud-java/blob/master/google-cloud-clients/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StatusUtil.java
+func (r *defaultRetryer) Retry(err error) (pause time.Duration, shouldRetry bool) {
 	s, ok := status.FromError(err)
 	if !ok { // includes io.EOF, normal stream close, which causes us to reopen
-		return true
+		return r.bo.Pause(), true
 	}
 	switch s.Code() {
 	case codes.DeadlineExceeded, codes.Internal, codes.ResourceExhausted, codes.Aborted:
-		return true
+		return r.bo.Pause(), true
 	case codes.Unavailable:
-		return !strings.Contains(s.Message(), "Server shutdownNow invoked")
+		c := strings.Contains(s.Message(), "Server shutdownNow invoked")
+		if !c {
+			return r.bo.Pause(), true
+		}
+		return 0, false
 	default:
-		return false
+		return 0, false
+	}
+}
+
+type streamingPullRetryer struct {
+	defaultRetryer gax.Retryer
+}
+
+// Does not retry ResourceExhausted. See: https://github.com/GoogleCloudPlatform/google-cloud-go/issues/1166#issuecomment-443744705
+func (r *streamingPullRetryer) Retry(err error) (pause time.Duration, shouldRetry bool) {
+	s, ok := status.FromError(err)
+	if !ok { // call defaultRetryer so that its backoff can be used
+		return r.defaultRetryer.Retry(err)
+	}
+	switch s.Code() {
+	case codes.ResourceExhausted:
+		return 0, false
+	default:
+		return r.defaultRetryer.Retry(err)
 	}
 }