feat(pubsublite): notify subscriber clients on partition reassignment (#4777)

Adds ReceiveSettings.ReassignmentHandler for the SubscriberClient to receive notifications when the server sends a new partition reassignment to the client.
diff --git a/pubsublite/internal/wire/assigner.go b/pubsublite/internal/wire/assigner.go
index ea4be1b..8f0ad4d 100644
--- a/pubsublite/internal/wire/assigner.go
+++ b/pubsublite/internal/wire/assigner.go
@@ -18,6 +18,7 @@
 	"errors"
 	"fmt"
 	"reflect"
+	"sort"
 
 	"github.com/google/uuid"
 	"google.golang.org/grpc"
@@ -26,26 +27,45 @@
 	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
 )
 
-// partitionSet is a set of partition numbers.
-type partitionSet map[int]struct{}
+// PartitionSet is a set of partition numbers.
+type PartitionSet map[int]struct{}
 
-func newPartitionSet(assignmentpb *pb.PartitionAssignment) partitionSet {
+// NewPartitionSet creates a partition set initialized from the given partition
+// numbers.
+func NewPartitionSet(partitions []int) PartitionSet {
 	var void struct{}
-	partitions := make(map[int]struct{})
-	for _, p := range assignmentpb.GetPartitions() {
-		partitions[int(p)] = void
+	partitionSet := make(map[int]struct{})
+	for _, p := range partitions {
+		partitionSet[p] = void
 	}
-	return partitionSet(partitions)
+	return partitionSet
 }
 
-func (ps partitionSet) Ints() (partitions []int) {
+func newPartitionSet(assignmentpb *pb.PartitionAssignment) PartitionSet {
+	var partitions []int
+	for _, p := range assignmentpb.GetPartitions() {
+		partitions = append(partitions, int(p))
+	}
+	return NewPartitionSet(partitions)
+}
+
+// Ints returns the partitions contained in this set as an unsorted slice.
+func (ps PartitionSet) Ints() (partitions []int) {
 	for p := range ps {
 		partitions = append(partitions, p)
 	}
 	return
 }
 
-func (ps partitionSet) Contains(partition int) bool {
+// SortedInts returns the partitions contained in this set as a sorted slice.
+func (ps PartitionSet) SortedInts() (partitions []int) {
+	partitions = ps.Ints()
+	sort.Ints(partitions)
+	return
+}
+
+// Contains returns true if this set contains the specified partition.
+func (ps PartitionSet) Contains(partition int) bool {
 	_, exists := ps[partition]
 	return exists
 }
@@ -54,9 +74,8 @@
 type generateUUIDFunc func() (uuid.UUID, error)
 
 // partitionAssignmentReceiver must enact the received partition assignment from
-// the server, or otherwise return an error, which will break the stream. The
-// receiver must not call the assigner, as this would result in a deadlock.
-type partitionAssignmentReceiver func(partitionSet) error
+// the server, or otherwise return an error, which will break the stream.
+type partitionAssignmentReceiver func(PartitionSet) error
 
 // assigner wraps the partition assignment stream and notifies a receiver when
 // the server sends a new set of partition assignments for a subscriber.
diff --git a/pubsublite/internal/wire/assigner_test.go b/pubsublite/internal/wire/assigner_test.go
index ad3ae01..8d98c8a 100644
--- a/pubsublite/internal/wire/assigner_test.go
+++ b/pubsublite/internal/wire/assigner_test.go
@@ -16,7 +16,6 @@
 import (
 	"context"
 	"errors"
-	"sort"
 	"testing"
 	"time"
 
@@ -46,9 +45,7 @@
 		}
 	}
 
-	gotPartitions := partitions.Ints()
-	sort.Ints(gotPartitions)
-	if !testutil.Equal(gotPartitions, wantPartitions) {
+	if gotPartitions := partitions.SortedInts(); !testutil.Equal(gotPartitions, wantPartitions) {
 		t.Errorf("Ints() got %v, want %v", gotPartitions, wantPartitions)
 	}
 }
@@ -91,9 +88,8 @@
 	return ta
 }
 
-func (ta *testAssigner) receiveAssignment(partitions partitionSet) error {
-	p := partitions.Ints()
-	sort.Ints(p)
+func (ta *testAssigner) receiveAssignment(partitions PartitionSet) error {
+	p := partitions.SortedInts()
 	ta.partitions <- p
 
 	if ta.recvError != nil {
diff --git a/pubsublite/internal/wire/subscriber.go b/pubsublite/internal/wire/subscriber.go
index b6860a6..d8b84b4 100644
--- a/pubsublite/internal/wire/subscriber.go
+++ b/pubsublite/internal/wire/subscriber.go
@@ -415,13 +415,14 @@
 // partitions.
 type multiPartitionSubscriber struct {
 	// Immutable after creation.
-	subscribers []*singlePartitionSubscriber
+	subscribers map[int]*singlePartitionSubscriber
 
 	apiClientService
 }
 
 func newMultiPartitionSubscriber(allClients apiClients, subFactory *singlePartitionSubscriberFactory) *multiPartitionSubscriber {
 	ms := &multiPartitionSubscriber{
+		subscribers:      make(map[int]*singlePartitionSubscriber),
 		apiClientService: apiClientService{clients: allClients},
 	}
 	ms.init()
@@ -429,7 +430,7 @@
 	for _, partition := range subFactory.settings.Partitions {
 		subscriber := subFactory.New(partition)
 		ms.unsafeAddServices(subscriber)
-		ms.subscribers = append(ms.subscribers, subscriber)
+		ms.subscribers[partition] = subscriber
 	}
 	return ms
 }
@@ -445,13 +446,23 @@
 	}
 }
 
+// PartitionActive returns whether the partition is active.
+func (ms *multiPartitionSubscriber) PartitionActive(partition int) bool {
+	_, exists := ms.subscribers[partition]
+	return exists
+}
+
+// ReassignmentHandlerFunc receives a partition assignment change.
+type ReassignmentHandlerFunc func(before, after PartitionSet) error
+
 // assigningSubscriber uses the Pub/Sub Lite partition assignment service to
 // listen to its assigned partition numbers and dynamically add/remove
 // singlePartitionSubscribers.
 type assigningSubscriber struct {
 	// Immutable after creation.
-	subFactory *singlePartitionSubscriberFactory
-	assigner   *assigner
+	reassignmentHandler ReassignmentHandlerFunc
+	subFactory          *singlePartitionSubscriberFactory
+	assigner            *assigner
 
 	// Fields below must be guarded with mu.
 	// Subscribers keyed by partition number. Updated as assignments change.
@@ -460,11 +471,13 @@
 	apiClientService
 }
 
-func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) {
+func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, reassignmentHandler ReassignmentHandlerFunc,
+	genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) {
 	as := &assigningSubscriber{
-		apiClientService: apiClientService{clients: allClients},
-		subFactory:       subFactory,
-		subscribers:      make(map[int]*singlePartitionSubscriber),
+		apiClientService:    apiClientService{clients: allClients},
+		reassignmentHandler: reassignmentHandler,
+		subFactory:          subFactory,
+		subscribers:         make(map[int]*singlePartitionSubscriber),
 	}
 	as.init()
 
@@ -477,12 +490,17 @@
 	return as, nil
 }
 
-func (as *assigningSubscriber) handleAssignment(partitions partitionSet) error {
-	removedSubscribers, err := as.doHandleAssignment(partitions)
+func (as *assigningSubscriber) handleAssignment(nextPartitions PartitionSet) error {
+	previousPartitions, removedSubscribers, err := as.doHandleAssignment(nextPartitions)
 	if err != nil {
 		return err
 	}
 
+	// Notify the user reassignment handler.
+	if err := as.reassignmentHandler(previousPartitions, nextPartitions); err != nil {
+		return err
+	}
+
 	// Wait for removed subscribers to completely stop (which waits for commit
 	// acknowledgments from the server) before acking the assignment. This avoids
 	// commits racing with the new assigned client.
@@ -492,17 +510,23 @@
 	return nil
 }
 
-func (as *assigningSubscriber) doHandleAssignment(partitions partitionSet) ([]*singlePartitionSubscriber, error) {
+// Returns the previous set of partitions and removed subscribers.
+func (as *assigningSubscriber) doHandleAssignment(nextPartitions PartitionSet) (PartitionSet, []*singlePartitionSubscriber, error) {
 	as.mu.Lock()
 	defer as.mu.Unlock()
 
+	var previousPartitions []int
+	for partition := range as.subscribers {
+		previousPartitions = append(previousPartitions, partition)
+	}
+
 	// Handle new partitions.
-	for _, partition := range partitions.Ints() {
+	for _, partition := range nextPartitions.Ints() {
 		if _, exists := as.subscribers[partition]; !exists {
 			subscriber := as.subFactory.New(partition)
 			if err := as.unsafeAddServices(subscriber); err != nil {
 				// Occurs when the assigningSubscriber is stopping/stopped.
-				return nil, err
+				return nil, nil, err
 			}
 			as.subscribers[partition] = subscriber
 		}
@@ -511,7 +535,7 @@
 	// Handle removed partitions.
 	var removedSubscribers []*singlePartitionSubscriber
 	for partition, subscriber := range as.subscribers {
-		if !partitions.Contains(partition) {
+		if !nextPartitions.Contains(partition) {
 			// Ignore unacked messages from this point on to avoid conflicting with
 			// the commits of the new subscriber that will be assigned this partition.
 			subscriber.Terminate()
@@ -523,7 +547,7 @@
 			delete(as.subscribers, partition)
 		}
 	}
-	return removedSubscribers, nil
+	return NewPartitionSet(previousPartitions), removedSubscribers, nil
 }
 
 // Terminate shuts down all singlePartitionSubscribers without waiting for
@@ -537,6 +561,15 @@
 	}
 }
 
+// PartitionActive returns whether the partition is still active.
+func (as *assigningSubscriber) PartitionActive(partition int) bool {
+	as.mu.Lock()
+	defer as.mu.Unlock()
+
+	_, exists := as.subscribers[partition]
+	return exists
+}
+
 // Subscriber is the client interface exported from this package for receiving
 // messages.
 type Subscriber interface {
@@ -545,10 +578,12 @@
 	Stop()
 	WaitStopped() error
 	Terminate()
+	PartitionActive(int) bool
 }
 
 // NewSubscriber creates a new client for receiving messages.
-func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc, region, subscriptionPath string, opts ...option.ClientOption) (Subscriber, error) {
+func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc, reassignmentHandler ReassignmentHandlerFunc,
+	region, subscriptionPath string, opts ...option.ClientOption) (Subscriber, error) {
 	if err := ValidateRegion(region); err != nil {
 		return nil, err
 	}
@@ -588,5 +623,5 @@
 		return nil, err
 	}
 	allClients = append(allClients, partitionClient)
-	return newAssigningSubscriber(allClients, partitionClient, uuid.NewRandom, subFactory)
+	return newAssigningSubscriber(allClients, partitionClient, reassignmentHandler, uuid.NewRandom, subFactory)
 }
diff --git a/pubsublite/internal/wire/subscriber_test.go b/pubsublite/internal/wire/subscriber_test.go
index 6aa5fe7..f7dca03 100644
--- a/pubsublite/internal/wire/subscriber_test.go
+++ b/pubsublite/internal/wire/subscriber_test.go
@@ -15,6 +15,7 @@
 
 import (
 	"context"
+	"errors"
 	"sort"
 	"sync"
 	"testing"
@@ -30,17 +31,22 @@
 	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
 )
 
+const (
+	maxMessages int = 10
+	maxBytes    int = 1000
+)
+
 func testSubscriberSettings() ReceiveSettings {
 	settings := testReceiveSettings()
-	settings.MaxOutstandingMessages = 10
-	settings.MaxOutstandingBytes = 1000
+	settings.MaxOutstandingMessages = maxMessages
+	settings.MaxOutstandingBytes = maxBytes
 	return settings
 }
 
 // initFlowControlReq returns the first expected flow control request when
 // testSubscriberSettings are used.
 func initFlowControlReq() *pb.SubscribeRequest {
-	return flowControlSubReq(flowControlTokens{Bytes: 1000, Messages: 10})
+	return flowControlSubReq(flowControlTokens{Bytes: int64(maxBytes), Messages: int64(maxMessages)})
 }
 
 func partitionMsgs(partition int, msgs ...*pb.SequencedMessage) []*ReceivedMessage {
@@ -929,6 +935,15 @@
 	}
 }
 
+func verifyPartitionsActive(t *testing.T, sub Subscriber, want bool, partitions ...int) {
+	t.Helper()
+	for _, p := range partitions {
+		if got := sub.PartitionActive(p); got != want {
+			t.Errorf("PartitionActive(%d) got %v, want %v", p, got, want)
+		}
+	}
+}
+
 func newTestMultiPartitionSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string, partitions []int) *multiPartitionSubscriber {
 	ctx := context.Background()
 	subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
@@ -994,6 +1009,9 @@
 	defer mockServer.OnTestEnd()
 
 	sub := newTestMultiPartitionSubscriber(t, receiver.onMessage, subscription, []int{1, 2})
+	verifyPartitionsActive(t, sub, true, 1, 2)
+	verifyPartitionsActive(t, sub, false, 0, 3)
+
 	if gotErr := sub.WaitStarted(); gotErr != nil {
 		t.Errorf("Start() got err: (%v)", gotErr)
 	}
@@ -1056,18 +1074,6 @@
 	receiver.VerifyNoMsgs()
 }
 
-func (as *assigningSubscriber) Partitions() []int {
-	as.mu.Lock()
-	defer as.mu.Unlock()
-
-	var partitions []int
-	for p := range as.subscribers {
-		partitions = append(partitions, p)
-	}
-	sort.Ints(partitions)
-	return partitions
-}
-
 func (as *assigningSubscriber) Subscribers() []*singlePartitionSubscriber {
 	as.mu.Lock()
 	defer as.mu.Unlock()
@@ -1088,7 +1094,11 @@
 	}
 }
 
