fix(pubsublite): rebatch messages upon new publish stream (#3694)

Merge in-flight message batches when a new publish stream is connected.
diff --git a/pubsublite/internal/wire/publish_batcher.go b/pubsublite/internal/wire/publish_batcher.go
index 63999b1..e50d3af 100644
--- a/pubsublite/internal/wire/publish_batcher.go
+++ b/pubsublite/internal/wire/publish_batcher.go
@@ -41,6 +41,7 @@
 // MessagePublishRequest.
 type publishBatch struct {
 	msgHolders []*messageHolder
+	totalSize  int
 }
 
 func (b *publishBatch) ToPublishRequest() *pb.PublishRequest {
@@ -93,7 +94,11 @@
 		// singlePartitionPublisher.onNewBatch() receives the new batch from the
 		// Bundler, which calls publishMessageBatcher.AddBatch(). Only the
 		// publisher's mutex is required.
-		onNewBatch(&publishBatch{msgHolders: msgs})
+		batch := &publishBatch{msgHolders: msgs}
+		for _, msg := range batch.msgHolders {
+			batch.totalSize += msg.size
+		}
+		onNewBatch(batch)
 	})
 	msgBundler.DelayThreshold = settings.DelayThreshold
 	msgBundler.BundleCountThreshold = settings.CountThreshold
@@ -164,10 +169,24 @@
 
 func (b *publishMessageBatcher) InFlightBatches() []*publishBatch {
 	var batches []*publishBatch
-	for elem := b.publishQueue.Front(); elem != nil; elem = elem.Next() {
-		if batch, ok := elem.Value.(*publishBatch); ok {
-			batches = append(batches, batch)
+	for elem := b.publishQueue.Front(); elem != nil; {
+		batch := elem.Value.(*publishBatch)
+		if elem.Prev() != nil {
+			// Merge current batch with previous if within max bytes and count limits.
+			prevBatch := elem.Prev().Value.(*publishBatch)
+			totalSize := prevBatch.totalSize + batch.totalSize
+			totalLen := len(prevBatch.msgHolders) + len(batch.msgHolders)
+			if totalSize <= MaxPublishRequestBytes && totalLen <= MaxPublishRequestCount {
+				prevBatch.totalSize = totalSize
+				prevBatch.msgHolders = append(prevBatch.msgHolders, batch.msgHolders...)
+				removeElem := elem
+				elem = elem.Next()
+				b.publishQueue.Remove(removeElem)
+				continue
+			}
 		}
+		batches = append(batches, batch)
+		elem = elem.Next()
 	}
 	return batches
 }
diff --git a/pubsublite/internal/wire/publish_batcher_test.go b/pubsublite/internal/wire/publish_batcher_test.go
index 690c7ed..014bd0d 100644
--- a/pubsublite/internal/wire/publish_batcher_test.go
+++ b/pubsublite/internal/wire/publish_batcher_test.go
@@ -128,8 +128,8 @@
 		}
 	}
 
-	if !testutil.Equal(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})) {
-		br.t.Errorf("Batches got: %v\nwant: %v", got, want)
+	if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+		br.t.Errorf("Batches got: -, want: +\n%s", diff)
 	}
 }
 
@@ -144,6 +144,15 @@
 	return h
 }
 
+func makePublishBatch(msgs ...*messageHolder) *publishBatch {
+	batch := new(publishBatch)
+	for _, msg := range msgs {
+		batch.msgHolders = append(batch.msgHolders, msg)
+		batch.totalSize += msg.size
+	}
+	return batch
+}
+
 func TestPublishBatcherAddMessage(t *testing.T) {
 	const initAvailableBytes = MaxPublishRequestBytes
 	settings := DefaultPublishSettings
@@ -199,22 +208,16 @@
 	// Batch 1
 	msg1 := &pb.PubSubMessage{Data: []byte{'1'}}
 	msg2 := &pb.PubSubMessage{Data: []byte{'2'}}
-	wantBatch1 := &publishBatch{
-		[]*messageHolder{makeMsgHolder(msg1), makeMsgHolder(msg2)},
-	}
+	wantBatch1 := makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2))
 
 	// Batch 2
 	msg3 := &pb.PubSubMessage{Data: []byte{'3'}}
 	msg4 := &pb.PubSubMessage{Data: []byte{'4'}}
