pstest: better error handling for subscriptions
In all RPCs that deal with subscriptions, check for the absence of a
subscription in the request as well as a non-existent subscription.
Change-Id: I1ce62f556bfc10705aa2fe69c1da74b6b0b99b5a
Reviewed-on: https://code-review.googlesource.com/33110
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go
index 9fca076..21ce6b7 100644
--- a/pubsub/pstest/fake.go
+++ b/pubsub/pstest/fake.go
@@ -370,22 +370,23 @@
func (s *gServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
-
- if sub := s.subs[req.Subscription]; sub != nil {
- return sub.proto, nil
+ sub, err := s.findSubscription(req.Subscription)
+ if err != nil {
+ return nil, err
}
- return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
+ return sub.proto, nil
}
func (s *gServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
+ if req.Subscription == nil {
+ return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
+ }
s.mu.Lock()
defer s.mu.Unlock()
-
- sub := s.subs[req.Subscription.Name]
- if sub == nil {
- return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name)
+ sub, err := s.findSubscription(req.Subscription.Name)
+ if err != nil {
+ return nil, err
}
-
for _, path := range req.UpdateMask.Paths {
switch path {
case "push_config":
@@ -442,10 +443,9 @@
func (s *gServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
-
- sub := s.subs[req.Subscription]
- if sub == nil {
- return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
+ sub, err := s.findSubscription(req.Subscription)
+ if err != nil {
+ return nil, err
}
sub.stop()
delete(s.subs, req.Subscription)
@@ -574,10 +574,11 @@
func (s *gServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
- if req.Subscription == "" {
- return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
+
+ sub, err := s.findSubscription(req.Subscription)
+ if err != nil {
+ return nil, err
}
- sub := s.subs[req.Subscription]
for _, id := range req.AckIds {
sub.ack(id)
}
@@ -587,16 +588,14 @@
func (s *gServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
-
+ sub, err := s.findSubscription(req.Subscription)
+ if err != nil {
+ return nil, err
+ }
now := time.Now()
for _, id := range req.AckIds {
s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
}
-
- if req.Subscription == "" {
- return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
- }
- sub := s.subs[req.Subscription]
dur := secsToDur(req.AckDeadlineSeconds)
for _, id := range req.AckIds {
sub.modifyAckDeadline(id, dur)
@@ -610,14 +609,11 @@
if err != nil {
return err
}
- if req.Subscription == "" {
- return status.Errorf(codes.InvalidArgument, "missing subscription")
- }
s.mu.Lock()
- sub := s.subs[req.Subscription]
+ sub, err := s.findSubscription(req.Subscription)
s.mu.Unlock()
- if sub == nil {
- return status.Errorf(codes.NotFound, "subscription %s", req.Subscription)
+ if err != nil {
+ return err
}
// Create a new stream to handle the pull.
st := sub.newStream(sps, s.streamTimeout)
@@ -631,6 +627,8 @@
// This fake doesn't deal with snapshots.
var target time.Time
switch v := req.Target.(type) {
+ case nil:
+ return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
case *pb.SeekRequest_Time:
var err error
target, err = ptypes.Timestamp(v.Time)
@@ -645,12 +643,10 @@
// because the messages don't have any other synchronization.
s.mu.Lock()
defer s.mu.Unlock()
-
- sub := s.subs[req.Subscription]
- if sub == nil {
- return nil, status.Errorf(codes.NotFound, "subscription %s", req.Subscription)
+ sub, err := s.findSubscription(req.Subscription)
+ if err != nil {
+ return nil, err
}
-
// Drop all messages from sub that were published before the target time.
for id, m := range sub.msgs {
if m.publishTime.Before(target) {
@@ -679,6 +675,19 @@
return &pb.SeekResponse{}, nil
}
+// Gets a subscription that must exist.
+// Must be called with the lock held.
+func (s *gServer) findSubscription(name string) (*subscription, error) {
+ if name == "" {
+ return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
+ }
+ sub := s.subs[name]
+ if sub == nil {
+ return nil, status.Errorf(codes.NotFound, "subscription %s", name)
+ }
+ return sub, nil
+}
+
var retentionDuration = 10 * time.Minute
func (s *subscription) deliver() {
diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go
index a0330e7..faad5b2 100644
--- a/pubsub/pstest/fake_test.go
+++ b/pubsub/pstest/fake_test.go
@@ -26,6 +26,8 @@
"golang.org/x/net/context"
pb "google.golang.org/genproto/googleapis/pubsub/v1"
"google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
)
func TestTopics(t *testing.T) {
@@ -127,6 +129,49 @@
}
}
+func TestSubscriptionErrors(t *testing.T) {
+ _, sclient, _ := newFake(t)
+ ctx := context.Background()
+
+ // TODO(jba): Go1.9: use t.Helper()
+ checkCode := func(msg string, err error, want codes.Code) {
+ if status.Code(err) != want {
+ t.Errorf("%s: got %v, want code %s", msg, err, want)
+ }
+ }
+
+ _, err := sclient.GetSubscription(ctx, &pb.GetSubscriptionRequest{})
+ checkCode("GetSubscription", err, codes.InvalidArgument)
+ _, err = sclient.GetSubscription(ctx, &pb.GetSubscriptionRequest{Subscription: "s"})
+ checkCode("GetSubscription", err, codes.NotFound)
+ _, err = sclient.UpdateSubscription(ctx, &pb.UpdateSubscriptionRequest{})
+ checkCode("UpdateSubscription", err, codes.InvalidArgument)
+ _, err = sclient.UpdateSubscription(ctx, &pb.UpdateSubscriptionRequest{Subscription: &pb.Subscription{}})
+ checkCode("UpdateSubscription", err, codes.InvalidArgument)
+ _, err = sclient.UpdateSubscription(ctx, &pb.UpdateSubscriptionRequest{Subscription: &pb.Subscription{Name: "s"}})
+ checkCode("UpdateSubscription", err, codes.NotFound)
+ _, err = sclient.DeleteSubscription(ctx, &pb.DeleteSubscriptionRequest{})
+ checkCode("DeleteSubscription", err, codes.InvalidArgument)
+ _, err = sclient.DeleteSubscription(ctx, &pb.DeleteSubscriptionRequest{Subscription: "s"})
+ checkCode("DeleteSubscription", err, codes.NotFound)
+ _, err = sclient.Acknowledge(ctx, &pb.AcknowledgeRequest{})
+ checkCode("Acknowledge", err, codes.InvalidArgument)
+ _, err = sclient.Acknowledge(ctx, &pb.AcknowledgeRequest{Subscription: "s"})
+ checkCode("Acknowledge", err, codes.NotFound)
+ _, err = sclient.ModifyAckDeadline(ctx, &pb.ModifyAckDeadlineRequest{})
+ checkCode("ModifyAckDeadline", err, codes.InvalidArgument)
+ _, err = sclient.ModifyAckDeadline(ctx, &pb.ModifyAckDeadlineRequest{Subscription: "s"})
+ checkCode("ModifyAckDeadline", err, codes.NotFound)
+
+ _, err = sclient.Seek(ctx, &pb.SeekRequest{})
+ checkCode("Seek", err, codes.InvalidArgument)
+ srt := &pb.SeekRequest_Time{Time: ptypes.TimestampNow()}
+ _, err = sclient.Seek(ctx, &pb.SeekRequest{Target: srt})
+ checkCode("Seek", err, codes.InvalidArgument)
+ _, err = sclient.Seek(ctx, &pb.SeekRequest{Target: srt, Subscription: "s"})
+ checkCode("Seek", err, codes.NotFound)
+}
+
func TestPublish(t *testing.T) {
s := NewServer()
var ids []string