fix(pubsub): fix memory leak issue in publish scheduler (#4282)

diff --git a/pubsub/internal/scheduler/publish_scheduler.go b/pubsub/internal/scheduler/publish_scheduler.go
index 9dbf51c..cba172b 100644
--- a/pubsub/internal/scheduler/publish_scheduler.go
+++ b/pubsub/internal/scheduler/publish_scheduler.go
@@ -39,8 +39,8 @@
 	BufferedByteLimit    int
 
 	mu          sync.Mutex
-	bundlers    map[string]*bundler.Bundler
-	outstanding map[string]int
+	bundlers    sync.Map // keys -> *bundler.Bundler
+	outstanding sync.Map // keys -> num outstanding messages
 
 	keysMu sync.RWMutex
 	// keysWithErrors tracks ordering keys that cannot accept new messages.
@@ -76,8 +76,6 @@
 	}
 
 	s := PublishScheduler{
-		bundlers:       make(map[string]*bundler.Bundler),
-		outstanding:    make(map[string]int),
 		keysWithErrors: make(map[string]struct{}),
 		workers:        make(chan struct{}, workers),
 		handle:         handle,
@@ -106,9 +104,11 @@
 
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	b, ok := s.bundlers[key]
+	var b *bundler.Bundler
+	bInterface, ok := s.bundlers.Load(key)
+
 	if !ok {
-		s.outstanding[key] = 1
+		s.outstanding.Store(key, 1)
 		b = bundler.NewBundler(item, func(bundle interface{}) {
 			s.workers <- struct{}{}
 			s.handle(bundle)
@@ -116,10 +116,11 @@
 
 			nlen := reflect.ValueOf(bundle).Len()
 			s.mu.Lock()
-			s.outstanding[key] -= nlen
-			if s.outstanding[key] == 0 {
-				delete(s.outstanding, key)
-				delete(s.bundlers, key)
+			outsInterface, _ := s.outstanding.Load(key)
+			s.outstanding.Store(key, outsInterface.(int)-nlen)
+			if v, _ := s.outstanding.Load(key); v == 0 {
+				s.outstanding.Delete(key)
+				s.bundlers.Delete(key)
 			}
 			s.mu.Unlock()
 		})
@@ -142,9 +143,13 @@
 			b.HandlerLimit = 1
 		}
 
-		s.bundlers[key] = b
+		s.bundlers.Store(key, b)
+	} else {
+		b = bInterface.(*bundler.Bundler)
+		oi, _ := s.outstanding.Load(key)
+		s.outstanding.Store(key, oi.(int)+1)
 	}
-	s.outstanding[key]++
+
 	return b.Add(item, size)
 }
 
@@ -152,22 +157,25 @@
 // blocks until all items have been flushed.
 func (s *PublishScheduler) FlushAndStop() {
 	close(s.done)
-	for _, b := range s.bundlers {
-		b.Flush()
-	}
+	s.bundlers.Range(func(_, bi interface{}) bool {
+		bi.(*bundler.Bundler).Flush()
+		return true
+	})
 }
 
 // Flush waits until all bundlers are sent.
 func (s *PublishScheduler) Flush() {
 	var wg sync.WaitGroup
-	for _, b := range s.bundlers {
+	s.bundlers.Range(func(_, bi interface{}) bool {
 		wg.Add(1)
 		go func(b *bundler.Bundler) {
 			defer wg.Done()
 			b.Flush()
-		}(b)
-	}
+		}(bi.(*bundler.Bundler))
+		return true
+	})
 	wg.Wait()
+
 }
 
 // IsPaused checks if the bundler associated with an ordering keys is
diff --git a/pubsub/loadtest/benchmark_test.go b/pubsub/loadtest/benchmark_test.go
index 885caa2..611002e 100644
--- a/pubsub/loadtest/benchmark_test.go
+++ b/pubsub/loadtest/benchmark_test.go
@@ -46,6 +46,7 @@
 	batchDuration           = 50 * time.Millisecond
 	serverDelay             = 200 * time.Millisecond
 	maxOutstandingPublishes = 1600 // max_outstanding_messages in run.py
+	useOrdered              = true
 )
 
 func BenchmarkPublishThroughput(b *testing.B) {
@@ -53,7 +54,8 @@
 	client := perfClient(serverDelay, 1, b)
 
 	lts := &PubServer{ID: "xxx"}
-	lts.init(client, "t", messageSize, batchSize, batchDuration)
+	lts.init(client, "t", messageSize, batchSize, batchDuration, useOrdered)
+
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
 		runOnce(lts)
diff --git a/pubsub/loadtest/loadtest.go b/pubsub/loadtest/loadtest.go
index d6e9c1b..cc1c734 100644
--- a/pubsub/loadtest/loadtest.go
+++ b/pubsub/loadtest/loadtest.go
@@ -22,6 +22,7 @@
 	"bytes"
 	"context"
 	"errors"
+	"fmt"
 	"log"
 	"runtime"
 	"strconv"
@@ -38,6 +39,7 @@
 	topic     *pubsub.Topic
 	msgData   []byte
 	batchSize int32
+	ordered   bool
 }
 
 // PubServer is a dummy Pub/Sub server for load testing.
@@ -56,23 +58,26 @@
 		return nil, err
 	}
 	dur := req.PublishBatchDuration.AsDuration()
-	l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur)
+	l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur, false)
 	log.Println("started")
 	return &pb.StartResponse{}, nil
 }
 
