spanner: Stop prepare sessions on error

The background process for preparing sessions for write
transactions should stop on any server error except
'Session not found'. The process should automatically
be re-enabled if a call to BeginTransaction succeeds.

Fixes #1687.

Change-Id: I1bcab4531f869da59ee0d9ecec3ecc2c419f9f72
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/49030
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/spanner/session.go b/spanner/session.go
index 42a8fea..1dc3268 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -308,8 +308,19 @@
 		return nil
 	}
 	tx, err := beginTransaction(contextWithOutgoingMetadata(ctx, s.md), s.getID(), s.client)
+	// Session not found should cause the session to be removed from the pool.
+	if isSessionNotFoundError(err) {
+		s.pool.remove(s, false)
+		s.pool.hc.unregister(s)
+		return err
+	}
+	// Enable/disable background preparing of write sessions depending on
+	// whether the BeginTransaction call succeeded. This will prevent the
+	// session pool workers from going into an infinite loop of trying to
+	// prepare sessions. Any subsequent successful BeginTransaction call from
+	// for example takeWriteSession will re-enable the background process.
 	s.pool.mu.Lock()
-	s.pool.disableBackgroundPrepareSessions = isPermissionDeniedError(err) || isDatabaseNotFoundError(err)
+	s.pool.disableBackgroundPrepareSessions = err != nil
 	s.pool.mu.Unlock()
 	if err != nil {
 		return err
@@ -1419,18 +1430,12 @@
 	return a
 }
 
-// isPermissionDeniedError returns true if the given error has code
-// PermissionDenied.
-func isPermissionDeniedError(err error) bool {
-	return ErrCode(err) == codes.PermissionDenied
-}
-
-// isDatabaseNotFoundError returns true if the given error is a
-// `Database not found` error.
-func isDatabaseNotFoundError(err error) bool {
-	// We are checking specifically for the error message `Database not found`,
-	// as the error could also be a `Session not found`. The former should
+// isSessionNotFoundError returns true if the given error is a
+// `Session not found` error.
+func isSessionNotFoundError(err error) bool {
+	// We are checking specifically for the error message `Session not found`,
+	// as the error could also be a `Database not found`. The latter should
 	// cause the session pool to stop preparing sessions for read/write
-	// transactions, while the latter should not.
-	return ErrCode(err) == codes.NotFound && strings.Contains(err.Error(), "Database not found")
+	// transactions, while the former should not.
+	return ErrCode(err) == codes.NotFound && strings.Contains(err.Error(), "Session not found")
 }
diff --git a/spanner/session_test.go b/spanner/session_test.go
index 33941eb..83d0f10 100644
--- a/spanner/session_test.go
+++ b/spanner/session_test.go
@@ -923,16 +923,23 @@
 }
 
 // The session pool should stop trying to create write-prepared sessions if a