-func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string) *assigningSubscriber {
+func noopReassignmentHandler(_, _ PartitionSet) error {
+	return nil
+}
+
+func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, reassignmentHandler ReassignmentHandlerFunc, subscriptionPath string) *assigningSubscriber {
 	ctx := context.Background()
 	subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
 	if err != nil {
@@ -1113,7 +1123,7 @@
 		receiver:         receiverFunc,
 		disableTasks:     true, // Background tasks disabled to control event order
 	}
-	sub, err := newAssigningSubscriber(allClients, assignmentClient, fakeGenerateUUID, f)
+	sub, err := newAssigningSubscriber(allClients, assignmentClient, reassignmentHandler, fakeGenerateUUID, f)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -1179,23 +1189,21 @@
 	mockServer.OnTestStart(verifiers)
 	defer mockServer.OnTestEnd()
 
-	sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
+	sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription)
 	if gotErr := sub.WaitStarted(); gotErr != nil {
 		t.Errorf("Start() got err: (%v)", gotErr)
 	}
 
 	// Partition assignments are initially {3, 6}.
 	receiver.ValidateMsgs(join(partitionMsgs(3, msg1), partitionMsgs(6, msg3)))
-	if got, want := sub.Partitions(), []int{3, 6}; !testutil.Equal(got, want) {
-		t.Errorf("subscriber partitions: got %d, want %d", got, want)
-	}
+	verifyPartitionsActive(t, sub, true, 3, 6)
+	verifyPartitionsActive(t, sub, false, 1, 8)
 
 	// Partition assignments will now be {3, 8}.
 	assignmentBarrier1.Release()
 	receiver.ValidateMsgs(partitionMsgs(8, msg5))
