pubsub: unflake TestMultiStreams

- Re-write TestMultiStreams to take into consideration the
fail-fast behavior added in https://code-review.googlesource.com/c/gocloud/+/34191.
Now, TestMultiStreams will only assert that multiple streams
get messages, not that multiple messages are round robined in
a certain manner (which is no longer strictly true - it's best
effort, which caused flakiness). Formerly, TestMultiStreams
would fail within ~100 runs. Now it passes with -count 10000.

- Refactor TestMultiStreams to have zero goroutine leaks, as
tested with leakcheck.Check.

- Refactor many pstest tests to use context.TODO instead of
context.Background for future goroutine cleanups.

- Refactor newFake to return a cleanup function that cleans
up both the server and the conn object (which was being leaked).

Change-Id: I948c73bbb185fc6f0929f7ab61883486e202b764
Reviewed-on: https://code-review.googlesource.com/c/34870
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Eno Compton <enocom@google.com>
diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go
index 3e8cbb1..268ecc9 100644
--- a/pubsub/pstest/fake_test.go
+++ b/pubsub/pstest/fake_test.go
@@ -31,13 +31,13 @@
 )
 
 func TestTopics(t *testing.T) {
-	pclient, _, server := newFake(t)
-	defer server.Close()
+	pclient, _, server, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
 	ctx := context.Background()
 	var topics []*pb.Topic
 	for i := 1; i < 3; i++ {
-		topics = append(topics, mustCreateTopic(t, pclient, &pb.Topic{
+		topics = append(topics, mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{
 			Name:   fmt.Sprintf("projects/P/topics/T%d", i),
 			Labels: map[string]string{"num": fmt.Sprintf("%d", i)},
 		}))
@@ -74,14 +74,14 @@
 }
 
 func TestSubscriptions(t *testing.T) {
-	pclient, sclient, server := newFake(t)
-	defer server.Close()
+	pclient, sclient, server, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
 	ctx := context.Background()
-	topic := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	topic := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
 	var subs []*pb.Subscription
 	for i := 0; i < 3; i++ {
-		subs = append(subs, mustCreateSubscription(t, sclient, &pb.Subscription{
+		subs = append(subs, mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 			Name:               fmt.Sprintf("projects/P/subscriptions/S%d", i),
 			Topic:              topic.Name,
 			AckDeadlineSeconds: int32(10 * (i + 1)),
@@ -134,8 +134,8 @@
 }
 
 func TestSubscriptionErrors(t *testing.T) {
-	_, sclient, srv := newFake(t)
-	defer srv.Close()
+	_, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
 	ctx := context.Background()
 
@@ -236,11 +236,11 @@
 }
 
 func TestPull(t *testing.T) {
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
@@ -251,7 +251,7 @@
 		{Data: []byte("d2")},
 		{Data: []byte("d3")},
 	})
-	got := pubsubMessages(pullN(t, len(want), sclient, sub))
+	got := pubsubMessages(pullN(context.TODO(), t, len(want), sclient, sub))
 	if diff := testutil.Diff(got, want); diff != "" {
 		t.Error(diff)
 	}
@@ -267,11 +267,11 @@
 
 func TestStreamingPull(t *testing.T) {
 	// A simple test of streaming pull.
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
@@ -282,7 +282,7 @@
 		{Data: []byte("d2")},
 		{Data: []byte("d3")},
 	})
-	got := pubsubMessages(streamingPullN(t, len(want), sclient, sub))
+	got := pubsubMessages(streamingPullN(context.TODO(), t, len(want), sclient, sub))
 	if diff := testutil.Diff(got, want); diff != "" {
 		t.Error(diff)
 	}
@@ -291,11 +291,11 @@
 func TestStreamingPullAck(t *testing.T) {
 	// Ack each message as it arrives. Make sure we don't see dups.
 	minAckDeadlineSecs = 1
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 1,
@@ -308,7 +308,7 @@
 	})
 
 	got := map[string]bool{}
-	spc := mustStartStreamingPull(t, sclient, sub)
+	spc := mustStartStreamingPull(context.TODO(), t, sclient, sub)
 	time.AfterFunc(time.Duration(3*minAckDeadlineSecs)*time.Second, func() {
 		if err := spc.CloseSend(); err != nil {
 			t.Errorf("CloseSend: %v", err)
@@ -339,11 +339,11 @@
 
 func TestAcknowledge(t *testing.T) {
 	ctx := context.Background()
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, srv, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
@@ -354,7 +354,7 @@
 		{Data: []byte("d2")},
 		{Data: []byte("d3")},
 	})
-	msgs := streamingPullN(t, 3, sclient, sub)
+	msgs := streamingPullN(context.TODO(), t, 3, sclient, sub)
 	var ackIDs []string
 	for _, m := range msgs {
 		ackIDs = append(ackIDs, m.AckId)
@@ -378,11 +378,11 @@
 
 func TestModAck(t *testing.T) {
 	ctx := context.Background()
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
@@ -393,7 +393,7 @@
 		{Data: []byte("d2")},
 		{Data: []byte("d3")},
 	})
-	msgs := streamingPullN(t, 3, sclient, sub)
+	msgs := streamingPullN(context.TODO(), t, 3, sclient, sub)
 	var ackIDs []string
 	for _, m := range msgs {
 		ackIDs = append(ackIDs, m.AckId)
@@ -406,7 +406,7 @@
 		t.Fatal(err)
 	}
 	// Having nacked all three messages, we should see them again.
-	msgs = streamingPullN(t, 3, sclient, sub)
+	msgs = streamingPullN(context.TODO(), t, 3, sclient, sub)
 	if got, want := len(msgs), 3; got != want {
 		t.Errorf("got %d messages, want %d", got, want)
 	}
@@ -414,12 +414,12 @@
 
 func TestAckDeadline(t *testing.T) {
 	// Messages should be resent after they expire.
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
 	minAckDeadlineSecs = 2
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: minAckDeadlineSecs,
@@ -432,7 +432,7 @@
 	})
 
 	got := map[string]int{}
-	spc := mustStartStreamingPull(t, sclient, sub)
+	spc := mustStartStreamingPull(context.TODO(), t, sclient, sub)
 	// In 5 seconds the ack deadline will expire twice, so we should see each message
 	// exactly three times.
 	time.AfterFunc(5*time.Second, func() {
@@ -461,16 +461,16 @@
 
 func TestMultiSubs(t *testing.T) {
 	// Each subscription gets every message.
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub1 := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub1 := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S1",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
 	})
-	sub2 := mustCreateSubscription(t, sclient, &pb.Subscription{
+	sub2 := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S2",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
@@ -481,8 +481,8 @@
 		{Data: []byte("d2")},
 		{Data: []byte("d3")},
 	})
-	got1 := pubsubMessages(streamingPullN(t, len(want), sclient, sub1))
-	got2 := pubsubMessages(streamingPullN(t, len(want), sclient, sub2))
+	got1 := pubsubMessages(streamingPullN(context.TODO(), t, len(want), sclient, sub1))
+	got2 := pubsubMessages(streamingPullN(context.TODO(), t, len(want), sclient, sub2))
 	if diff := testutil.Diff(got1, want); diff != "" {
 		t.Error(diff)
 	}
@@ -491,56 +491,75 @@
 	}
 }
 
+// Messages are handed out to all streams of a subscription in a best-effort
+// round-robin behavior. The fake server prefers to fail-fast onto another
+// stream when one stream is already busy, though, so we're unable to test
+// strict round robin behavior.
 func TestMultiStreams(t *testing.T) {
-	// Messages are handed out to the streams of a subscription in round-robin order.
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	pclient, sclient, _, cleanup := newFake(ctx, t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(ctx, t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(ctx, t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
 	})
-	want := publish(t, pclient, top, []*pb.PubsubMessage{
+	st1 := mustStartStreamingPull(ctx, t, sclient, sub)
+	defer st1.CloseSend()
+	st1Received := make(chan struct{})
+	go func() {
+		_, err := st1.Recv()
+		if err != nil {
+			t.Error(err)
+		}
+		close(st1Received)
+	}()
+
+	st2 := mustStartStreamingPull(ctx, t, sclient, sub)
+	defer st2.CloseSend()
+	st2Received := make(chan struct{})
+	go func() {
+		_, err := st2.Recv()
+		if err != nil {
+			t.Error(err)
+		}
+		close(st2Received)
+	}()
+
+	publish(t, pclient, top, []*pb.PubsubMessage{
 		{Data: []byte("d1")},
 		{Data: []byte("d2")},
-		{Data: []byte("d3")},
-		{Data: []byte("d4")},
 	})
-	streams := []pb.Subscriber_StreamingPullClient{
-		mustStartStreamingPull(t, sclient, sub),
-		mustStartStreamingPull(t, sclient, sub),
+
+	timeout := time.After(5 * time.Second)
+	select {
+	case <-timeout:
+		t.Fatal("timed out waiting for stream 1 to receive any message")
+	case <-st1Received:
 	}
-	got := map[string]*pb.PubsubMessage{}
-	for i := 0; i < 2; i++ {
-		for _, st := range streams {
-			res, err := st.Recv()
-			if err != nil {
-				t.Fatal(err)
-			}
-			m := res.ReceivedMessages[0]
-			got[m.Message.MessageId] = m.Message
-		}
-	}
-	if diff := testutil.Diff(got, want); diff != "" {
-		t.Error(diff)
+	select {
+	case <-timeout:
+		t.Fatal("timed out waiting for stream 1 to receive any message")
+	case <-st2Received:
 	}
 }
 
 func TestStreamingPullTimeout(t *testing.T) {
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, srv, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
 	timeout := 200 * time.Millisecond
 	srv.SetStreamTimeout(timeout)
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
 	})
-	stream := mustStartStreamingPull(t, sclient, sub)
+	stream := mustStartStreamingPull(context.TODO(), t, sclient, sub)
 	time.Sleep(2 * timeout)
 	_, err := stream.Recv()
 	if err != io.EOF {
@@ -549,11 +568,11 @@
 }
 
 func TestSeek(t *testing.T) {
-	pclient, sclient, srv := newFake(t)
-	defer srv.Close()
+	pclient, sclient, _, cleanup := newFake(context.TODO(), t)
+	defer cleanup()
 
-	top := mustCreateTopic(t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
-	sub := mustCreateSubscription(t, sclient, &pb.Subscription{
+	top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"})
+	sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{
 		Name:               "projects/P/subscriptions/S",
 		Topic:              top.Name,
 		AckDeadlineSeconds: 10,
@@ -605,8 +624,8 @@
 	}
 }
 
-func mustStartStreamingPull(t *testing.T, sc pb.SubscriberClient, sub *pb.Subscription) pb.Subscriber_StreamingPullClient {
-	spc, err := sc.StreamingPull(context.Background())
+func mustStartStreamingPull(ctx context.Context, t *testing.T, sc pb.SubscriberClient, sub *pb.Subscription) pb.Subscriber_StreamingPullClient {
+	spc, err := sc.StreamingPull(ctx)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -616,8 +635,7 @@
 	return spc
 }
 
-func pullN(t *testing.T, n int, sc pb.SubscriberClient, sub *pb.Subscription) map[string]*pb.ReceivedMessage {
-	ctx := context.Background()
+func pullN(ctx context.Context, t *testing.T, n int, sc pb.SubscriberClient, sub *pb.Subscription) map[string]*pb.ReceivedMessage {
 	got := map[string]*pb.ReceivedMessage{}
 	for i := 0; len(got) < n; i++ {
 		res, err := sc.Pull(ctx, &pb.PullRequest{Subscription: sub.Name, MaxMessages: int32(n - len(got))})
@@ -631,8 +649,8 @@
 	return got
 }
 
-func streamingPullN(t *testing.T, n int, sc pb.SubscriberClient, sub *pb.Subscription) map[string]*pb.ReceivedMessage {
-	spc := mustStartStreamingPull(t, sc, sub)
+func streamingPullN(ctx context.Context, t *testing.T, n int, sc pb.SubscriberClient, sub *pb.Subscription) map[string]*pb.ReceivedMessage {
+	spc := mustStartStreamingPull(ctx, t, sc, sub)
 	got := map[string]*pb.ReceivedMessage{}
 	for i := 0; i < n; i++ {
 		res, err := spc.Recv()
@@ -661,28 +679,34 @@
 	return ms
 }
 
-func mustCreateTopic(t *testing.T, pc pb.PublisherClient, topic *pb.Topic) *pb.Topic {
-	top, err := pc.CreateTopic(context.Background(), topic)
+func mustCreateTopic(ctx context.Context, t *testing.T, pc pb.PublisherClient, topic *pb.Topic) *pb.Topic {
+	top, err := pc.CreateTopic(ctx, topic)
 	if err != nil {
 		t.Fatal(err)
 	}
 	return top
 }
 
-func mustCreateSubscription(t *testing.T, sc pb.SubscriberClient, sub *pb.Subscription) *pb.Subscription {
-	sub, err := sc.CreateSubscription(context.Background(), sub)
+func mustCreateSubscription(ctx context.Context, t *testing.T, sc pb.SubscriberClient, sub *pb.Subscription) *pb.Subscription {
+	sub, err := sc.CreateSubscription(ctx, sub)
 	if err != nil {
 		t.Fatal(err)
 	}
 	return sub
 }
 
-// Note: be sure to close server!
-func newFake(t *testing.T) (pb.PublisherClient, pb.SubscriberClient, *Server) {
+// newFake creates a new fake server along  with a publisher and subscriber
+// client. Its final return is a cleanup function.
+//
+// Note: be sure to call cleanup!
+func newFake(ctx context.Context, t *testing.T) (pb.PublisherClient, pb.SubscriberClient, *Server, func()) {
 	srv := NewServer()
-	conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
+	conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure())
 	if err != nil {
 		t.Fatal(err)
 	}
-	return pb.NewPublisherClient(conn), pb.NewSubscriberClient(conn), srv
+	return pb.NewPublisherClient(conn), pb.NewSubscriberClient(conn), srv, func() {
+		srv.Close()
+		conn.Close()
+	}
 }