-	wantBatch2 := &publishBatch{
-		[]*messageHolder{makeMsgHolder(msg3), makeMsgHolder(msg4)},
-	}
+	wantBatch2 := makePublishBatch(makeMsgHolder(msg3), makeMsgHolder(msg4))
 
 	// Batch 3
 	msg5 := &pb.PubSubMessage{Data: []byte{'5'}}
-	wantBatch3 := &publishBatch{
-		[]*messageHolder{makeMsgHolder(msg5)},
-	}
+	wantBatch3 := makePublishBatch(makeMsgHolder(msg5))
 
 	receiver := newTestPublishBatchReceiver(t)
 	batcher := newPublishMessageBatcher(&settings, 0, receiver.onNewBatch)
@@ -236,15 +239,11 @@
 
 	// Batch 1
 	msg1 := &pb.PubSubMessage{Data: []byte{'1'}}
-	wantBatch1 := &publishBatch{
-		[]*messageHolder{makeMsgHolder(msg1)},
-	}
+	wantBatch1 := makePublishBatch(makeMsgHolder(msg1))
 
 	// Batch 2
 	msg2 := &pb.PubSubMessage{Data: []byte{'2'}}
-	wantBatch2 := &publishBatch{
-		[]*messageHolder{makeMsgHolder(msg2)},
-	}
+	wantBatch2 := makePublishBatch(makeMsgHolder(msg2))
 
 	receiver := newTestPublishBatchReceiver(t)
 	batcher := newPublishMessageBatcher(&settings, 0, receiver.onNewBatch)
@@ -271,12 +270,7 @@
 	msg2 := &pb.PubSubMessage{Data: []byte{'2'}}
 	pubResult1 := newTestPublishResultReceiver(t, msg1)
 	pubResult2 := newTestPublishResultReceiver(t, msg2)
-	batcher.AddBatch(&publishBatch{
-		[]*messageHolder{
-			makeMsgHolder(msg1, pubResult1),
-			makeMsgHolder(msg2, pubResult2),
-		},
-	})
+	batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1, pubResult1), makeMsgHolder(msg2, pubResult2)))
 
 	wantErr := status.Error(codes.FailedPrecondition, "failed")
 	batcher.OnPermanentError(wantErr)
@@ -306,17 +300,8 @@
 		pubResult2 := newTestPublishResultReceiver(t, msg2)
 		pubResult3 := newTestPublishResultReceiver(t, msg3)
 
-		batcher.AddBatch(&publishBatch{
-			[]*messageHolder{
-				makeMsgHolder(msg1, pubResult1),
-				makeMsgHolder(msg2, pubResult2),
-			},
-		})
-		batcher.AddBatch(&publishBatch{
-			[]*messageHolder{
-				makeMsgHolder(msg3, pubResult3),
-			},
-		})
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1, pubResult1), makeMsgHolder(msg2, pubResult2)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3, pubResult3)))
 		if err := batcher.OnPublishResponse(70); err != nil {
 			t.Errorf("OnPublishResponse() got err: %v", err)
 		}
@@ -332,14 +317,126 @@
 	t.Run("inconsistent offset", func(t *testing.T) {
 		msg := &pb.PubSubMessage{Data: []byte{'4'}}
 		pubResult := newTestPublishResultReceiver(t, msg)
-		batcher.AddBatch(&publishBatch{
-			[]*messageHolder{
-				makeMsgHolder(msg, pubResult),
-			},
-		})
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg, pubResult)))
 
 		if gotErr, wantMsg := batcher.OnPublishResponse(80), "inconsistent start offset = 80"; !test.ErrorHasMsg(gotErr, wantMsg) {
 			t.Errorf("OnPublishResponse() got err: %v, want err msg: %q", gotErr, wantMsg)
 		}
 	})
 }