-func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration) {
+func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration, ordered bool) {
 	topic := c.Topic(topicName)
 	topic.PublishSettings = pubsub.PublishSettings{
-		DelayThreshold: batchDur,
-		CountThreshold: 950,
-		ByteThreshold:  9500000,
+		DelayThreshold:    batchDur,
+		CountThreshold:    950,
+		ByteThreshold:     9500000,
+		BufferedByteLimit: 2e9,
 	}
+	topic.EnableMessageOrdering = ordered
 
 	l.cfg.Store(pubServerConfig{
 		topic:     topic,
 		msgData:   bytes.Repeat([]byte{'A'}, int(msgSize)),
 		batchSize: batchSize,
+		ordered:   ordered,
 	})
 }
 
@@ -101,14 +106,18 @@
 
 	rs := make([]*pubsub.PublishResult, cfg.batchSize)
 	for i := int32(0); i < cfg.batchSize; i++ {
-		rs[i] = cfg.topic.Publish(context.TODO(), &pubsub.Message{
+		msg := &pubsub.Message{
 			Data: cfg.msgData,
 			Attributes: map[string]string{
 				"sendTime":       startStr,
 				"clientId":       l.ID,
 				"sequenceNumber": strconv.Itoa(int(seqNum + i)),
 			},
-		})
+		}
+		if cfg.ordered {
+			msg.OrderingKey = fmt.Sprintf("key-%d", seqNum+i)
+		}
+		rs[i] = cfg.topic.Publish(context.TODO(), msg)
 	}
 	for i, r := range rs {
 		_, err := r.Get(context.Background())
diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go
index fcc546b..23b5bc5 100644
--- a/pubsub/pstest/fake.go
+++ b/pubsub/pstest/fake.go
@@ -185,8 +185,6 @@
 // AddPublishResponse adds a new publish response to the channel used for
 // responding to publish requests.
 func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) {
-	s.GServer.mu.Lock()
-	defer s.GServer.mu.Unlock()
 	pr := &publishResponse{}
 	if err != nil {
 		pr.err = err
diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go
index 4546f46..279bcc0 100644
--- a/pubsub/pstest/fake_test.go
+++ b/pubsub/pstest/fake_test.go
@@ -1076,4 +1076,15 @@
 	if want := "2"; got != want {
 		t.Fatalf("srv.Publish(): got %v, want %v", got, want)
 	}
+
+	go func() {
+		got = srv.Publish("projects/p/topics/t", []byte("msg4"), nil)
+		if want := "3"; got != want {
+			fmt.Printf("srv.Publish(): got %v, want %v", got, want)
+		}
+	}()
+	time.Sleep(5 * time.Second)
+	srv.AddPublishResponse(&pb.PublishResponse{
+		MessageIds: []string{"3"},
+	}, nil)
 }
diff --git a/pubsub/topic_test.go b/pubsub/topic_test.go
index 1707ba1..d216cd3 100644
--- a/pubsub/topic_test.go
+++ b/pubsub/topic_test.go
@@ -321,31 +321,63 @@
 
 	// Subsequent publishes after a flush should succeed.
 	topic.Flush()
-	r := topic.Publish(ctx, &Message{
+	r1 := topic.Publish(ctx, &Message{
 		Data: []byte("hello"),
 	})
-	_, err = r.Get(ctx)
+	_, err = r1.Get(ctx)
 	if err != nil {
 		t.Errorf("got err: %v", err)
 	}
 
 	// Publishing after a flush should succeed.
 	topic.Flush()
-	r = topic.Publish(ctx, &Message{
+	r2 := topic.Publish(ctx, &Message{
 		Data: []byte("world"),
 	})
-	_, err = r.Get(ctx)
+	_, err = r2.Get(ctx)
 	if err != nil {
 		t.Errorf("got err: %v", err)
 	}
 
+	// Publishing after a temporarily blocked flush should succeed.
+	srv.SetAutoPublishResponse(false)
+
+	r3 := topic.Publish(ctx, &Message{
+		Data: []byte("blocking message publish"),
+	})
+	go func() {
+		topic.Flush()
+	}()
+
+	// Wait a second between publishes to ensure messages are not bundled together.
+	time.Sleep(1 * time.Second)
+	r4 := topic.Publish(ctx, &Message{
+		Data: []byte("message published after flush"),
+	})
+
+	// Wait 5 seconds to simulate network delay.
+	time.Sleep(5 * time.Second)
+	srv.AddPublishResponse(&pubsubpb.PublishResponse{
+		MessageIds: []string{"1"},
+	}, nil)
+	srv.AddPublishResponse(&pubsubpb.PublishResponse{
+		MessageIds: []string{"2"},
+	}, nil)
+
+	if _, err = r3.Get(ctx); err != nil {
+		t.Errorf("got err: %v", err)
+	}
+	if _, err = r4.Get(ctx); err != nil {
+		t.Errorf("got err: %v", err)
+	}
+
 	// Publishing after Stop should fail.
+	srv.SetAutoPublishResponse(true)
 	topic.Stop()
-	r = topic.Publish(ctx, &Message{
+	r5 := topic.Publish(ctx, &Message{
 		Data: []byte("this should fail"),
 	})
-	_, err = r.Get(ctx)
-	if err != errTopicStopped {
+	if _, err := r5.Get(ctx); err != errTopicStopped {
 		t.Errorf("got %v, want errTopicStopped", err)
 	}
 }