spanner: make ReadWriteTransaction retry on Session not found error

Updates #1527

Ref: https://github.com/googleapis/google-cloud-go/issues/1527
Change-Id: Iea12342ca098c8056abc2206b91edbeda630e718
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/45910
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
diff --git a/spanner/client.go b/spanner/client.go
index 4f0ee8e..f1d18f8 100644
--- a/spanner/client.go
+++ b/spanner/client.go
@@ -435,7 +435,7 @@
 		ts time.Time
 		sh *sessionHandle
 	)
-	err = runWithRetryOnAborted(ctx, func(ctx context.Context) error {
+	err = runWithRetryOnAbortedOrSessionNotFound(ctx, func(ctx context.Context) error {
 		var (
 			err error
 			t   *ReadWriteTransaction
diff --git a/spanner/client_test.go b/spanner/client_test.go
index 8b83f61..36cf0b9 100644
--- a/spanner/client_test.go
+++ b/spanner/client_test.go
@@ -68,7 +68,7 @@
 	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
 	opts = append(opts, clientOptions...)
 	ctx := context.Background()
-	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
+	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
 	client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...)
 	if err != nil {
 		t.Fatal(err)
@@ -609,6 +609,165 @@
 	}
 }
 
+func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) {
+	t.Parallel()
+	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
+		MethodCommitTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+	}, 2); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
+	t.Parallel()
+	// We expect only 1 attempt, as the 'Session not found' error is already
+	//handled in the session pool where the session is prepared.
+	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
+		MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+	}, 1); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransactionWithEmptySessionPool(t *testing.T) {
+	t.Parallel()
+	// There will be no prepared sessions in the pool, so the error will occur
+	// when the transaction tries to get a session from the pool. This will
+	// also be handled by the session pool, so the transaction itself does not
+	// need to retry, hence the expectedAttempts == 1.
+	if err := testReadWriteTransactionWithConfig(t, ClientConfig{
+		SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.0},
+	}, map[string]SimulatedExecutionTime{
+		MethodBeginTransaction: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+	}, 1); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
+	t.Parallel()
+	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
+		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+	}, 2); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteUpdate(t *testing.T) {
+	t.Parallel()
+
+	server, client, teardown := setupMockedTestServer(t)
+	defer teardown()
+	server.TestSpanner.PutExecutionTime(
+		MethodExecuteSql,
+		SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+	)
+	ctx := context.Background()
+	var attempts int
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		attempts++
+		rowCount, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo))
+		if err != nil {
+			return err
+		}
+		if g, w := rowCount, int64(UpdateBarSetFooRowCount); g != w {
+			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
+		}
+		return nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if g, w := attempts, 2; g != w {
+		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
+	}
+}
+
+func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteBatchUpdate(t *testing.T) {
+	t.Parallel()
+
+	server, client, teardown := setupMockedTestServer(t)
+	defer teardown()
+	server.TestSpanner.PutExecutionTime(
+		MethodExecuteBatchDml,
+		SimulatedExecutionTime{Errors: []error{status.Error(codes.NotFound, "Session not found")}},
+	)
+	ctx := context.Background()
+	var attempts int
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		attempts++
+		rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(UpdateBarSetFoo)})
+		if err != nil {
+			return err
+		}
+		if g, w := len(rowCounts), 1; g != w {
+			return status.Errorf(codes.FailedPrecondition, "Row counts length mismatch\nGot: %v\nWant: %v", g, w)
+		}
+		if g, w := rowCounts[0], int64(UpdateBarSetFooRowCount); g != w {
+			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
+		}
+		return nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if g, w := attempts, 2; g != w {
+		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
+	}
+}
+
+func TestClient_SessionNotFound(t *testing.T) {
+	// Ensure we always have at least one session in the pool.
+	sc := SessionPoolConfig{
+		MinOpened: 1,
+	}
+	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc})
+	defer teardown()
+	ctx := context.Background()
+	for {
+		client.idleSessions.mu.Lock()
+		numSessions := client.idleSessions.idleList.Len()
+		client.idleSessions.mu.Unlock()
+		if numSessions > 0 {
+			break
+		}
+		time.After(time.Millisecond)
+	}
+	// Remove the session from the server without the pool knowing it.
+	_, err := server.TestSpanner.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id})
+	if err != nil {
+		t.Fatalf("Failed to delete session unexpectedly: %v", err)
+	}
+
+	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
+		defer iter.Stop()
+		rowCount := int64(0)
+		for {
+			row, err := iter.Next()
+			if err == iterator.Done {
+				break
+			}
+			if err != nil {
+				return err
+			}
+			var singerID, albumID int64
+			var albumTitle string
+			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
+				return err
+			}
+			rowCount++
+		}
+		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
+			return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
+		}
+		return nil
+	})
+	if err != nil {
+		t.Fatalf("Unexpected error during transaction: %v", err)
+	}
+}
+
 func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) {
 	t.Parallel()
 	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
@@ -801,6 +960,10 @@
 }
 
 func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
