pstest: better error handling for subscriptions

In all RPCs that deal with subscriptions, check for the absence of a
subscription in the request as well as a non-existent subscription.

Change-Id: I1ce62f556bfc10705aa2fe69c1da74b6b0b99b5a
Reviewed-on: https://code-review.googlesource.com/33110
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go
index 9fca076..21ce6b7 100644
--- a/pubsub/pstest/fake.go
+++ b/pubsub/pstest/fake.go
@@ -370,22 +370,23 @@
 func (s *gServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-
-	if sub := s.subs[req.Subscription]; sub != nil {
-		return sub.proto, nil
+	sub, err := s.findSubscription(req.Subscription)
+	if err != nil {
+		return nil, err
 	}
-	return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
+	return sub.proto, nil
 }
 
 func (s *gServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
+	if req.Subscription == nil {
+		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
+	}
 	s.mu.Lock()
 	defer s.mu.Unlock()
-
-	sub := s.subs[req.Subscription.Name]
-	if sub == nil {
-		return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name)
+	sub, err := s.findSubscription(req.Subscription.Name)
+	if err != nil {
+		return nil, err
 	}
-
 	for _, path := range req.UpdateMask.Paths {
 		switch path {
 		case "push_config":
@@ -442,10 +443,9 @@
 func (s *gServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-
-	sub := s.subs[req.Subscription]
-	if sub == nil {
-		return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
+	sub, err := s.findSubscription(req.Subscription)
+	if err != nil {
+		return nil, err
 	}
 	sub.stop()
 	delete(s.subs, req.Subscription)
@@ -574,10 +574,11 @@
 func (s *gServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	if req.Subscription == "" {
-		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
+
+	sub, err := s.findSubscription(req.Subscription)
+	if err != nil {
+		return nil, err
 	}
-	sub := s.subs[req.Subscription]
 	for _, id := range req.AckIds {
 		sub.ack(id)
 	}
@@ -587,16 +588,14 @@
 func (s *gServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-
+	sub, err := s.findSubscription(req.Subscription)
+	if err != nil {
+		return nil, err
+	}
 	now := time.Now()
 	for _, id := range req.AckIds {
 		s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
 	}
-
-	if req.Subscription == "" {
-		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
-	}
-	sub := s.subs[req.Subscription]
 	dur := secsToDur(req.AckDeadlineSeconds)
 	for _, id := range req.AckIds {
 		sub.modifyAckDeadline(id, dur)
@@ -610,14 +609,11 @@
 	if err != nil {
 		return err
 	}
-	if req.Subscription == "" {
-		return status.Errorf(codes.InvalidArgument, "missing subscription")
-	}
 	s.mu.Lock()
-	sub := s.subs[req.Subscription]
+	sub, err := s.findSubscription(req.Subscription)
 	s.mu.Unlock()
-	if sub == nil {
-		return status.Errorf(codes.NotFound, "subscription %s", req.Subscription)
+	if err != nil {
+		return err
 	}
 	// Create a new stream to handle the pull.
 	st := sub.newStream(sps, s.streamTimeout)
@@ -631,6 +627,8 @@
 	// This fake doesn't deal with snapshots.
 	var target time.Time
 	switch v := req.Target.(type) {
+	case nil:
+		return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
 	case *pb.SeekRequest_Time:
 		var err error
 		target, err = ptypes.Timestamp(v.Time)
@@ -645,12 +643,10 @@
 	// because the messages don't have any other synchronization.
 	s.mu.Lock()
 	defer s.mu.Unlock()
-
-	sub := s.subs[req.Subscription]
-	if sub == nil {
-		return nil, status.Errorf(codes.NotFound, "subscription %s", req.Subscription)
+	sub, err := s.findSubscription(req.Subscription)
+	if err != nil {
+		return nil, err
 	}
-
 	// Drop all messages from sub that were published before the target time.
 	for id, m := range sub.msgs {
 		if m.publishTime.Before(target) {
@@ -679,6 +675,19 @@
 	return &pb.SeekResponse{}, nil
 }
 
+// Gets a subscription that must exist.
+// Must be called with the lock held.
+func (s *gServer) findSubscription(name string) (*subscription, error) {
+	if name == "" {
+		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
+	}
+	sub := s.subs[name]
+	if sub == nil {
+		return nil, status.Errorf(codes.NotFound, "subscription %s", name)
+	}
+	return sub, nil
+}
+
 var retentionDuration = 10 * time.Minute
 
 func (s *subscription) deliver() {
diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go
index a0330e7..faad5b2 100644
--- a/pubsub/pstest/fake_test.go
+++ b/pubsub/pstest/fake_test.go
@@ -26,6 +26,8 @@
 	"golang.org/x/net/context"
 	pb "google.golang.org/genproto/googleapis/pubsub/v1"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 )
 
 func TestTopics(t *testing.T) {
@@ -127,6 +129,49 @@
 	}
 }
 
+func TestSubscriptionErrors(t *testing.T) {
+	_, sclient, _ := newFake(t)
+	ctx := context.Background()
+
+	// TODO(jba): Go1.9: use t.Helper()
+	checkCode := func(msg string, err error, want codes.Code) {
+		if status.Code(err) != want {
+			t.Errorf("%s: got %v, want code %s", msg, err, want)
+		}
+	}
+
+	_, err := sclient.GetSubscription(ctx, &pb.GetSubscriptionRequest{})
+	checkCode("GetSubscription", err, codes.InvalidArgument)
+	_, err = sclient.GetSubscription(ctx, &pb.GetSubscriptionRequest{Subscription: "s"})
+	checkCode("GetSubscription", err, codes.NotFound)
+	_, err = sclient.UpdateSubscription(ctx, &pb.UpdateSubscriptionRequest{})
+	checkCode("UpdateSubscription", err, codes.InvalidArgument)
+	_, err = sclient.UpdateSubscription(ctx, &pb.UpdateSubscriptionRequest{Subscription: &pb.Subscription{}})
+	checkCode("UpdateSubscription", err, codes.InvalidArgument)
+	_, err = sclient.UpdateSubscription(ctx, &pb.UpdateSubscriptionRequest{Subscription: &pb.Subscription{Name: "s"}})
+	checkCode("UpdateSubscription", err, codes.NotFound)
+	_, err = sclient.DeleteSubscription(ctx, &pb.DeleteSubscriptionRequest{})
+	checkCode("DeleteSubscription", err, codes.InvalidArgument)
+	_, err = sclient.DeleteSubscription(ctx, &pb.DeleteSubscriptionRequest{Subscription: "s"})
+	checkCode("DeleteSubscription", err, codes.NotFound)
+	_, err = sclient.Acknowledge(ctx, &pb.AcknowledgeRequest{})
+	checkCode("Acknowledge", err, codes.InvalidArgument)
+	_, err = sclient.Acknowledge(ctx, &pb.AcknowledgeRequest{Subscription: "s"})
+	checkCode("Acknowledge", err, codes.NotFound)
+	_, err = sclient.ModifyAckDeadline(ctx, &pb.ModifyAckDeadlineRequest{})
+	checkCode("ModifyAckDeadline", err, codes.InvalidArgument)
+	_, err = sclient.ModifyAckDeadline(ctx, &pb.ModifyAckDeadlineRequest{Subscription: "s"})
+	checkCode("ModifyAckDeadline", err, codes.NotFound)
+
+	_, err = sclient.Seek(ctx, &pb.SeekRequest{})
+	checkCode("Seek", err, codes.InvalidArgument)
+	srt := &pb.SeekRequest_Time{Time: ptypes.TimestampNow()}
+	_, err = sclient.Seek(ctx, &pb.SeekRequest{Target: srt})
+	checkCode("Seek", err, codes.InvalidArgument)
+	_, err = sclient.Seek(ctx, &pb.SeekRequest{Target: srt, Subscription: "s"})
+	checkCode("Seek", err, codes.NotFound)
+}
+
 func TestPublish(t *testing.T) {
 	s := NewServer()
 	var ids []string