feat(pubsublite): notify subscriber clients on partition reassignment (#4777)
Adds ReceiveSettings.ReassignmentHandler for the SubscriberClient to receive notifications when the server sends a new partition reassignment to the client.
diff --git a/pubsublite/internal/wire/assigner.go b/pubsublite/internal/wire/assigner.go
index ea4be1b..8f0ad4d 100644
--- a/pubsublite/internal/wire/assigner.go
+++ b/pubsublite/internal/wire/assigner.go
@@ -18,6 +18,7 @@
"errors"
"fmt"
"reflect"
+ "sort"
"github.com/google/uuid"
"google.golang.org/grpc"
@@ -26,26 +27,45 @@
pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
)
-// partitionSet is a set of partition numbers.
-type partitionSet map[int]struct{}
+// PartitionSet is a set of partition numbers.
+type PartitionSet map[int]struct{}
-func newPartitionSet(assignmentpb *pb.PartitionAssignment) partitionSet {
+// NewPartitionSet creates a partition set initialized from the given partition
+// numbers.
+func NewPartitionSet(partitions []int) PartitionSet {
var void struct{}
- partitions := make(map[int]struct{})
- for _, p := range assignmentpb.GetPartitions() {
- partitions[int(p)] = void
+ partitionSet := make(map[int]struct{})
+ for _, p := range partitions {
+ partitionSet[p] = void
}
- return partitionSet(partitions)
+ return partitionSet
}
-func (ps partitionSet) Ints() (partitions []int) {
+func newPartitionSet(assignmentpb *pb.PartitionAssignment) PartitionSet {
+ var partitions []int
+ for _, p := range assignmentpb.GetPartitions() {
+ partitions = append(partitions, int(p))
+ }
+ return NewPartitionSet(partitions)
+}
+
+// Ints returns the partitions contained in this set as an unsorted slice.
+func (ps PartitionSet) Ints() (partitions []int) {
for p := range ps {
partitions = append(partitions, p)
}
return
}
-func (ps partitionSet) Contains(partition int) bool {
+// SortedInts returns the partitions contained in this set as a sorted slice.
+func (ps PartitionSet) SortedInts() (partitions []int) {
+ partitions = ps.Ints()
+ sort.Ints(partitions)
+ return
+}
+
+// Contains returns true if this set contains the specified partition.
+func (ps PartitionSet) Contains(partition int) bool {
_, exists := ps[partition]
return exists
}
@@ -54,9 +74,8 @@
type generateUUIDFunc func() (uuid.UUID, error)
// partitionAssignmentReceiver must enact the received partition assignment from
-// the server, or otherwise return an error, which will break the stream. The
-// receiver must not call the assigner, as this would result in a deadlock.
-type partitionAssignmentReceiver func(partitionSet) error
+// the server, or otherwise return an error, which will break the stream.
+type partitionAssignmentReceiver func(PartitionSet) error
// assigner wraps the partition assignment stream and notifies a receiver when
// the server sends a new set of partition assignments for a subscriber.
diff --git a/pubsublite/internal/wire/assigner_test.go b/pubsublite/internal/wire/assigner_test.go
index ad3ae01..8d98c8a 100644
--- a/pubsublite/internal/wire/assigner_test.go
+++ b/pubsublite/internal/wire/assigner_test.go
@@ -16,7 +16,6 @@
import (
"context"
"errors"
- "sort"
"testing"
"time"
@@ -46,9 +45,7 @@
}
}
- gotPartitions := partitions.Ints()
- sort.Ints(gotPartitions)
- if !testutil.Equal(gotPartitions, wantPartitions) {
+ if gotPartitions := partitions.SortedInts(); !testutil.Equal(gotPartitions, wantPartitions) {
t.Errorf("Ints() got %v, want %v", gotPartitions, wantPartitions)
}
}
@@ -91,9 +88,8 @@
return ta
}
-func (ta *testAssigner) receiveAssignment(partitions partitionSet) error {
- p := partitions.Ints()
- sort.Ints(p)
+func (ta *testAssigner) receiveAssignment(partitions PartitionSet) error {
+ p := partitions.SortedInts()
ta.partitions <- p
if ta.recvError != nil {
diff --git a/pubsublite/internal/wire/subscriber.go b/pubsublite/internal/wire/subscriber.go
index b6860a6..d8b84b4 100644
--- a/pubsublite/internal/wire/subscriber.go
+++ b/pubsublite/internal/wire/subscriber.go
@@ -415,13 +415,14 @@
// partitions.
type multiPartitionSubscriber struct {
// Immutable after creation.
- subscribers []*singlePartitionSubscriber
+ subscribers map[int]*singlePartitionSubscriber
apiClientService
}
func newMultiPartitionSubscriber(allClients apiClients, subFactory *singlePartitionSubscriberFactory) *multiPartitionSubscriber {
ms := &multiPartitionSubscriber{
+ subscribers: make(map[int]*singlePartitionSubscriber),
apiClientService: apiClientService{clients: allClients},
}
ms.init()
@@ -429,7 +430,7 @@
for _, partition := range subFactory.settings.Partitions {
subscriber := subFactory.New(partition)
ms.unsafeAddServices(subscriber)
- ms.subscribers = append(ms.subscribers, subscriber)
+ ms.subscribers[partition] = subscriber
}
return ms
}
@@ -445,13 +446,23 @@
}
}
+// PartitionActive returns whether the partition is active.
+func (ms *multiPartitionSubscriber) PartitionActive(partition int) bool {
+ _, exists := ms.subscribers[partition]
+ return exists
+}
+
+// ReassignmentHandlerFunc receives a partition assignment change.
+type ReassignmentHandlerFunc func(before, after PartitionSet) error
+
// assigningSubscriber uses the Pub/Sub Lite partition assignment service to
// listen to its assigned partition numbers and dynamically add/remove
// singlePartitionSubscribers.
type assigningSubscriber struct {
// Immutable after creation.
- subFactory *singlePartitionSubscriberFactory
- assigner *assigner
+ reassignmentHandler ReassignmentHandlerFunc
+ subFactory *singlePartitionSubscriberFactory
+ assigner *assigner
// Fields below must be guarded with mu.
// Subscribers keyed by partition number. Updated as assignments change.
@@ -460,11 +471,13 @@
apiClientService
}
-func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) {
+func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, reassignmentHandler ReassignmentHandlerFunc,
+ genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) {
as := &assigningSubscriber{
- apiClientService: apiClientService{clients: allClients},
- subFactory: subFactory,
- subscribers: make(map[int]*singlePartitionSubscriber),
+ apiClientService: apiClientService{clients: allClients},
+ reassignmentHandler: reassignmentHandler,
+ subFactory: subFactory,
+ subscribers: make(map[int]*singlePartitionSubscriber),
}
as.init()
@@ -477,12 +490,17 @@
return as, nil
}
-func (as *assigningSubscriber) handleAssignment(partitions partitionSet) error {
- removedSubscribers, err := as.doHandleAssignment(partitions)
+func (as *assigningSubscriber) handleAssignment(nextPartitions PartitionSet) error {
+ previousPartitions, removedSubscribers, err := as.doHandleAssignment(nextPartitions)
if err != nil {
return err
}
+ // Notify the user reassignment handler.
+ if err := as.reassignmentHandler(previousPartitions, nextPartitions); err != nil {
+ return err
+ }
+
// Wait for removed subscribers to completely stop (which waits for commit
// acknowledgments from the server) before acking the assignment. This avoids
// commits racing with the new assigned client.
@@ -492,17 +510,23 @@
return nil
}
-func (as *assigningSubscriber) doHandleAssignment(partitions partitionSet) ([]*singlePartitionSubscriber, error) {
+// Returns the previous set of partitions and removed subscribers.
+func (as *assigningSubscriber) doHandleAssignment(nextPartitions PartitionSet) (PartitionSet, []*singlePartitionSubscriber, error) {
as.mu.Lock()
defer as.mu.Unlock()
+ var previousPartitions []int
+ for partition := range as.subscribers {
+ previousPartitions = append(previousPartitions, partition)
+ }
+
// Handle new partitions.
- for _, partition := range partitions.Ints() {
+ for _, partition := range nextPartitions.Ints() {
if _, exists := as.subscribers[partition]; !exists {
subscriber := as.subFactory.New(partition)
if err := as.unsafeAddServices(subscriber); err != nil {
// Occurs when the assigningSubscriber is stopping/stopped.
- return nil, err
+ return nil, nil, err
}
as.subscribers[partition] = subscriber
}
@@ -511,7 +535,7 @@
// Handle removed partitions.
var removedSubscribers []*singlePartitionSubscriber
for partition, subscriber := range as.subscribers {
- if !partitions.Contains(partition) {
+ if !nextPartitions.Contains(partition) {
// Ignore unacked messages from this point on to avoid conflicting with
// the commits of the new subscriber that will be assigned this partition.
subscriber.Terminate()
@@ -523,7 +547,7 @@
delete(as.subscribers, partition)
}
}
- return removedSubscribers, nil
+ return NewPartitionSet(previousPartitions), removedSubscribers, nil
}
// Terminate shuts down all singlePartitionSubscribers without waiting for
@@ -537,6 +561,15 @@
}
}
+// PartitionActive returns whether the partition is still active.
+func (as *assigningSubscriber) PartitionActive(partition int) bool {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+
+ _, exists := as.subscribers[partition]
+ return exists
+}
+
// Subscriber is the client interface exported from this package for receiving
// messages.
type Subscriber interface {
@@ -545,10 +578,12 @@
Stop()
WaitStopped() error
Terminate()
+ PartitionActive(int) bool
}
// NewSubscriber creates a new client for receiving messages.
-func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc, region, subscriptionPath string, opts ...option.ClientOption) (Subscriber, error) {
+func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc, reassignmentHandler ReassignmentHandlerFunc,
+ region, subscriptionPath string, opts ...option.ClientOption) (Subscriber, error) {
if err := ValidateRegion(region); err != nil {
return nil, err
}
@@ -588,5 +623,5 @@
return nil, err
}
allClients = append(allClients, partitionClient)
- return newAssigningSubscriber(allClients, partitionClient, uuid.NewRandom, subFactory)
+ return newAssigningSubscriber(allClients, partitionClient, reassignmentHandler, uuid.NewRandom, subFactory)
}
diff --git a/pubsublite/internal/wire/subscriber_test.go b/pubsublite/internal/wire/subscriber_test.go
index 6aa5fe7..f7dca03 100644
--- a/pubsublite/internal/wire/subscriber_test.go
+++ b/pubsublite/internal/wire/subscriber_test.go
@@ -15,6 +15,7 @@
import (
"context"
+ "errors"
"sort"
"sync"
"testing"
@@ -30,17 +31,22 @@
pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
)
+const (
+ maxMessages int = 10
+ maxBytes int = 1000
+)
+
func testSubscriberSettings() ReceiveSettings {
settings := testReceiveSettings()
- settings.MaxOutstandingMessages = 10
- settings.MaxOutstandingBytes = 1000
+ settings.MaxOutstandingMessages = maxMessages
+ settings.MaxOutstandingBytes = maxBytes
return settings
}
// initFlowControlReq returns the first expected flow control request when
// testSubscriberSettings are used.
func initFlowControlReq() *pb.SubscribeRequest {
- return flowControlSubReq(flowControlTokens{Bytes: 1000, Messages: 10})
+ return flowControlSubReq(flowControlTokens{Bytes: int64(maxBytes), Messages: int64(maxMessages)})
}
func partitionMsgs(partition int, msgs ...*pb.SequencedMessage) []*ReceivedMessage {
@@ -929,6 +935,15 @@
}
}
+func verifyPartitionsActive(t *testing.T, sub Subscriber, want bool, partitions ...int) {
+ t.Helper()
+ for _, p := range partitions {
+ if got := sub.PartitionActive(p); got != want {
+ t.Errorf("PartitionActive(%d) got %v, want %v", p, got, want)
+ }
+ }
+}
+
func newTestMultiPartitionSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string, partitions []int) *multiPartitionSubscriber {
ctx := context.Background()
subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
@@ -994,6 +1009,9 @@
defer mockServer.OnTestEnd()
sub := newTestMultiPartitionSubscriber(t, receiver.onMessage, subscription, []int{1, 2})
+ verifyPartitionsActive(t, sub, true, 1, 2)
+ verifyPartitionsActive(t, sub, false, 0, 3)
+
if gotErr := sub.WaitStarted(); gotErr != nil {
t.Errorf("Start() got err: (%v)", gotErr)
}
@@ -1056,18 +1074,6 @@
receiver.VerifyNoMsgs()
}
-func (as *assigningSubscriber) Partitions() []int {
- as.mu.Lock()
- defer as.mu.Unlock()
-
- var partitions []int
- for p := range as.subscribers {
- partitions = append(partitions, p)
- }
- sort.Ints(partitions)
- return partitions
-}
-
func (as *assigningSubscriber) Subscribers() []*singlePartitionSubscriber {
as.mu.Lock()
defer as.mu.Unlock()
@@ -1088,7 +1094,11 @@
}
}
-func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string) *assigningSubscriber {
+func noopReassignmentHandler(_, _ PartitionSet) error {
+ return nil
+}
+
+func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, reassignmentHandler ReassignmentHandlerFunc, subscriptionPath string) *assigningSubscriber {
ctx := context.Background()
subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
@@ -1113,7 +1123,7 @@
receiver: receiverFunc,
disableTasks: true, // Background tasks disabled to control event order
}
- sub, err := newAssigningSubscriber(allClients, assignmentClient, fakeGenerateUUID, f)
+ sub, err := newAssigningSubscriber(allClients, assignmentClient, reassignmentHandler, fakeGenerateUUID, f)
if err != nil {
t.Fatal(err)
}
@@ -1179,23 +1189,21 @@
mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()
- sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
+ sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription)
if gotErr := sub.WaitStarted(); gotErr != nil {
t.Errorf("Start() got err: (%v)", gotErr)
}
// Partition assignments are initially {3, 6}.
receiver.ValidateMsgs(join(partitionMsgs(3, msg1), partitionMsgs(6, msg3)))
- if got, want := sub.Partitions(), []int{3, 6}; !testutil.Equal(got, want) {
- t.Errorf("subscriber partitions: got %d, want %d", got, want)
- }
+ verifyPartitionsActive(t, sub, true, 3, 6)
+ verifyPartitionsActive(t, sub, false, 1, 8)
// Partition assignments will now be {3, 8}.
assignmentBarrier1.Release()
receiver.ValidateMsgs(partitionMsgs(8, msg5))
- if got, want := sub.Partitions(), []int{3, 8}; !testutil.Equal(got, want) {
- t.Errorf("subscriber partitions: got %d, want %d", got, want)
- }
+ verifyPartitionsActive(t, sub, true, 3, 8)
+ verifyPartitionsActive(t, sub, false, 2, 6)
// msg2 is from partition 3 and should be received. msg4 is from partition 6
// (removed) and should be discarded.
@@ -1255,7 +1263,7 @@
mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()
- sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
+ sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription)
if gotErr := sub.WaitStarted(); gotErr != nil {
t.Errorf("Start() got err: (%v)", gotErr)
}
@@ -1298,7 +1306,7 @@
mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()
- sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
+ sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription)
if gotErr := sub.WaitStarted(); gotErr != nil {
t.Errorf("Start() got err: (%v)", gotErr)
}
@@ -1330,6 +1338,116 @@
}
}
+func TestAssigningSubscriberStoppedWhileReassignmentHandlerActive(t *testing.T) {
+ const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
+ receiver := newTestMessageReceiver(t)
+
+ verifiers := test.NewVerifiers(t)
+
+ // Assignment stream
+ asnStream := test.NewRPCVerifier(t)
+ asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1}), nil)
+ verifiers.AddAssignmentStream(subscription, asnStream)
+
+ // Partition 1
+ subStream := test.NewRPCVerifier(t)
+ subStream.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
+ subBarrier := subStream.PushWithBarrier(initFlowControlReq(), nil, nil)
+ verifiers.AddSubscribeStream(subscription, 1, subStream)
+
+ cmtStream := test.NewRPCVerifier(t)
+ cmtBarrier := cmtStream.PushWithBarrier(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
+ verifiers.AddCommitStream(subscription, 1, cmtStream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ reassignmentHandlerCalled := test.NewCondition("reassignment handler called")
+ returnReassignmentHandler := test.NewCondition("return reassignment handler")
+ onReassignment := func(before, after PartitionSet) error {
+ if got, want := len(before.SortedInts()), 0; got != want {
+ t.Errorf("len(before): got %v, want %v", got, want)
+ }
+ if got, want := after.SortedInts(), []int{1}; !testutil.Equal(got, want) {
+ t.Errorf("after: got %v, want %v", got, want)
+ }
+ reassignmentHandlerCalled.SetDone()
+ returnReassignmentHandler.WaitUntilDone(t, serviceTestWaitTimeout)
+ return nil
+ }
+
+ sub := newTestAssigningSubscriber(t, receiver.onMessage, onReassignment, subscription)
+ if gotErr := sub.WaitStarted(); gotErr != nil {
+ t.Errorf("Start() got err: (%v)", gotErr)
+ }
+
+ // Used to control order of execution to ensure the test is deterministic.
+ subBarrier.Release()
+ cmtBarrier.Release()
+
+ // Ensure there are no deadlocks if the reassignment handler blocks and the
+ // subscriber is stopped.
+ reassignmentHandlerCalled.WaitUntilDone(t, serviceTestWaitTimeout)
+ sub.Stop()
+ returnReassignmentHandler.SetDone()
+
+ if gotErr := sub.WaitStopped(); gotErr != nil {
+ t.Errorf("WaitStopped() got err: (%v)", gotErr)
+ }
+}
+
+func TestAssigningSubscriberReassignmentHandlerReturnsError(t *testing.T) {
+ const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
+ receiver := newTestMessageReceiver(t)
+
+ verifiers := test.NewVerifiers(t)
+
+ // Assignment stream
+ asnStream := test.NewRPCVerifier(t)
+ asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1}), nil)
+ verifiers.AddAssignmentStream(subscription, asnStream)
+
+ // Partition 1
+ subStream := test.NewRPCVerifier(t)
+ subStream.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
+ subBarrier := subStream.PushWithBarrier(initFlowControlReq(), nil, nil)
+ verifiers.AddSubscribeStream(subscription, 1, subStream)
+
+ cmtStream := test.NewRPCVerifier(t)
+ cmtBarrier := cmtStream.PushWithBarrier(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
+ verifiers.AddCommitStream(subscription, 1, cmtStream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ reassignmentErr := errors.New("reassignment handler error")
+ returnReassignmentErr := test.NewCondition("return reassignment error")
+ onAssignment := func(before, after PartitionSet) error {
+ if got, want := len(before.SortedInts()), 0; got != want {
+ t.Errorf("len(before): got %v, want %v", got, want)
+ }
+ if got, want := after.SortedInts(), []int{1}; !testutil.Equal(got, want) {
+ t.Errorf("after: got %v, want %v", got, want)
+ }
+ returnReassignmentErr.WaitUntilDone(t, serviceTestWaitTimeout)
+ return reassignmentErr
+ }
+
+ sub := newTestAssigningSubscriber(t, receiver.onMessage, onAssignment, subscription)
+ if gotErr := sub.WaitStarted(); gotErr != nil {
+ t.Errorf("Start() got err: (%v)", gotErr)
+ }
+
+ // Used to control order of execution to ensure the test is deterministic.
+ subBarrier.Release()
+ cmtBarrier.Release()
+ returnReassignmentErr.SetDone()
+
+ if gotErr := sub.WaitStopped(); !test.ErrorEqual(gotErr, reassignmentErr) {
+ t.Errorf("WaitStopped() got err: (%v), want err: (%v)", gotErr, reassignmentErr)
+ }
+}
+
func TestNewSubscriberValidatesSettings(t *testing.T) {
const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
const region = "us-central1"
@@ -1337,7 +1455,7 @@
settings := DefaultReceiveSettings
settings.MaxOutstandingMessages = 0
- if _, err := NewSubscriber(context.Background(), settings, receiver.onMessage, region, subscription); err == nil {
+ if _, err := NewSubscriber(context.Background(), settings, receiver.onMessage, noopReassignmentHandler, region, subscription); err == nil {
t.Error("NewSubscriber() did not return error")
}
}
diff --git a/pubsublite/pscompat/integration_test.go b/pubsublite/pscompat/integration_test.go
index 59aa779..86da439 100644
--- a/pubsublite/pscompat/integration_test.go
+++ b/pubsublite/pscompat/integration_test.go
@@ -17,6 +17,7 @@
"context"
"errors"
"fmt"
+ "sort"
"strings"
"sync"
"sync/atomic"
@@ -729,12 +730,29 @@
}
}
+ // Verify partition reassignment notifications.
+ var allPartitions []int
+ var mu sync.Mutex
+ reassignmentHandler := func(before, after []int) error {
+ t.Logf("Partition assignments: before %v, after %v", before, after)
+ if got, want := len(before), 0; got != want {
+ t.Errorf("Partition assignments len(before): got %d, want %d", got, want)
+ }
+ mu.Lock()
+ allPartitions = append(allPartitions, after...)
+ mu.Unlock()
+ return nil
+ }
+
+ receiveSettings := DefaultReceiveSettings
+ receiveSettings.ReassignmentHandler = reassignmentHandler
+
cctx, stopSubscribers := context.WithTimeout(context.Background(), defaultTestTimeout)
g, _ := errgroup.WithContext(ctx)
for i := 0; i < subscriberCount; i++ {
// Subscribers must be started in a goroutine as Receive() blocks.
g.Go(func() error {
- subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath)
+ subscriber := subscriberClient(context.Background(), t, receiveSettings, subscriptionPath)
err := subscriber.Receive(cctx, messageReceiver)
if err != nil {
t.Errorf("Receive() got err: %v", err)
@@ -748,6 +766,13 @@
stopSubscribers()
// Wait until all subscribers have terminated.
g.Wait()
+
+ mu.Lock()
+ sort.Ints(allPartitions)
+ if got, want := allPartitions, partitionNumbers(partitionCount); !testutil.Equal(got, want) {
+ t.Errorf("Assigned partition numbers: got %v, want %v", got, want)
+ }
+ mu.Unlock()
})
}
diff --git a/pubsublite/pscompat/settings.go b/pubsublite/pscompat/settings.go
index 59b84c8..a4b8b51 100644
--- a/pubsublite/pscompat/settings.go
+++ b/pubsublite/pscompat/settings.go
@@ -164,6 +164,30 @@
// error and terminate.
type ReceiveMessageTransformerFunc func(*pb.SequencedMessage, *pubsub.Message) error
+// ReassignmentHandlerFunc is called any time a new partition assignment is
+// received from the server. It will be called with both the previous and new
+// partition numbers as decided by the server. Both slices of partition numbers
+// are sorted in ascending order.
+//
+// When this handler is called, partitions that are being assigned away are
+// stopping and new partitions are starting. Acks and nacks for messages from
+// partitions that are being assigned away will have no effect, but message
+// deliveries may still be in flight.
+//
+// The client library will not acknowledge the assignment until this handler
+// returns. The server will not assign any of the partitions in
+// `previousPartitions` to another client unless the assignment is acknowledged,
+// or a client takes too long to acknowledge (currently 30 seconds from the time
+// the assignment is sent from server's point of view).
+//
+// Because of the above, as long as reassignment handling is processed quickly,
+// it can be used to abort outstanding operations on partitions which are being
+// assigned away from this client.
+//
+// If this handler returns an error, the SubscriberClient will consider this a
+// fatal error and terminate.
+type ReassignmentHandlerFunc func(previousPartitions, nextPartitions []int) error
+
// ReceiveSettings configure the SubscriberClient. Flow control settings
// (MaxOutstandingMessages, MaxOutstandingBytes) apply per partition.
//
@@ -210,6 +234,10 @@
// Optional custom function that transforms a SequencedMessage API proto to a
// pubsub.Message.
MessageTransformer ReceiveMessageTransformerFunc
+
+ // Optional custom function that is called when a new partition assignment has
+ // been delivered to the client.
+ ReassignmentHandler ReassignmentHandlerFunc
}
// DefaultReceiveSettings holds the default values for ReceiveSettings.
diff --git a/pubsublite/pscompat/subscriber.go b/pubsublite/pscompat/subscriber.go
index 368d735..d0d00e0 100644
--- a/pubsublite/pscompat/subscriber.go
+++ b/pubsublite/pscompat/subscriber.go
@@ -40,6 +40,7 @@
type pslAckHandler struct {
ackh wire.AckConsumer
msg *pubsub.Message
+ partition int
nackh NackHandler
subInstance *subscriberInstance
}
@@ -58,6 +59,11 @@
return
}
+ // Ignore nacks for partitions that have been assigned away.
+ if !ah.subInstance.wireSub.PartitionActive(ah.partition) {
+ return
+ }
+
err := ah.nackh(ah.msg)
if err != nil {
// If the NackHandler returns an error, shut down the subscriber client.
@@ -72,7 +78,7 @@
// wireSubscriberFactory is a factory for creating wire subscribers, which can
// be overridden with a mock in unit tests.
type wireSubscriberFactory interface {
- New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error)
+ New(context.Context, wire.MessageReceiverFunc, wire.ReassignmentHandlerFunc) (wire.Subscriber, error)
}
type wireSubscriberFactoryImpl struct {
@@ -82,8 +88,8 @@
options []option.ClientOption
}
-func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
- return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...)
+func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) {
+ return wire.NewSubscriber(ctx, f.settings, receiver, onReassignment, f.region, f.subscription.String(), f.options...)
}
type messageReceiverFunc = func(context.Context, *pubsub.Message)
@@ -116,7 +122,7 @@
// cancelled, the gRPC streams will be disconnected and the subscriber will
// not be able to process acks and commit the final cursor offset. Use the
// context from NewSubscriberClient (clientCtx) instead.
- wireSub, err := factory.New(clientCtx, subInstance.onMessage)
+ wireSub, err := factory.New(clientCtx, subInstance.onMessage, subInstance.onReassignment)
if err != nil {
return nil, err
}
@@ -131,6 +137,13 @@
return subInstance, nil
}
+func (si *subscriberInstance) onReassignment(before, after wire.PartitionSet) error {
+ if si.settings.ReassignmentHandler != nil {
+ return si.settings.ReassignmentHandler(before.SortedInts(), after.SortedInts())
+ }
+ return nil
+}
+
func (si *subscriberInstance) transformMessage(in *wire.ReceivedMessage, out *pubsub.Message) error {
if err := si.settings.MessageTransformer(in.Msg, out); err != nil {
return err
@@ -147,6 +160,7 @@
pslAckh := &pslAckHandler{
ackh: msg.Ack,
nackh: si.settings.NackHandler,
+ partition: msg.Partition,
subInstance: si,
}
psMsg := ipubsub.NewMessage(pslAckh)
diff --git a/pubsublite/pscompat/subscriber_test.go b/pubsublite/pscompat/subscriber_test.go
index 5c737f8..6bd4ac8 100644
--- a/pubsublite/pscompat/subscriber_test.go
+++ b/pubsublite/pscompat/subscriber_test.go
@@ -30,7 +30,10 @@
pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
)
-const defaultSubscriberTestTimeout = 10 * time.Second
+const (
+ defaultSubscriberTestTimeout = 10 * time.Second
+ activePartition = 1
+)
// mockAckConsumer is a mock implementation of the wire.AckConsumer interface.
type mockAckConsumer struct {
@@ -43,12 +46,14 @@
// mockWireSubscriber is a mock implementation of the wire.Subscriber interface.
type mockWireSubscriber struct {
- receiver wire.MessageReceiverFunc
- msgsC chan *wire.ReceivedMessage
- stopC chan struct{}
- err error
- Stopped bool
- Terminated bool
+ receiver wire.MessageReceiverFunc
+ onReassignment wire.ReassignmentHandlerFunc
+ activePartitions wire.PartitionSet
+ msgsC chan *wire.ReceivedMessage
+ stopC chan struct{}
+ err error
+ Stopped bool
+ Terminated bool
}
// DeliverMessages should be called from the test to simulate a message
@@ -59,6 +64,14 @@
}
}
+// OnReassignment should be called from the test to simulate a partition
+// reassignment.
+func (ms *mockWireSubscriber) DeliverReassignment(before, after wire.PartitionSet) {
+ if err := ms.onReassignment(before, after); err != nil {
+ ms.SimulateFatalError(err)
+ }
+}
+
// SimulateFatalError should be called from the test to simulate a fatal error
// occurring in the wire subscriber.
func (ms *mockWireSubscriber) SimulateFatalError(err error) {
@@ -111,13 +124,19 @@
return ms.err
}
+func (ms *mockWireSubscriber) PartitionActive(partition int) bool {
+ return ms.activePartitions.Contains(partition)
+}
+
type mockWireSubscriberFactory struct{}
-func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
+func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) {
return &mockWireSubscriber{
- receiver: receiver,
- msgsC: make(chan *wire.ReceivedMessage, 10),
- stopC: make(chan struct{}),
+ receiver: receiver,
+ onReassignment: onReassignment,
+ activePartitions: wire.NewPartitionSet([]int{activePartition}),
+ msgsC: make(chan *wire.ReceivedMessage, 10),
+ stopC: make(chan struct{}),
}, nil
}
@@ -289,6 +308,7 @@
desc string
// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
mutateSettings func(settings *ReceiveSettings)
+ msgPartition int
wantErr error
wantAckCount int
wantStopped bool
@@ -297,17 +317,27 @@
{
desc: "default settings",
mutateSettings: func(settings *ReceiveSettings) {},
+ msgPartition: activePartition,
wantErr: errNackCalled,
wantAckCount: 0,
wantTerminated: true,
},
{
+ desc: "message partition inactive",
+ mutateSettings: func(settings *ReceiveSettings) {},
+ msgPartition: activePartition + 1,
+ wantErr: nil,
+ wantAckCount: 0,
+ wantStopped: true,
+ },
+ {
desc: "nack handler returns nil",
mutateSettings: func(settings *ReceiveSettings) {
settings.NackHandler = func(_ *pubsub.Message) error {
return nil
}
},
+ msgPartition: activePartition,
wantErr: nil,
wantAckCount: 1,
wantStopped: true,
@@ -319,6 +349,7 @@
return nackErr
}
},
+ msgPartition: activePartition,
wantErr: nackErr,
wantAckCount: 0,
wantTerminated: true,
@@ -329,7 +360,7 @@
tc.mutateSettings(&settings)
ack := &mockAckConsumer{}
- msg := &wire.ReceivedMessage{Msg: msg, Ack: ack}
+ msg := &wire.ReceivedMessage{Msg: msg, Ack: ack, Partition: tc.msgPartition}
cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
messageReceiver := func(ctx context.Context, got *pubsub.Message) {
@@ -363,6 +394,67 @@
}
}
+func TestSubscriberInstanceReassignmentHandler(t *testing.T) {
+ reassignmentErr := errors.New("reassignment failure")
+ before := wire.NewPartitionSet([]int{3, 2, 1})
+ after := wire.NewPartitionSet([]int{4, 5, 3})
+ ctx := context.Background()
+
+ for _, tc := range []struct {
+ desc string
+ // mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
+ mutateSettings func(settings *ReceiveSettings)
+ wantErr error
+ }{
+ {
+ desc: "default settings",
+ mutateSettings: func(settings *ReceiveSettings) {},
+ },
+ {
+ desc: "reassignment handler returns nil",
+ mutateSettings: func(settings *ReceiveSettings) {
+ settings.ReassignmentHandler = func(before, after []int) error {
+ if got, want := before, []int{1, 2, 3}; !testutil.Equal(got, want) {
+ t.Errorf("before: got %d, want %d", got, want)
+ }
+ if got, want := after, []int{3, 4, 5}; !testutil.Equal(got, want) {
+ t.Errorf("after: got %d, want %d", got, want)
+ }
+ return nil
+ }
+ },
+ },
+ {
+ desc: "reassignment handler returns error",
+ mutateSettings: func(settings *ReceiveSettings) {
+ settings.ReassignmentHandler = func(before, after []int) error {
+ return reassignmentErr
+ }
+ },
+ wantErr: reassignmentErr,
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ settings := DefaultReceiveSettings
+ tc.mutateSettings(&settings)
+
+ cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
+ messageReceiver := func(ctx context.Context, got *pubsub.Message) {
+ t.Error("Message receiver should not be called")
+ }
+ subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
+ subInstance.wireSub.(*mockWireSubscriber).DeliverReassignment(before, after)
+ if tc.wantErr == nil {
+ stopSubscriber()
+ }
+
+ if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
+ t.Errorf("subscriberInstance.Wait() got err: (%v), want err: (%v)", gotErr, tc.wantErr)
+ }
+ })
+ }
+}
+
func TestSubscriberInstanceWireSubscriberFails(t *testing.T) {
fatalErr := errors.New("server error")