+	return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts)
+}
+
+func testReadWriteTransactionWithConfig(t *testing.T, config ClientConfig, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
 	server, client, teardown := setupMockedTestServer(t)
 	defer teardown()
 	for method, exec := range executionTimes {
@@ -966,3 +1129,50 @@
 		t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg)
 	}
 }
+
+func TestReadWriteTransaction_WrapSessionNotFoundError(t *testing.T) {
+	t.Parallel()
+	server, client, teardown := setupMockedTestServer(t)
+	defer teardown()
+	server.TestSpanner.PutExecutionTime(MethodBeginTransaction,
+		SimulatedExecutionTime{
+			Errors: []error{status.Error(codes.NotFound, "Session not found")},
+		})
+	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
+		SimulatedExecutionTime{
+			Errors: []error{status.Error(codes.NotFound, "Session not found")},
+		})
+	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
+		SimulatedExecutionTime{
+			Errors: []error{status.Error(codes.NotFound, "Session not found")},
+		})
+	msg := "query failed"
+	numAttempts := 0
+	ctx := context.Background()
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		numAttempts++
+		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
+		defer iter.Stop()
+		for {
+			_, err := iter.Next()
+			if err == iterator.Done {
+				break
+			}
+			if err != nil {
+				// Wrap the error in another error that implements the
+				// (xerrors|errors).Wrapper interface.
+				return &wrappedTestError{err, msg}
+			}
+		}
+		return nil
+	})
+	if err != nil {
+		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
+	}
+	// We want 3 attempts. The 'Session not found' error on BeginTransaction
+	// will not retry the entire transaction, which means that we will have two
+	// failed attempts and then a successful attempt.
+	if g, w := numAttempts, 3; g != w {
+		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", g, w)
+	}
+}
diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go
index f107f18..eb1c870 100644
--- a/spanner/internal/testutil/inmem_spanner_server.go
+++ b/spanner/internal/testutil/inmem_spanner_server.go
@@ -431,7 +431,7 @@
 	defer s.mu.Unlock()
 	session := s.sessions[name]
 	if session == nil {
-		return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name))
+		return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session not found: %s", name))
 	}
 	return session, nil
 }
diff --git a/spanner/retry.go b/spanner/retry.go
index cc2d520..75231af 100644
--- a/spanner/retry.go
+++ b/spanner/retry.go
@@ -69,11 +69,13 @@
 	return delay, true
 }
 
