fix(pubsublite): rebatch messages upon new publish stream (#3694)
Merge in-flight message batches when a new publish stream is connected.
diff --git a/pubsublite/internal/wire/publish_batcher.go b/pubsublite/internal/wire/publish_batcher.go
index 63999b1..e50d3af 100644
--- a/pubsublite/internal/wire/publish_batcher.go
+++ b/pubsublite/internal/wire/publish_batcher.go
@@ -41,6 +41,7 @@
// MessagePublishRequest.
type publishBatch struct {
msgHolders []*messageHolder
+ totalSize int
}
func (b *publishBatch) ToPublishRequest() *pb.PublishRequest {
@@ -93,7 +94,11 @@
// singlePartitionPublisher.onNewBatch() receives the new batch from the
// Bundler, which calls publishMessageBatcher.AddBatch(). Only the
// publisher's mutex is required.
- onNewBatch(&publishBatch{msgHolders: msgs})
+ batch := &publishBatch{msgHolders: msgs}
+ for _, msg := range batch.msgHolders {
+ batch.totalSize += msg.size
+ }
+ onNewBatch(batch)
})
msgBundler.DelayThreshold = settings.DelayThreshold
msgBundler.BundleCountThreshold = settings.CountThreshold
@@ -164,10 +169,24 @@
func (b *publishMessageBatcher) InFlightBatches() []*publishBatch {
var batches []*publishBatch
- for elem := b.publishQueue.Front(); elem != nil; elem = elem.Next() {
- if batch, ok := elem.Value.(*publishBatch); ok {
- batches = append(batches, batch)
+ for elem := b.publishQueue.Front(); elem != nil; {
+ batch := elem.Value.(*publishBatch)
+ if elem.Prev() != nil {
+ // Merge current batch with previous if within max bytes and count limits.
+ prevBatch := elem.Prev().Value.(*publishBatch)
+ totalSize := prevBatch.totalSize + batch.totalSize
+ totalLen := len(prevBatch.msgHolders) + len(batch.msgHolders)
+ if totalSize <= MaxPublishRequestBytes && totalLen <= MaxPublishRequestCount {
+ prevBatch.totalSize = totalSize
+ prevBatch.msgHolders = append(prevBatch.msgHolders, batch.msgHolders...)
+ removeElem := elem
+ elem = elem.Next()
+ b.publishQueue.Remove(removeElem)
+ continue
+ }
}
+ batches = append(batches, batch)
+ elem = elem.Next()
}
return batches
}
diff --git a/pubsublite/internal/wire/publish_batcher_test.go b/pubsublite/internal/wire/publish_batcher_test.go
index 690c7ed..014bd0d 100644
--- a/pubsublite/internal/wire/publish_batcher_test.go
+++ b/pubsublite/internal/wire/publish_batcher_test.go
@@ -128,8 +128,8 @@
}
}
- if !testutil.Equal(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})) {
- br.t.Errorf("Batches got: %v\nwant: %v", got, want)
+ if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+ br.t.Errorf("Batches got: -, want: +\n%s", diff)
}
}
@@ -144,6 +144,15 @@
return h
}
+func makePublishBatch(msgs ...*messageHolder) *publishBatch {
+ batch := new(publishBatch)
+ for _, msg := range msgs {
+ batch.msgHolders = append(batch.msgHolders, msg)
+ batch.totalSize += msg.size
+ }
+ return batch
+}
+
func TestPublishBatcherAddMessage(t *testing.T) {
const initAvailableBytes = MaxPublishRequestBytes
settings := DefaultPublishSettings
@@ -199,22 +208,16 @@
// Batch 1
msg1 := &pb.PubSubMessage{Data: []byte{'1'}}
msg2 := &pb.PubSubMessage{Data: []byte{'2'}}
- wantBatch1 := &publishBatch{
- []*messageHolder{makeMsgHolder(msg1), makeMsgHolder(msg2)},
- }
+ wantBatch1 := makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2))
// Batch 2
msg3 := &pb.PubSubMessage{Data: []byte{'3'}}
msg4 := &pb.PubSubMessage{Data: []byte{'4'}}
- wantBatch2 := &publishBatch{
- []*messageHolder{makeMsgHolder(msg3), makeMsgHolder(msg4)},
- }
+ wantBatch2 := makePublishBatch(makeMsgHolder(msg3), makeMsgHolder(msg4))
// Batch 3
msg5 := &pb.PubSubMessage{Data: []byte{'5'}}
- wantBatch3 := &publishBatch{
- []*messageHolder{makeMsgHolder(msg5)},
- }
+ wantBatch3 := makePublishBatch(makeMsgHolder(msg5))
receiver := newTestPublishBatchReceiver(t)
batcher := newPublishMessageBatcher(&settings, 0, receiver.onNewBatch)
@@ -236,15 +239,11 @@
// Batch 1
msg1 := &pb.PubSubMessage{Data: []byte{'1'}}
- wantBatch1 := &publishBatch{
- []*messageHolder{makeMsgHolder(msg1)},
- }
+ wantBatch1 := makePublishBatch(makeMsgHolder(msg1))
// Batch 2
msg2 := &pb.PubSubMessage{Data: []byte{'2'}}
- wantBatch2 := &publishBatch{
- []*messageHolder{makeMsgHolder(msg2)},
- }
+ wantBatch2 := makePublishBatch(makeMsgHolder(msg2))
receiver := newTestPublishBatchReceiver(t)
batcher := newPublishMessageBatcher(&settings, 0, receiver.onNewBatch)
@@ -271,12 +270,7 @@
msg2 := &pb.PubSubMessage{Data: []byte{'2'}}
pubResult1 := newTestPublishResultReceiver(t, msg1)
pubResult2 := newTestPublishResultReceiver(t, msg2)
- batcher.AddBatch(&publishBatch{
- []*messageHolder{
- makeMsgHolder(msg1, pubResult1),
- makeMsgHolder(msg2, pubResult2),
- },
- })
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1, pubResult1), makeMsgHolder(msg2, pubResult2)))
wantErr := status.Error(codes.FailedPrecondition, "failed")
batcher.OnPermanentError(wantErr)
@@ -306,17 +300,8 @@
pubResult2 := newTestPublishResultReceiver(t, msg2)
pubResult3 := newTestPublishResultReceiver(t, msg3)
- batcher.AddBatch(&publishBatch{
- []*messageHolder{
- makeMsgHolder(msg1, pubResult1),
- makeMsgHolder(msg2, pubResult2),
- },
- })
- batcher.AddBatch(&publishBatch{
- []*messageHolder{
- makeMsgHolder(msg3, pubResult3),
- },
- })
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1, pubResult1), makeMsgHolder(msg2, pubResult2)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3, pubResult3)))
if err := batcher.OnPublishResponse(70); err != nil {
t.Errorf("OnPublishResponse() got err: %v", err)
}
@@ -332,14 +317,126 @@
t.Run("inconsistent offset", func(t *testing.T) {
msg := &pb.PubSubMessage{Data: []byte{'4'}}
pubResult := newTestPublishResultReceiver(t, msg)
- batcher.AddBatch(&publishBatch{
- []*messageHolder{
- makeMsgHolder(msg, pubResult),
- },
- })
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg, pubResult)))
if gotErr, wantMsg := batcher.OnPublishResponse(80), "inconsistent start offset = 80"; !test.ErrorHasMsg(gotErr, wantMsg) {
t.Errorf("OnPublishResponse() got err: %v, want err msg: %q", gotErr, wantMsg)
}
})
}
+
+func TestPublishBatcherRebatching(t *testing.T) {
+ const partition = 2
+ receiver := newTestPublishBatchReceiver(t)
+
+ t.Run("single batch", func(t *testing.T) {
+ msg1 := &pb.PubSubMessage{Data: []byte{'1'}}
+
+ batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+
+ got := batcher.InFlightBatches()
+ want := []*publishBatch{
+ makePublishBatch(makeMsgHolder(msg1)),
+ }
+ if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+ t.Errorf("Batches got: -, want: +\n%s", diff)
+ }
+ })
+
+ t.Run("merge into single batch", func(t *testing.T) {
+ msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, 100)}
+ msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, 200)}
+ msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, 300)}
+ msg4 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{4}, 400)}
+
+ batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2), makeMsgHolder(msg3)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg4)))
+
+ got := batcher.InFlightBatches()
+ want := []*publishBatch{
+ makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2), makeMsgHolder(msg3), makeMsgHolder(msg4)),
+ }
+ if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+ t.Errorf("Batches got: -, want: +\n%s", diff)
+ }
+ })
+
+ t.Run("no rebatching", func(t *testing.T) {
+ msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, MaxPublishRequestBytes-10)}
+ msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, MaxPublishRequestBytes/2)}
+ msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, MaxPublishRequestBytes/2)}
+
+ batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3)))
+
+ got := batcher.InFlightBatches()
+ want := []*publishBatch{
+ makePublishBatch(makeMsgHolder(msg1)),
+ makePublishBatch(makeMsgHolder(msg2)),
+ makePublishBatch(makeMsgHolder(msg3)),
+ }
+ if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+ t.Errorf("Batches got: -, want: +\n%s", diff)
+ }
+ })
+
+ t.Run("mixed rebatching", func(t *testing.T) {
+ // Should be merged into a single batch.
+ msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, MaxPublishRequestBytes/2)}
+ msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, 200)}
+ msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, 300)}
+ // Not merged due to byte limit.
+ msg4 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{4}, MaxPublishRequestBytes-500)}
+ msg5 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{5}, 500)}
+
+ batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg4)))
+ batcher.AddBatch(makePublishBatch(makeMsgHolder(msg5)))
+
+ got := batcher.InFlightBatches()
+ want := []*publishBatch{
+ makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2), makeMsgHolder(msg3)),
+ makePublishBatch(makeMsgHolder(msg4)),
+ makePublishBatch(makeMsgHolder(msg5)),
+ }
+ if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+ t.Errorf("Batches got: -, want: +\n%s", diff)
+ }
+ })
+
+ t.Run("max count", func(t *testing.T) {
+ var msgs []*pb.PubSubMessage
+ var batch1 []*messageHolder
+ var batch2 []*messageHolder
+ batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+ for i := 0; i <= MaxPublishRequestCount; i++ {
+ msg := &pb.PubSubMessage{Data: []byte{'0'}}
+ msgs = append(msgs, msg)
+
+ msgHolder := makeMsgHolder(msg)
+ if i < MaxPublishRequestCount {
+ batch1 = append(batch1, msgHolder)
+ } else {
+ batch2 = append(batch2, msgHolder)
+ }
+ batcher.AddBatch(makePublishBatch(msgHolder))
+ }
+
+ got := batcher.InFlightBatches()
+ want := []*publishBatch{
+ makePublishBatch(batch1...),
+ makePublishBatch(batch2...),
+ }
+ if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+ t.Errorf("Batches got: -, want: +\n%s", diff)
+ }
+ })
+}
diff --git a/pubsublite/internal/wire/publisher_test.go b/pubsublite/internal/wire/publisher_test.go
index 1f58d3a..600d0dd 100644
--- a/pubsublite/internal/wire/publisher_test.go
+++ b/pubsublite/internal/wire/publisher_test.go
@@ -196,8 +196,7 @@
// The publisher should resend all in-flight batches to the second stream.
stream2 := test.NewRPCVerifier(t)
stream2.Push(initPubReq(topic), initPubResp(), nil)
- stream2.Push(msgPubReq(msg1), msgPubResp(0), nil)
- stream2.Push(msgPubReq(msg2), msgPubResp(1), nil)
+ stream2.Push(msgPubReq(msg1, msg2), msgPubResp(0), nil)
stream2.Push(msgPubReq(msg3), msgPubResp(2), nil)
verifiers.AddPublishStream(topic.Path, topic.Partition, stream2)