-	if got, want := sub.Partitions(), []int{3, 8}; !testutil.Equal(got, want) {
-		t.Errorf("subscriber partitions: got %d, want %d", got, want)
-	}
+	verifyPartitionsActive(t, sub, true, 3, 8)
+	verifyPartitionsActive(t, sub, false, 2, 6)
 
 	// msg2 is from partition 3 and should be received. msg4 is from partition 6
 	// (removed) and should be discarded.
@@ -1255,7 +1263,7 @@
 	mockServer.OnTestStart(verifiers)
 	defer mockServer.OnTestEnd()
 
-	sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
+	sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription)
 	if gotErr := sub.WaitStarted(); gotErr != nil {
 		t.Errorf("Start() got err: (%v)", gotErr)
 	}
@@ -1298,7 +1306,7 @@
 	mockServer.OnTestStart(verifiers)
 	defer mockServer.OnTestEnd()
 
-	sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
+	sub := newTestAssigningSubscriber(t, receiver.onMessage, noopReassignmentHandler, subscription)
 	if gotErr := sub.WaitStarted(); gotErr != nil {
 		t.Errorf("Start() got err: (%v)", gotErr)
 	}
@@ -1330,6 +1338,116 @@
 	}
 }
 
+func TestAssigningSubscriberStoppedWhileReassignmentHandlerActive(t *testing.T) {
+	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
+	receiver := newTestMessageReceiver(t)
+
+	verifiers := test.NewVerifiers(t)
+
+	// Assignment stream
+	asnStream := test.NewRPCVerifier(t)
+	asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1}), nil)
+	verifiers.AddAssignmentStream(subscription, asnStream)
+
+	// Partition 1
+	subStream := test.NewRPCVerifier(t)
+	subStream.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
+	subBarrier := subStream.PushWithBarrier(initFlowControlReq(), nil, nil)
+	verifiers.AddSubscribeStream(subscription, 1, subStream)
+
+	cmtStream := test.NewRPCVerifier(t)
+	cmtBarrier := cmtStream.PushWithBarrier(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
+	verifiers.AddCommitStream(subscription, 1, cmtStream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	reassignmentHandlerCalled := test.NewCondition("reassignment handler called")
+	returnReassignmentHandler := test.NewCondition("return reassignment handler")
+	onReassignment := func(before, after PartitionSet) error {
+		if got, want := len(before.SortedInts()), 0; got != want {
+			t.Errorf("len(before): got %v, want %v", got, want)
+		}
+		if got, want := after.SortedInts(), []int{1}; !testutil.Equal(got, want) {
+			t.Errorf("after: got %v, want %v", got, want)
+		}
+		reassignmentHandlerCalled.SetDone()
+		returnReassignmentHandler.WaitUntilDone(t, serviceTestWaitTimeout)
+		return nil
+	}
+
+	sub := newTestAssigningSubscriber(t, receiver.onMessage, onReassignment, subscription)
+	if gotErr := sub.WaitStarted(); gotErr != nil {
+		t.Errorf("Start() got err: (%v)", gotErr)
+	}
+
+	// Used to control order of execution to ensure the test is deterministic.
+	subBarrier.Release()
+	cmtBarrier.Release()
+
+	// Ensure there are no deadlocks if the reassignment handler blocks and the
+	// subscriber is stopped.
+	reassignmentHandlerCalled.WaitUntilDone(t, serviceTestWaitTimeout)
+	sub.Stop()
+	returnReassignmentHandler.SetDone()
+
+	if gotErr := sub.WaitStopped(); gotErr != nil {
+		t.Errorf("WaitStopped() got err: (%v)", gotErr)
+	}
+}
+
+func TestAssigningSubscriberReassignmentHandlerReturnsError(t *testing.T) {
+	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
+	receiver := newTestMessageReceiver(t)
+
+	verifiers := test.NewVerifiers(t)
+
+	// Assignment stream
+	asnStream := test.NewRPCVerifier(t)
+	asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1}), nil)
+	verifiers.AddAssignmentStream(subscription, asnStream)
+
+	// Partition 1
+	subStream := test.NewRPCVerifier(t)
+	subStream.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
+	subBarrier := subStream.PushWithBarrier(initFlowControlReq(), nil, nil)
+	verifiers.AddSubscribeStream(subscription, 1, subStream)
+
+	cmtStream := test.NewRPCVerifier(t)
+	cmtBarrier := cmtStream.PushWithBarrier(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
+	verifiers.AddCommitStream(subscription, 1, cmtStream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	reassignmentErr := errors.New("reassignment handler error")
+	returnReassignmentErr := test.NewCondition("return reassignment error")
+	onAssignment := func(before, after PartitionSet) error {
+		if got, want := len(before.SortedInts()), 0; got != want {
+			t.Errorf("len(before): got %v, want %v", got, want)
+		}
+		if got, want := after.SortedInts(), []int{1}; !testutil.Equal(got, want) {
+			t.Errorf("after: got %v, want %v", got, want)
+		}
+		returnReassignmentErr.WaitUntilDone(t, serviceTestWaitTimeout)
+		return reassignmentErr
+	}
+
+	sub := newTestAssigningSubscriber(t, receiver.onMessage, onAssignment, subscription)
+	if gotErr := sub.WaitStarted(); gotErr != nil {
+		t.Errorf("Start() got err: (%v)", gotErr)
+	}
+
+	// Used to control order of execution to ensure the test is deterministic.
+	subBarrier.Release()
+	cmtBarrier.Release()
+	returnReassignmentErr.SetDone()
+
+	if gotErr := sub.WaitStopped(); !test.ErrorEqual(gotErr, reassignmentErr) {
+		t.Errorf("WaitStopped() got err: (%v), want err: (%v)", gotErr, reassignmentErr)
+	}
+}
+
 func TestNewSubscriberValidatesSettings(t *testing.T) {
 	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
 	const region = "us-central1"
@@ -1337,7 +1455,7 @@
 
 	settings := DefaultReceiveSettings
 	settings.MaxOutstandingMessages = 0
-	if _, err := NewSubscriber(context.Background(), settings, receiver.onMessage, region, subscription); err == nil {
+	if _, err := NewSubscriber(context.Background(), settings, receiver.onMessage, noopReassignmentHandler, region, subscription); err == nil {
 		t.Error("NewSubscriber() did not return error")
 	}
 }
diff --git a/pubsublite/pscompat/integration_test.go b/pubsublite/pscompat/integration_test.go
index 59aa779..86da439 100644
--- a/pubsublite/pscompat/integration_test.go
+++ b/pubsublite/pscompat/integration_test.go
@@ -17,6 +17,7 @@
 	"context"
 	"errors"
 	"fmt"
+	"sort"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -729,12 +730,29 @@
 			}
 		}
 
