fix(pubsub): fix memory leak issue in publish scheduler (#4282)
diff --git a/pubsub/internal/scheduler/publish_scheduler.go b/pubsub/internal/scheduler/publish_scheduler.go
index 9dbf51c..cba172b 100644
--- a/pubsub/internal/scheduler/publish_scheduler.go
+++ b/pubsub/internal/scheduler/publish_scheduler.go
@@ -39,8 +39,8 @@
BufferedByteLimit int
mu sync.Mutex
- bundlers map[string]*bundler.Bundler
- outstanding map[string]int
+ bundlers sync.Map // keys -> *bundler.Bundler
+ outstanding sync.Map // keys -> num outstanding messages
keysMu sync.RWMutex
// keysWithErrors tracks ordering keys that cannot accept new messages.
@@ -76,8 +76,6 @@
}
s := PublishScheduler{
- bundlers: make(map[string]*bundler.Bundler),
- outstanding: make(map[string]int),
keysWithErrors: make(map[string]struct{}),
workers: make(chan struct{}, workers),
handle: handle,
@@ -106,9 +104,11 @@
s.mu.Lock()
defer s.mu.Unlock()
- b, ok := s.bundlers[key]
+ var b *bundler.Bundler
+ bInterface, ok := s.bundlers.Load(key)
+
if !ok {
- s.outstanding[key] = 1
+ s.outstanding.Store(key, 1)
b = bundler.NewBundler(item, func(bundle interface{}) {
s.workers <- struct{}{}
s.handle(bundle)
@@ -116,10 +116,11 @@
nlen := reflect.ValueOf(bundle).Len()
s.mu.Lock()
- s.outstanding[key] -= nlen
- if s.outstanding[key] == 0 {
- delete(s.outstanding, key)
- delete(s.bundlers, key)
+ outsInterface, _ := s.outstanding.Load(key)
+ s.outstanding.Store(key, outsInterface.(int)-nlen)
+ if v, _ := s.outstanding.Load(key); v == 0 {
+ s.outstanding.Delete(key)
+ s.bundlers.Delete(key)
}
s.mu.Unlock()
})
@@ -142,9 +143,13 @@
b.HandlerLimit = 1
}
- s.bundlers[key] = b
+ s.bundlers.Store(key, b)
+ } else {
+ b = bInterface.(*bundler.Bundler)
+ oi, _ := s.outstanding.Load(key)
+ s.outstanding.Store(key, oi.(int)+1)
}
- s.outstanding[key]++
+
return b.Add(item, size)
}
@@ -152,22 +157,25 @@
// blocks until all items have been flushed.
func (s *PublishScheduler) FlushAndStop() {
close(s.done)
- for _, b := range s.bundlers {
- b.Flush()
- }
+ s.bundlers.Range(func(_, bi interface{}) bool {
+ bi.(*bundler.Bundler).Flush()
+ return true
+ })
}
// Flush waits until all bundlers are sent.
func (s *PublishScheduler) Flush() {
var wg sync.WaitGroup
- for _, b := range s.bundlers {
+ s.bundlers.Range(func(_, bi interface{}) bool {
wg.Add(1)
go func(b *bundler.Bundler) {
defer wg.Done()
b.Flush()
- }(b)
- }
+ }(bi.(*bundler.Bundler))
+ return true
+ })
wg.Wait()
+
}
// IsPaused checks if the bundler associated with an ordering keys is
diff --git a/pubsub/loadtest/benchmark_test.go b/pubsub/loadtest/benchmark_test.go
index 885caa2..611002e 100644
--- a/pubsub/loadtest/benchmark_test.go
+++ b/pubsub/loadtest/benchmark_test.go
@@ -46,6 +46,7 @@
batchDuration = 50 * time.Millisecond
serverDelay = 200 * time.Millisecond
maxOutstandingPublishes = 1600 // max_outstanding_messages in run.py
+ useOrdered = true
)
func BenchmarkPublishThroughput(b *testing.B) {
@@ -53,7 +54,8 @@
client := perfClient(serverDelay, 1, b)
lts := &PubServer{ID: "xxx"}
- lts.init(client, "t", messageSize, batchSize, batchDuration)
+ lts.init(client, "t", messageSize, batchSize, batchDuration, useOrdered)
+
b.ResetTimer()
for i := 0; i < b.N; i++ {
runOnce(lts)
diff --git a/pubsub/loadtest/loadtest.go b/pubsub/loadtest/loadtest.go
index d6e9c1b..cc1c734 100644
--- a/pubsub/loadtest/loadtest.go
+++ b/pubsub/loadtest/loadtest.go
@@ -22,6 +22,7 @@
"bytes"
"context"
"errors"
+ "fmt"
"log"
"runtime"
"strconv"
@@ -38,6 +39,7 @@
topic *pubsub.Topic
msgData []byte
batchSize int32
+ ordered bool
}
// PubServer is a dummy Pub/Sub server for load testing.
@@ -56,23 +58,26 @@
return nil, err
}
dur := req.PublishBatchDuration.AsDuration()
- l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur)
+ l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur, false)
log.Println("started")
return &pb.StartResponse{}, nil
}
-func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration) {
+func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration, ordered bool) {
topic := c.Topic(topicName)
topic.PublishSettings = pubsub.PublishSettings{
- DelayThreshold: batchDur,
- CountThreshold: 950,
- ByteThreshold: 9500000,
+ DelayThreshold: batchDur,
+ CountThreshold: 950,
+ ByteThreshold: 9500000,
+ BufferedByteLimit: 2e9,
}
+ topic.EnableMessageOrdering = ordered
l.cfg.Store(pubServerConfig{
topic: topic,
msgData: bytes.Repeat([]byte{'A'}, int(msgSize)),
batchSize: batchSize,
+ ordered: ordered,
})
}
@@ -101,14 +106,18 @@
rs := make([]*pubsub.PublishResult, cfg.batchSize)
for i := int32(0); i < cfg.batchSize; i++ {
- rs[i] = cfg.topic.Publish(context.TODO(), &pubsub.Message{
+ msg := &pubsub.Message{
Data: cfg.msgData,
Attributes: map[string]string{
"sendTime": startStr,
"clientId": l.ID,
"sequenceNumber": strconv.Itoa(int(seqNum + i)),
},
- })
+ }
+ if cfg.ordered {
+ msg.OrderingKey = fmt.Sprintf("key-%d", seqNum+i)
+ }
+ rs[i] = cfg.topic.Publish(context.TODO(), msg)
}
for i, r := range rs {
_, err := r.Get(context.Background())
diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go
index fcc546b..23b5bc5 100644
--- a/pubsub/pstest/fake.go
+++ b/pubsub/pstest/fake.go
@@ -185,8 +185,6 @@
// AddPublishResponse adds a new publish response to the channel used for
// responding to publish requests.
func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) {
- s.GServer.mu.Lock()
- defer s.GServer.mu.Unlock()
pr := &publishResponse{}
if err != nil {
pr.err = err
diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go
index 4546f46..279bcc0 100644
--- a/pubsub/pstest/fake_test.go
+++ b/pubsub/pstest/fake_test.go
@@ -1076,4 +1076,15 @@
if want := "2"; got != want {
t.Fatalf("srv.Publish(): got %v, want %v", got, want)
}
+
+ go func() {
+ got = srv.Publish("projects/p/topics/t", []byte("msg4"), nil)
+ if want := "3"; got != want {
+ fmt.Printf("srv.Publish(): got %v, want %v", got, want)
+ }
+ }()
+ time.Sleep(5 * time.Second)
+ srv.AddPublishResponse(&pb.PublishResponse{
+ MessageIds: []string{"3"},
+ }, nil)
}
diff --git a/pubsub/topic_test.go b/pubsub/topic_test.go
index 1707ba1..d216cd3 100644
--- a/pubsub/topic_test.go
+++ b/pubsub/topic_test.go
@@ -321,31 +321,63 @@
// Subsequent publishes after a flush should succeed.
topic.Flush()
- r := topic.Publish(ctx, &Message{
+ r1 := topic.Publish(ctx, &Message{
Data: []byte("hello"),
})
- _, err = r.Get(ctx)
+ _, err = r1.Get(ctx)
if err != nil {
t.Errorf("got err: %v", err)
}
// Publishing after a flush should succeed.
topic.Flush()
- r = topic.Publish(ctx, &Message{
+ r2 := topic.Publish(ctx, &Message{
Data: []byte("world"),
})
- _, err = r.Get(ctx)
+ _, err = r2.Get(ctx)
if err != nil {
t.Errorf("got err: %v", err)
}
+ // Publishing after a temporarily blocked flush should succeed.
+ srv.SetAutoPublishResponse(false)
+
+ r3 := topic.Publish(ctx, &Message{
+ Data: []byte("blocking message publish"),
+ })
+ go func() {
+ topic.Flush()
+ }()
+
+ // Wait a second between publishes to ensure messages are not bundled together.
+ time.Sleep(1 * time.Second)
+ r4 := topic.Publish(ctx, &Message{
+ Data: []byte("message published after flush"),
+ })
+
+ // Wait 5 seconds to simulate network delay.
+ time.Sleep(5 * time.Second)
+ srv.AddPublishResponse(&pubsubpb.PublishResponse{
+ MessageIds: []string{"1"},
+ }, nil)
+ srv.AddPublishResponse(&pubsubpb.PublishResponse{
+ MessageIds: []string{"2"},
+ }, nil)
+
+ if _, err = r3.Get(ctx); err != nil {
+ t.Errorf("got err: %v", err)
+ }
+ if _, err = r4.Get(ctx); err != nil {
+ t.Errorf("got err: %v", err)
+ }
+
// Publishing after Stop should fail.
+ srv.SetAutoPublishResponse(true)
topic.Stop()
- r = topic.Publish(ctx, &Message{
+ r5 := topic.Publish(ctx, &Message{
Data: []byte("this should fail"),
})
- _, err = r.Get(ctx)
- if err != errTopicStopped {
+ if _, err := r5.Get(ctx); err != errTopicStopped {
t.Errorf("got %v, want errTopicStopped", err)
}
}