+
+func TestPublishBatcherRebatching(t *testing.T) {
+	const partition = 2
+	receiver := newTestPublishBatchReceiver(t)
+
+	t.Run("single batch", func(t *testing.T) {
+		msg1 := &pb.PubSubMessage{Data: []byte{'1'}}
+
+		batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+
+		got := batcher.InFlightBatches()
+		want := []*publishBatch{
+			makePublishBatch(makeMsgHolder(msg1)),
+		}
+		if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+			t.Errorf("Batches got: -, want: +\n%s", diff)
+		}
+	})
+
+	t.Run("merge into single batch", func(t *testing.T) {
+		msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, 100)}
+		msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, 200)}
+		msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, 300)}
+		msg4 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{4}, 400)}
+
+		batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2), makeMsgHolder(msg3)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg4)))
+
+		got := batcher.InFlightBatches()
+		want := []*publishBatch{
+			makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2), makeMsgHolder(msg3), makeMsgHolder(msg4)),
+		}
+		if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+			t.Errorf("Batches got: -, want: +\n%s", diff)
+		}
+	})
+
+	t.Run("no rebatching", func(t *testing.T) {
+		msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, MaxPublishRequestBytes-10)}
+		msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, MaxPublishRequestBytes/2)}
+		msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, MaxPublishRequestBytes/2)}
+
+		batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3)))
+
+		got := batcher.InFlightBatches()
+		want := []*publishBatch{
+			makePublishBatch(makeMsgHolder(msg1)),
+			makePublishBatch(makeMsgHolder(msg2)),
+			makePublishBatch(makeMsgHolder(msg3)),
+		}
+		if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+			t.Errorf("Batches got: -, want: +\n%s", diff)
+		}
+	})
+
+	t.Run("mixed rebatching", func(t *testing.T) {
+		// Should be merged into a single batch.
+		msg1 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{1}, MaxPublishRequestBytes/2)}
+		msg2 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{2}, 200)}
+		msg3 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{3}, 300)}
+		// Not merged due to byte limit.
+		msg4 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{4}, MaxPublishRequestBytes-500)}
+		msg5 := &pb.PubSubMessage{Data: bytes.Repeat([]byte{5}, 500)}
+
+		batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg1)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg2)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg3)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg4)))
+		batcher.AddBatch(makePublishBatch(makeMsgHolder(msg5)))
+
+		got := batcher.InFlightBatches()
+		want := []*publishBatch{
+			makePublishBatch(makeMsgHolder(msg1), makeMsgHolder(msg2), makeMsgHolder(msg3)),
+			makePublishBatch(makeMsgHolder(msg4)),
+			makePublishBatch(makeMsgHolder(msg5)),
+		}
+		if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+			t.Errorf("Batches got: -, want: +\n%s", diff)
+		}
+	})
+
+	t.Run("max count", func(t *testing.T) {
+		var msgs []*pb.PubSubMessage
+		var batch1 []*messageHolder
+		var batch2 []*messageHolder
+		batcher := newPublishMessageBatcher(&DefaultPublishSettings, partition, receiver.onNewBatch)
+		for i := 0; i <= MaxPublishRequestCount; i++ {
+			msg := &pb.PubSubMessage{Data: []byte{'0'}}
+			msgs = append(msgs, msg)
+
+			msgHolder := makeMsgHolder(msg)
+			if i < MaxPublishRequestCount {
+				batch1 = append(batch1, msgHolder)
+			} else {
+				batch2 = append(batch2, msgHolder)
+			}
+			batcher.AddBatch(makePublishBatch(msgHolder))
+		}
+
+		got := batcher.InFlightBatches()
+		want := []*publishBatch{
+			makePublishBatch(batch1...),
+			makePublishBatch(batch2...),
+		}
+		if diff := testutil.Diff(got, want, cmp.AllowUnexported(publishBatch{}, messageHolder{})); diff != "" {
+			t.Errorf("Batches got: -, want: +\n%s", diff)
+		}
+	})
+}
diff --git a/pubsublite/internal/wire/publisher_test.go b/pubsublite/internal/wire/publisher_test.go
index 1f58d3a..600d0dd 100644
--- a/pubsublite/internal/wire/publisher_test.go
+++ b/pubsublite/internal/wire/publisher_test.go
@@ -196,8 +196,7 @@
 	// The publisher should resend all in-flight batches to the second stream.
 	stream2 := test.NewRPCVerifier(t)
 	stream2.Push(initPubReq(topic), initPubResp(), nil)
-	stream2.Push(msgPubReq(msg1), msgPubResp(0), nil)
-	stream2.Push(msgPubReq(msg2), msgPubResp(1), nil)
+	stream2.Push(msgPubReq(msg1, msg2), msgPubResp(0), nil)
 	stream2.Push(msgPubReq(msg3), msgPubResp(2), nil)
 	verifiers.AddPublishStream(topic.Path, topic.Partition, stream2)