+		// Verify partition reassignment notifications.
+		var allPartitions []int
+		var mu sync.Mutex
+		reassignmentHandler := func(before, after []int) error {
+			t.Logf("Partition assignments: before %v, after %v", before, after)
+			if got, want := len(before), 0; got != want {
+				t.Errorf("Partition assignments len(before): got %d, want %d", got, want)
+			}
+			mu.Lock()
+			allPartitions = append(allPartitions, after...)
+			mu.Unlock()
+			return nil
+		}
+
+		receiveSettings := DefaultReceiveSettings
+		receiveSettings.ReassignmentHandler = reassignmentHandler
+
 		cctx, stopSubscribers := context.WithTimeout(context.Background(), defaultTestTimeout)
 		g, _ := errgroup.WithContext(ctx)
 		for i := 0; i < subscriberCount; i++ {
 			// Subscribers must be started in a goroutine as Receive() blocks.
 			g.Go(func() error {
-				subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath)
+				subscriber := subscriberClient(context.Background(), t, receiveSettings, subscriptionPath)
 				err := subscriber.Receive(cctx, messageReceiver)
 				if err != nil {
 					t.Errorf("Receive() got err: %v", err)
@@ -748,6 +766,13 @@
 		stopSubscribers()
 		// Wait until all subscribers have terminated.
 		g.Wait()
+
+		mu.Lock()
+		sort.Ints(allPartitions)
+		if got, want := allPartitions, partitionNumbers(partitionCount); !testutil.Equal(got, want) {
+			t.Errorf("Assigned partition numbers: got %v, want %v", got, want)
+		}
+		mu.Unlock()
 	})
 }
 
