spanner: make ReadWriteTransaction retry on Session not found error
Updates #1527
Ref: https://github.com/googleapis/google-cloud-go/issues/1527
Change-Id: Iea12342ca098c8056abc2206b91edbeda630e718
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/45910
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
diff --git a/spanner/client.go b/spanner/client.go
index 4f0ee8e..f1d18f8 100644
--- a/spanner/client.go
+++ b/spanner/client.go
@@ -435,7 +435,7 @@
ts time.Time
sh *sessionHandle
)
- err = runWithRetryOnAborted(ctx, func(ctx context.Context) error {
+ err = runWithRetryOnAbortedOrSessionNotFound(ctx, func(ctx context.Context) error {
var (
err error
t *ReadWriteTransaction
diff --git a/spanner/client_test.go b/spanner/client_test.go
index 8b83f61..36cf0b9 100644
--- a/spanner/client_test.go
+++ b/spanner/client_test.go
@@ -68,7 +68,7 @@
server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
opts = append(opts, clientOptions...)
ctx := context.Background()
- var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
+ formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...)
if err != nil {
t.Fatal(err)
@@ -609,6 +609,165 @@
}
}
+func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) {
+ t.Parallel()
+ if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
+ MethodCommitTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+ }, 2); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
+ t.Parallel()
+ // We expect only 1 attempt, as the 'Session not found' error is already
+ //handled in the session pool where the session is prepared.
+ if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
+ MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+ }, 1); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransactionWithEmptySessionPool(t *testing.T) {
+ t.Parallel()
+ // There will be no prepared sessions in the pool, so the error will occur
+ // when the transaction tries to get a session from the pool. This will
+ // also be handled by the session pool, so the transaction itself does not
+ // need to retry, hence the expectedAttempts == 1.
+ if err := testReadWriteTransactionWithConfig(t, ClientConfig{
+ SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.0},
+ }, map[string]SimulatedExecutionTime{
+ MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+ }, 1); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
+ t.Parallel()
+ if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
+ MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+ }, 2); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteUpdate(t *testing.T) {
+ t.Parallel()
+
+ server, client, teardown := setupMockedTestServer(t)
+ defer teardown()
+ server.TestSpanner.PutExecutionTime(
+ MethodExecuteSql,
+ SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+ )
+ ctx := context.Background()
+ var attempts int
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ attempts++
+ rowCount, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo))
+ if err != nil {
+ return err
+ }
+ if g, w := rowCount, int64(UpdateBarSetFooRowCount); g != w {
+ return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, w := attempts, 2; g != w {
+ t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
+ }
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteBatchUpdate(t *testing.T) {
+ t.Parallel()
+
+ server, client, teardown := setupMockedTestServer(t)
+ defer teardown()
+ server.TestSpanner.PutExecutionTime(
+ MethodExecuteBatchDml,
+ SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+ )
+ ctx := context.Background()
+ var attempts int
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ attempts++
+ rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(UpdateBarSetFoo)})
+ if err != nil {
+ return err
+ }
+ if g, w := len(rowCounts), 1; g != w {
+ return status.Errorf(codes.FailedPrecondition, "Row counts length mismatch\nGot: %v\nWant: %v", g, w)
+ }
+ if g, w := rowCounts[0], int64(UpdateBarSetFooRowCount); g != w {
+ return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, w := attempts, 2; g != w {
+ t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
+ }
+}
+
+func TestClient_SessionNotFound(t *testing.T) {
+ // Ensure we always have at least one session in the pool.
+ sc := SessionPoolConfig{
+ MinOpened: 1,
+ }
+ server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc})
+ defer teardown()
+ ctx := context.Background()
+ for {
+ client.idleSessions.mu.Lock()
+ numSessions := client.idleSessions.idleList.Len()
+ client.idleSessions.mu.Unlock()
+ if numSessions > 0 {
+ break
+ }
+ time.After(time.Millisecond)
+ }
+ // Remove the session from the server without the pool knowing it.
+ _, err := server.TestSpanner.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id})
+ if err != nil {
+ t.Fatalf("Failed to delete session unexpectedly: %v", err)
+ }
+
+ _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
+ defer iter.Stop()
+ rowCount := int64(0)
+ for {
+ row, err := iter.Next()
+ if err == iterator.Done {
+ break
+ }
+ if err != nil {
+ return err
+ }
+ var singerID, albumID int64
+ var albumTitle string
+ if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
+ return err
+ }
+ rowCount++
+ }
+ if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
+ return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("Unexpected error during transaction: %v", err)
+ }
+}
+
func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) {
t.Parallel()
if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
@@ -801,6 +960,10 @@
}
func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
+ return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts)
+}
+
+func testReadWriteTransactionWithConfig(t *testing.T, config ClientConfig, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
server, client, teardown := setupMockedTestServer(t)
defer teardown()
for method, exec := range executionTimes {
@@ -966,3 +1129,50 @@
t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg)
}
}
+
+func TestReadWriteTransaction_WrapSessionNotFoundError(t *testing.T) {
+ t.Parallel()
+ server, client, teardown := setupMockedTestServer(t)
+ defer teardown()
+ server.TestSpanner.PutExecutionTime(MethodBeginTransaction,
+ SimulatedExecutionTime{
+ Errors: []error{status.Error(codes.NotFound, "Session not found")},
+ })
+ server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
+ SimulatedExecutionTime{
+ Errors: []error{status.Error(codes.NotFound, "Session not found")},
+ })
+ server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
+ SimulatedExecutionTime{
+ Errors: []error{status.Error(codes.NotFound, "Session not found")},
+ })
+ msg := "query failed"
+ numAttempts := 0
+ ctx := context.Background()
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ numAttempts++
+ iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
+ defer iter.Stop()
+ for {
+ _, err := iter.Next()
+ if err == iterator.Done {
+ break
+ }
+ if err != nil {
+ // Wrap the error in another error that implements the
+ // (xerrors|errors).Wrapper interface.
+ return &wrappedTestError{err, msg}
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
+ }
+ // We want 3 attempts. The 'Session not found' error on BeginTransaction
+ // will not retry the entire transaction, which means that we will have two
+ // failed attempts and then a successful attempt.
+ if g, w := numAttempts, 3; g != w {
+ t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", g, w)
+ }
+}
diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go
index f107f18..eb1c870 100644
--- a/spanner/internal/testutil/inmem_spanner_server.go
+++ b/spanner/internal/testutil/inmem_spanner_server.go
@@ -431,7 +431,7 @@
defer s.mu.Unlock()
session := s.sessions[name]
if session == nil {
- return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name))
+ return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session not found: %s", name))
}
return session, nil
}
diff --git a/spanner/retry.go b/spanner/retry.go
index cc2d520..75231af 100644
--- a/spanner/retry.go
+++ b/spanner/retry.go
@@ -69,11 +69,13 @@
return delay, true
}
-// runWithRetryOnAborted executes the given function and retries it if it
-// returns an Aborted error. The delay between retries is the delay returned
-// by Cloud Spanner, and if none is returned, the calculated delay with a
-// minimum of 10ms and maximum of 32s.
-func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) error {
+// runWithRetryOnAbortedOrSessionNotFound executes the given function and
+// retries it if it returns an Aborted or Session not found error. The retry
+// is delayed if the error was Aborted. The delay between retries is the delay
+// returned by Cloud Spanner, or if none is returned, the calculated delay with
+// a minimum of 10ms and maximum of 32s. There is no delay before the retry if
+// the error was Session not found.
+func runWithRetryOnAbortedOrSessionNotFound(ctx context.Context, f func(context.Context) error) error {
retryer := onCodes(DefaultRetryBackoff, codes.Aborted)
funcWithRetry := func(ctx context.Context) error {
for {
@@ -99,6 +101,10 @@
}
retryErr = err
}
+ if isSessionNotFoundError(retryErr) {
+ trace.TracePrintf(ctx, nil, "Retrying after Session not found")
+ continue
+ }
delay, shouldRetry := retryer.Retry(retryErr)
if !shouldRetry {
return err
diff --git a/spanner/session.go b/spanner/session.go
index 6545c8a..9a3dfae 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -130,22 +130,23 @@
func (sh *sessionHandle) destroy() {
sh.mu.Lock()
s := sh.session
- p := s.pool
tracked := sh.trackedSessionHandle
sh.session = nil
sh.trackedSessionHandle = nil
sh.checkoutTime = time.Time{}
sh.stack = nil
sh.mu.Unlock()
- if tracked != nil {
- p.mu.Lock()
- p.trackedSessionHandles.Remove(tracked)
- p.mu.Unlock()
- }
+
if s == nil {
// sessionHandle has already been destroyed..
return
}
+ if tracked != nil {
+ p := s.pool
+ p.mu.Lock()
+ p.trackedSessionHandles.Remove(tracked)
+ p.mu.Unlock()
+ }
s.destroy(false)
}
@@ -764,7 +765,7 @@
func (p *sessionPool) isHealthy(s *session) bool {
if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) {
// TODO: figure out if we need to schedule a new healthcheck worker here.
- if err := s.ping(); shouldDropSession(err) {
+ if err := s.ping(); isSessionNotFoundError(err) {
// The session is already bad, continue to fetch/create a new one.
s.destroy(false)
return false
@@ -923,6 +924,13 @@
}
if !s.isWritePrepared() {
if err = s.prepareForWrite(ctx); err != nil {
+ if isSessionNotFoundError(err) {
+ s.destroy(false)
+ trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
+ "Session not found for write")
+ return nil, toSpannerError(err)
+ }
+
s.recycle()
trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
"Error preparing session for write")
@@ -1230,7 +1238,7 @@
s.destroy(false)
return
}
- if err := s.ping(); shouldDropSession(err) {
+ if err := s.ping(); isSessionNotFoundError(err) {
// Ping failed, destroy the session.
s.destroy(false)
}
@@ -1497,23 +1505,6 @@
}
}
-// shouldDropSession returns true if a particular error leads to the removal of
-// a session
-func shouldDropSession(err error) bool {
- if err == nil {
- return false
- }
- // If a Cloud Spanner can no longer locate the session (for example, if
- // session is garbage collected), then caller should not try to return the
- // session back into the session pool.
- //
- // TODO: once gRPC can return auxiliary error information, stop parsing the error message.
- if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found") {
- return true
- }
- return false
-}
-
// maxUint64 returns the maximum of two uint64.
func maxUint64(a, b uint64) uint64 {
if a > b {
@@ -1533,9 +1524,13 @@
// isSessionNotFoundError returns true if the given error is a
// `Session not found` error.
func isSessionNotFoundError(err error) bool {
+ if err == nil {
+ return false
+ }
// We are checking specifically for the error message `Session not found`,
// as the error could also be a `Database not found`. The latter should
// cause the session pool to stop preparing sessions for read/write
// transactions, while the former should not.
+ // TODO: once gRPC can return auxiliary error information, stop parsing the error message.
return ErrCode(err) == codes.NotFound && strings.Contains(err.Error(), "Session not found")
}
diff --git a/spanner/transaction.go b/spanner/transaction.go
index 39505f2..4b8825e 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -357,7 +357,7 @@
if err != nil && sh != nil {
// Got a valid session handle, but failed to initialize transaction=
// on Cloud Spanner.
- if shouldDropSession(err) {
+ if isSessionNotFoundError(err) {
sh.destroy()
}
// If sh.destroy was already executed, this becomes a noop.
@@ -527,7 +527,7 @@
sh := t.sh
t.mu.Unlock()
if sh != nil { // sh could be nil if t.acquire() fails.
- if shouldDropSession(err) {
+ if isSessionNotFoundError(err) {
sh.destroy()
}
if t.singleUse {
@@ -795,7 +795,7 @@
t.mu.Lock()
sh := t.sh
t.mu.Unlock()
- if sh != nil && shouldDropSession(err) {
+ if sh != nil && isSessionNotFoundError(err) {
sh.destroy()
}
}
@@ -831,7 +831,7 @@
t.state = txActive
return nil
}
- if shouldDropSession(err) {
+ if isSessionNotFoundError(err) {
t.sh.destroy()
}
return err
@@ -869,7 +869,7 @@
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
}
- if shouldDropSession(err) {
+ if isSessionNotFoundError(err) {
t.sh.destroy()
}
return ts, err
@@ -892,7 +892,7 @@
Session: sid,
TransactionId: t.tx,
})
- if shouldDropSession(err) {
+ if isSessionNotFoundError(err) {
t.sh.destroy()
}
}
@@ -914,6 +914,10 @@
// one's wound-wait priority.
return ts, err
}
+ if isSessionNotFoundError(err) {
+ t.sh.destroy()
+ return ts, err
+ }
// Not going to commit, according to API spec, should rollback the
// transaction.
t.rollback(ctx)
@@ -973,7 +977,7 @@
Mutations: mPb,
}, gax.WithGRPCOptions(grpc.Trailer(&trailers)))
if err != nil && !isAbortErr(err) {
- if shouldDropSession(err) {
+ if isSessionNotFoundError(err) {
// Discard the bad session.
sh.destroy()
}