-// permanent error occurs while trying to begin a transaction. Possible
-// permanent errors are PermissionDenied or `Database not found`.
-func TestPermanentErrorOnPrepareSession(t *testing.T) {
+// non-transient error occurs while trying to begin a transaction. The
+// process for preparing write sessions should automatically be re-enabled if
+// a BeginTransaction call initiated by takeWriteSession succeeds.
+//
+// The only exception to the above is that a 'Session not found' error should
+// cause the session to be removed from the session pool, and it should not
+// affect the background process of preparing sessions.
+func TestErrorOnPrepareSession(t *testing.T) {
 	t.Parallel()
 
-	permanentErrors := []error{
+	serverErrors := []error{
 		status.Errorf(codes.PermissionDenied, "Caller is missing IAM permission spanner.databases.beginOrRollbackReadWriteTransaction on resource"),
 		status.Errorf(codes.NotFound, `Database not found: projects/<project>/instances/<instance>/databases/<database> resource_type: "type.googleapis.com/google.spanner.admin.database.v1.Database" resource_name: "projects/<project>/instances/<instance>/databases/<database>" description: "Database does not exist."`),
+		status.Errorf(codes.FailedPrecondition, "Invalid transaction option"),
+		status.Errorf(codes.Internal, "Unknown server error"),
 	}
-	for _, permanentError := range permanentErrors {
+	for _, serverErr := range serverErrors {
 		ctx := context.Background()
 		server, client, teardown := setupMockedTestServerWithConfig(t,
 			ClientConfig{
@@ -945,7 +952,7 @@
 			})
 		defer teardown()
 		server.TestSpanner.PutExecutionTime(MethodBeginTransaction, SimulatedExecutionTime{
-			Errors:    []error{permanentError},
+			Errors:    []error{serverErr},
 			KeepError: true,
 		})
 		sp := client.idleSessions
@@ -956,10 +963,11 @@
 		waitUntil := time.After(time.Second)
 		var prepareDisabled bool
 		var numOpened int
+	waitForPrepare:
 		for !prepareDisabled || numOpened < 10 {
 			select {
 			case <-waitUntil:
-				break
+				break waitForPrepare
 			default:
 			}
 			sp.mu.Lock()
@@ -986,14 +994,14 @@
 			t.Fatalf("cannot get session from session pool: %v", err)
 		}
 		sh.recycle()
-		// Take a write session should fail with the permanent error.
+		// Take a write session should fail with the server error.
 		_, err = sp.takeWriteSession(ctx)
-		if ErrCode(err) != ErrCode(permanentError) {
-			t.Fatalf("take write session failed with unexpected error.\nGot: %v\nWant: %v\n", err, permanentError)
+		if ErrCode(err) != ErrCode(serverErr) {
+			t.Fatalf("take write session failed with unexpected error.\nGot: %v\nWant: %v\n", err, serverErr)
 		}
 
-		// Clearing the error on the server (or granting the permission to the
-		// credentials in use) should allow us to take a write session.
+		// Clearing the error on the server should allow us to take a write
+		// session.
 		server.TestSpanner.PutExecutionTime(MethodBeginTransaction, SimulatedExecutionTime{})
 		sh, err = sp.takeWriteSession(ctx)
 		if err != nil {
@@ -1022,6 +1030,83 @@
 	}
 }
 
+// The session pool should continue to try to create write-prepared sessions if
+// a 'Session not found' error occurs. The session that has been deleted by
+// backend should be removed from the pool, and the maintainer should create a
+// new session if this causes the number of sessions in the pool to fall below
+// MinOpened.
+func TestSessionNotFoundOnPrepareSession(t *testing.T) {
+	t.Parallel()
+
+	// The server will return 'Session not found' for the first 8
+	// BeginTransaction calls.
+	sessionNotFoundErr := status.Errorf(codes.NotFound, `Session not found: projects/<project>/instances/<instance>/databases/<database>/sessions/<session> resource_type: "Session" resource_name: "projects/<project>/instances/<instance>/databases/<database>/sessions/<session>" description: "Session does not exist."`)
+	serverErrors := make([]error, 8)
+	for i := range serverErrors {
+		serverErrors[i] = sessionNotFoundErr
+	}
+	ctx := context.Background()
+	server, client, teardown := setupMockedTestServerWithConfig(t,
+		ClientConfig{
+			SessionPoolConfig: SessionPoolConfig{
+				MinOpened:                 10,
+				MaxOpened:                 10,
+				WriteSessions:             0.5,
+				HealthCheckInterval:       time.Millisecond,
+				healthCheckSampleInterval: time.Millisecond,
+			},
+		})
+	defer teardown()
+	server.TestSpanner.PutExecutionTime(MethodBeginTransaction, SimulatedExecutionTime{
+		Errors: serverErrors,
+	})
+	sp := client.idleSessions
+
+	// Wait until the health checker has tried to write-prepare the sessions.
+	// This will cause the session pool to write some errors to the log that
+	// preparing sessions failed.
+	waitUntil := time.After(time.Second)
+	var numWriteSessions int
+	var numReadSessions int
+waitForPrepare:
+	for (numWriteSessions+numReadSessions) < 10 || numWriteSessions < 5 {
+		select {
+		case <-waitUntil:
+			break waitForPrepare
+		default:
+		}
+		sp.mu.Lock()
+		numReadSessions = sp.idleList.Len()
+		numWriteSessions = sp.idleWriteList.Len()
+		sp.mu.Unlock()
+	}
+
+	// There should be at least 5 write-prepared sessions.
+	sp.mu.Lock()
+	if g, w := sp.idleWriteList.Len(), 5; g < w {
+		sp.mu.Unlock()
+		t.Fatalf("write-prepared session count mismatch.\nWant at least: %v\nGot: %v", w, g)
+	}
+	// The other sessions should be in the read idle list.
+	if g, w := sp.idleList.Len()+sp.idleWriteList.Len(), 10; g != w {
+		sp.mu.Unlock()
+		t.Fatalf("total session count mismatch:\nWant: %v\nGot: %v", w, g)
+	}
+	sp.mu.Unlock()
+	// Take a read session should succeed.
+	sh, err := sp.take(ctx)
+	if err != nil {
+		t.Fatalf("cannot get session from session pool: %v", err)
+	}
+	sh.recycle()
+	// Take a write session should succeed.
+	sh, err = sp.takeWriteSession(ctx)
+	if err != nil {
+		t.Fatalf("take write session failed with unexpected error.\nGot: %v\nWant: %v\n", err, nil)
+	}
+	sh.recycle()
+}
+
 // TestSessionHealthCheck tests healthchecking cases.
 func TestSessionHealthCheck(t *testing.T) {
 	t.Parallel()
diff --git a/spanner/transaction.go b/spanner/transaction.go
index 5eace9e..604c877 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -813,6 +813,9 @@
 	if err != nil {
 		return nil, err
 	}
+	if res.Id == nil {
+		return nil, spannerErrorf(codes.Unknown, "BeginTransaction returned a transaction with a nil ID.")
+	}
 	return res.Id, nil
 }