diff --git a/pubsublite/pscompat/settings.go b/pubsublite/pscompat/settings.go
index 59b84c8..a4b8b51 100644
--- a/pubsublite/pscompat/settings.go
+++ b/pubsublite/pscompat/settings.go
@@ -164,6 +164,30 @@
 // error and terminate.
 type ReceiveMessageTransformerFunc func(*pb.SequencedMessage, *pubsub.Message) error
 
+// ReassignmentHandlerFunc is called any time a new partition assignment is
+// received from the server. It will be called with both the previous and new
+// partition numbers as decided by the server. Both slices of partition numbers
+// are sorted in ascending order.
+//
+// When this handler is called, partitions that are being assigned away are
+// stopping and new partitions are starting. Acks and nacks for messages from
+// partitions that are being assigned away will have no effect, but message
+// deliveries may still be in flight.
+//
+// The client library will not acknowledge the assignment until this handler
+// returns. The server will not assign any of the partitions in
+// `previousPartitions` to another client unless the assignment is acknowledged,
+// or a client takes too long to acknowledge (currently 30 seconds from the time
+// the assignment is sent from server's point of view).
+//
+// Because of the above, as long as reassignment handling is processed quickly,
+// it can be used to abort outstanding operations on partitions which are being
+// assigned away from this client.
+//
+// If this handler returns an error, the SubscriberClient will consider this a
+// fatal error and terminate.
+type ReassignmentHandlerFunc func(previousPartitions, nextPartitions []int) error
+
 // ReceiveSettings configure the SubscriberClient. Flow control settings
 // (MaxOutstandingMessages, MaxOutstandingBytes) apply per partition.
 //