-// runWithRetryOnAborted executes the given function and retries it if it
-// returns an Aborted error. The delay between retries is the delay returned
-// by Cloud Spanner, and if none is returned, the calculated delay with a
-// minimum of 10ms and maximum of 32s.
-func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) error {
+// runWithRetryOnAbortedOrSessionNotFound executes the given function and
+// retries it if it returns an Aborted or Session not found error. The retry
+// is delayed if the error was Aborted. The delay between retries is the delay
+// returned by Cloud Spanner, or if none is returned, the calculated delay with
+// a minimum of 10ms and maximum of 32s. There is no delay before the retry if
+// the error was Session not found.
+func runWithRetryOnAbortedOrSessionNotFound(ctx context.Context, f func(context.Context) error) error {
 	retryer := onCodes(DefaultRetryBackoff, codes.Aborted)
 	funcWithRetry := func(ctx context.Context) error {
 		for {
@@ -99,6 +101,10 @@
 				}
 				retryErr = err
 			}
+			if isSessionNotFoundError(retryErr) {
+				trace.TracePrintf(ctx, nil, "Retrying after Session not found")
+				continue
+			}
 			delay, shouldRetry := retryer.Retry(retryErr)
 			if !shouldRetry {
 				return err
diff --git a/spanner/session.go b/spanner/session.go
index 6545c8a..9a3dfae 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -130,22 +130,23 @@
 func (sh *sessionHandle) destroy() {
 	sh.mu.Lock()
 	s := sh.session
-	p := s.pool
 	tracked := sh.trackedSessionHandle
 	sh.session = nil
 	sh.trackedSessionHandle = nil
 	sh.checkoutTime = time.Time{}
 	sh.stack = nil
 	sh.mu.Unlock()
-	if tracked != nil {
-		p.mu.Lock()
-		p.trackedSessionHandles.Remove(tracked)
-		p.mu.Unlock()
-	}
+
 	if s == nil {
 		// sessionHandle has already been destroyed..
 		return
 	}
+	if tracked != nil {
+		p := s.pool
+		p.mu.Lock()
+		p.trackedSessionHandles.Remove(tracked)
+		p.mu.Unlock()
+	}
 	s.destroy(false)
 }
 
@@ -764,7 +765,7 @@
 func (p *sessionPool) isHealthy(s *session) bool {
 	if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) {
 		// TODO: figure out if we need to schedule a new healthcheck worker here.
-		if err := s.ping(); shouldDropSession(err) {
+		if err := s.ping(); isSessionNotFoundError(err) {
 			// The session is already bad, continue to fetch/create a new one.
 			s.destroy(false)
 			return false
@@ -923,6 +924,13 @@
 		}
 		if !s.isWritePrepared() {
 			if err = s.prepareForWrite(ctx); err != nil {
+				if isSessionNotFoundError(err) {
+					s.destroy(false)
+					trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
+						"Session not found for write")
+					return nil, toSpannerError(err)
+				}
+
 				s.recycle()
 				trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
 					"Error preparing session for write")
@@ -1230,7 +1238,7 @@
 		s.destroy(false)
 		return
 	}
-	if err := s.ping(); shouldDropSession(err) {
+	if err := s.ping(); isSessionNotFoundError(err) {
 		// Ping failed, destroy the session.
 		s.destroy(false)
 	}
@@ -1497,23 +1505,6 @@
 	}
 }
 
-// shouldDropSession returns true if a particular error leads to the removal of
-// a session
-func shouldDropSession(err error) bool {
-	if err == nil {
-		return false
-	}
-	// If a Cloud Spanner can no longer locate the session (for example, if
-	// session is garbage collected), then caller should not try to return the
-	// session back into the session pool.
-	//
-	// TODO: once gRPC can return auxiliary error information, stop parsing the error message.
-	if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found") {
-		return true
-	}
-	return false
-}
-
 // maxUint64 returns the maximum of two uint64.
 func maxUint64(a, b uint64) uint64 {
 	if a > b {
@@ -1533,9 +1524,13 @@
 // isSessionNotFoundError returns true if the given error is a
 // `Session not found` error.
 func isSessionNotFoundError(err error) bool {
+	if err == nil {
+		return false
+	}
 	// We are checking specifically for the error message `Session not found`,
 	// as the error could also be a `Database not found`. The latter should
 	// cause the session pool to stop preparing sessions for read/write
 	// transactions, while the former should not.
+	// TODO: once gRPC can return auxiliary error information, stop parsing the error message.
 	return ErrCode(err) == codes.NotFound && strings.Contains(err.Error(), "Session not found")
 }
diff --git a/spanner/transaction.go b/spanner/transaction.go
index 39505f2..4b8825e 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -357,7 +357,7 @@
 		if err != nil && sh != nil {
 			// Got a valid session handle, but failed to initialize transaction=
 			// on Cloud Spanner.
-			if shouldDropSession(err) {
+			if isSessionNotFoundError(err) {
 				sh.destroy()
 			}
 			// If sh.destroy was already executed, this becomes a noop.
@@ -527,7 +527,7 @@
 	sh := t.sh
 	t.mu.Unlock()
 	if sh != nil { // sh could be nil if t.acquire() fails.
-		if shouldDropSession(err) {
+		if isSessionNotFoundError(err) {
 			sh.destroy()
 		}
 		if t.singleUse {
@@ -795,7 +795,7 @@
 	t.mu.Lock()
 	sh := t.sh
 	t.mu.Unlock()
-	if sh != nil && shouldDropSession(err) {
+	if sh != nil && isSessionNotFoundError(err) {
 		sh.destroy()
 	}
 }
@@ -831,7 +831,7 @@
 		t.state = txActive
 		return nil
 	}
-	if shouldDropSession(err) {
+	if isSessionNotFoundError(err) {
 		t.sh.destroy()
 	}
 	return err
@@ -869,7 +869,7 @@
 	if tstamp := res.GetCommitTimestamp(); tstamp != nil {
 		ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
 	}
-	if shouldDropSession(err) {
+	if isSessionNotFoundError(err) {
 		t.sh.destroy()
 	}
 	return ts, err
@@ -892,7 +892,7 @@
 		Session:       sid,
 		TransactionId: t.tx,
 	})
-	if shouldDropSession(err) {
+	if isSessionNotFoundError(err) {
 		t.sh.destroy()
 	}
 }
@@ -914,6 +914,10 @@
 			// one's wound-wait priority.
 			return ts, err
 		}
+		if isSessionNotFoundError(err) {
+			t.sh.destroy()
+			return ts, err
+		}
 		// Not going to commit, according to API spec, should rollback the
 		// transaction.
 		t.rollback(ctx)
@@ -973,7 +977,7 @@
 			Mutations: mPb,
 		}, gax.WithGRPCOptions(grpc.Trailer(&trailers)))
 		if err != nil && !isAbortErr(err) {
-			if shouldDropSession(err) {
+			if isSessionNotFoundError(err) {
 				// Discard the bad session.
 				sh.destroy()
 			}