@@ -210,6 +234,10 @@
 	// Optional custom function that transforms a SequencedMessage API proto to a
 	// pubsub.Message.
 	MessageTransformer ReceiveMessageTransformerFunc
+
+	// Optional custom function that is called when a new partition assignment has
+	// been delivered to the client.
+	ReassignmentHandler ReassignmentHandlerFunc
 }
 
 // DefaultReceiveSettings holds the default values for ReceiveSettings.
diff --git a/pubsublite/pscompat/subscriber.go b/pubsublite/pscompat/subscriber.go
index 368d735..d0d00e0 100644
--- a/pubsublite/pscompat/subscriber.go
+++ b/pubsublite/pscompat/subscriber.go
@@ -40,6 +40,7 @@
 type pslAckHandler struct {
 	ackh        wire.AckConsumer
 	msg         *pubsub.Message
+	partition   int
 	nackh       NackHandler
 	subInstance *subscriberInstance
 }
@@ -58,6 +59,11 @@
 		return
 	}
 
+	// Ignore nacks for partitions that have been assigned away.
+	if !ah.subInstance.wireSub.PartitionActive(ah.partition) {
+		return
+	}
+
 	err := ah.nackh(ah.msg)
 	if err != nil {
 		// If the NackHandler returns an error, shut down the subscriber client.
@@ -72,7 +78,7 @@
 // wireSubscriberFactory is a factory for creating wire subscribers, which can
 // be overridden with a mock in unit tests.
 type wireSubscriberFactory interface {
-	New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error)
+	New(context.Context, wire.MessageReceiverFunc, wire.ReassignmentHandlerFunc) (wire.Subscriber, error)
 }
 
 type wireSubscriberFactoryImpl struct {
@@ -82,8 +88,8 @@
 	options      []option.ClientOption
 }
 
-func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
-	return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...)
+func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) {
+	return wire.NewSubscriber(ctx, f.settings, receiver, onReassignment, f.region, f.subscription.String(), f.options...)
 }
 
 type messageReceiverFunc = func(context.Context, *pubsub.Message)
@@ -116,7 +122,7 @@
 	// cancelled, the gRPC streams will be disconnected and the subscriber will
 	// not be able to process acks and commit the final cursor offset. Use the
 	// context from NewSubscriberClient (clientCtx) instead.
-	wireSub, err := factory.New(clientCtx, subInstance.onMessage)
+	wireSub, err := factory.New(clientCtx, subInstance.onMessage, subInstance.onReassignment)
 	if err != nil {
 		return nil, err
 	}
@@ -131,6 +137,13 @@
 	return subInstance, nil
 }
 
+func (si *subscriberInstance) onReassignment(before, after wire.PartitionSet) error {
+	if si.settings.ReassignmentHandler != nil {
+		return si.settings.ReassignmentHandler(before.SortedInts(), after.SortedInts())
+	}
+	return nil
+}
+
 func (si *subscriberInstance) transformMessage(in *wire.ReceivedMessage, out *pubsub.Message) error {
 	if err := si.settings.MessageTransformer(in.Msg, out); err != nil {
 		return err
@@ -147,6 +160,7 @@
 	pslAckh := &pslAckHandler{
 		ackh:        msg.Ack,
 		nackh:       si.settings.NackHandler,
+		partition:   msg.Partition,
 		subInstance: si,
 	}
 	psMsg := ipubsub.NewMessage(pslAckh)
diff --git a/pubsublite/pscompat/subscriber_test.go b/pubsublite/pscompat/subscriber_test.go
index 5c737f8..6bd4ac8 100644
--- a/pubsublite/pscompat/subscriber_test.go
+++ b/pubsublite/pscompat/subscriber_test.go
@@ -30,7 +30,10 @@
 	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
 )
 
-const defaultSubscriberTestTimeout = 10 * time.Second
+const (
+	defaultSubscriberTestTimeout = 10 * time.Second
+	activePartition              = 1
+)
 
 // mockAckConsumer is a mock implementation of the wire.AckConsumer interface.
 type mockAckConsumer struct {
@@ -43,12 +46,14 @@
 
 // mockWireSubscriber is a mock implementation of the wire.Subscriber interface.
 type mockWireSubscriber struct {
-	receiver   wire.MessageReceiverFunc
-	msgsC      chan *wire.ReceivedMessage
-	stopC      chan struct{}
-	err        error
-	Stopped    bool
-	Terminated bool
+	receiver         wire.MessageReceiverFunc
+	onReassignment   wire.ReassignmentHandlerFunc
+	activePartitions wire.PartitionSet
+	msgsC            chan *wire.ReceivedMessage
+	stopC            chan struct{}
+	err              error
+	Stopped          bool
+	Terminated       bool
 }
 
 // DeliverMessages should be called from the test to simulate a message
@@ -59,6 +64,14 @@
 	}
 }
 
+// OnReassignment should be called from the test to simulate a partition
+// reassignment.
+func (ms *mockWireSubscriber) DeliverReassignment(before, after wire.PartitionSet) {
+	if err := ms.onReassignment(before, after); err != nil {
+		ms.SimulateFatalError(err)
+	}
+}
+
 // SimulateFatalError should be called from the test to simulate a fatal error
 // occurring in the wire subscriber.
 func (ms *mockWireSubscriber) SimulateFatalError(err error) {
@@ -111,13 +124,19 @@
 	return ms.err
 }
 
+func (ms *mockWireSubscriber) PartitionActive(partition int) bool {
+	return ms.activePartitions.Contains(partition)
+}
+
 type mockWireSubscriberFactory struct{}
 
-func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
+func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) {
 	return &mockWireSubscriber{
-		receiver: receiver,
-		msgsC:    make(chan *wire.ReceivedMessage, 10),
-		stopC:    make(chan struct{}),
+		receiver:         receiver,
+		onReassignment:   onReassignment,
+		activePartitions: wire.NewPartitionSet([]int{activePartition}),
+		msgsC:            make(chan *wire.ReceivedMessage, 10),
+		stopC:            make(chan struct{}),
 	}, nil
 }
 
@@ -289,6 +308,7 @@
 		desc string
 		// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
 		mutateSettings func(settings *ReceiveSettings)
+		msgPartition   int
 		wantErr        error
 		wantAckCount   int
 		wantStopped    bool
@@ -297,17 +317,27 @@
 		{
 			desc:           "default settings",
 			mutateSettings: func(settings *ReceiveSettings) {},
+			msgPartition:   activePartition,
 			wantErr:        errNackCalled,
 			wantAckCount:   0,
 			wantTerminated: true,
 		},
 		{
+			desc:           "message partition inactive",
+			mutateSettings: func(settings *ReceiveSettings) {},
+			msgPartition:   activePartition + 1,
+			wantErr:        nil,
+			wantAckCount:   0,
+			wantStopped:    true,
+		},
+		{
 			desc: "nack handler returns nil",
 			mutateSettings: func(settings *ReceiveSettings) {
 				settings.NackHandler = func(_ *pubsub.Message) error {
 					return nil
 				}
 			},
+			msgPartition: activePartition,
 			wantErr:      nil,
 			wantAckCount: 1,
 			wantStopped:  true,
@@ -319,6 +349,7 @@
 					return nackErr
 				}
 			},
+			msgPartition:   activePartition,
 			wantErr:        nackErr,
 			wantAckCount:   0,
 			wantTerminated: true,
@@ -329,7 +360,7 @@
 			tc.mutateSettings(&settings)
 
 			ack := &mockAckConsumer{}
-			msg := &wire.ReceivedMessage{Msg: msg, Ack: ack}
+			msg := &wire.ReceivedMessage{Msg: msg, Ack: ack, Partition: tc.msgPartition}
 
 			cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
 			messageReceiver := func(ctx context.Context, got *pubsub.Message) {
@@ -363,6 +394,67 @@
 	}
 }
 
+func TestSubscriberInstanceReassignmentHandler(t *testing.T) {
+	reassignmentErr := errors.New("reassignment failure")
+	before := wire.NewPartitionSet([]int{3, 2, 1})
+	after := wire.NewPartitionSet([]int{4, 5, 3})
+	ctx := context.Background()
+
+	for _, tc := range []struct {
+		desc string
+		// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
+		mutateSettings func(settings *ReceiveSettings)
+		wantErr        error
+	}{
+		{
+			desc:           "default settings",
+			mutateSettings: func(settings *ReceiveSettings) {},
+		},
+		{
+			desc: "reassignment handler returns nil",
+			mutateSettings: func(settings *ReceiveSettings) {
+				settings.ReassignmentHandler = func(before, after []int) error {
+					if got, want := before, []int{1, 2, 3}; !testutil.Equal(got, want) {
+						t.Errorf("before: got %d, want %d", got, want)
+					}
+					if got, want := after, []int{3, 4, 5}; !testutil.Equal(got, want) {
+						t.Errorf("after: got %d, want %d", got, want)
+					}
+					return nil
+				}
+			},
+		},
+		{
+			desc: "reassignment handler returns error",
+			mutateSettings: func(settings *ReceiveSettings) {
+				settings.ReassignmentHandler = func(before, after []int) error {
+					return reassignmentErr
+				}
+			},
+			wantErr: reassignmentErr,
+		},
+	} {
+		t.Run(tc.desc, func(t *testing.T) {
+			settings := DefaultReceiveSettings
+			tc.mutateSettings(&settings)
+
+			cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
+			messageReceiver := func(ctx context.Context, got *pubsub.Message) {
+				t.Error("Message receiver should not be called")
+			}
+			subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
+			subInstance.wireSub.(*mockWireSubscriber).DeliverReassignment(before, after)
+			if tc.wantErr == nil {
+				stopSubscriber()
+			}
+
+			if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
+				t.Errorf("subscriberInstance.Wait() got err: (%v), want err: (%v)", gotErr, tc.wantErr)
+			}
+		})
+	}
+}
+
 func TestSubscriberInstanceWireSubscriberFails(t *testing.T) {
 	fatalErr := errors.New("server error")