spanner: do not rollback after failed commit

A failed commit should not be followed by a Rollback RPC as
that RPC would also fail. While it does not break anything,
it does add noise to logs.

Fixes #1772.

Change-Id: I2cb2e7fdbab2e2d44a4aa9a929443741f3a0d8a3
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52770
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: David Symonds <dsymonds@golang.org>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
diff --git a/spanner/client_test.go b/spanner/client_test.go
index d83122a..047d454 100644
--- a/spanner/client_test.go
+++ b/spanner/client_test.go
@@ -1481,3 +1481,65 @@
 		t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error")
 	}
 }
+
+func TestFailedCommit_NoRollback(t *testing.T) {
+	t.Parallel()
+	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
+		SessionPoolConfig: SessionPoolConfig{
+			MinOpened:     0,
+			MaxOpened:     1,
+			WriteSessions: 0,
+		},
+	})
+	defer teardown()
+	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
+		SimulatedExecutionTime{
+			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid mutations")},
+		})
+	_, err := client.Apply(context.Background(), []*Mutation{
+		Insert("FOO", []string{"ID", "BAR"}, []interface{}{1, "value"}),
+	})
+	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
+		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
+	}
+	// The failed commit should not trigger a rollback after the commit.
+	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
+		&sppb.CreateSessionRequest{},
+		&sppb.BeginTransactionRequest{},
+		&sppb.CommitRequest{},
+	}); err != nil {
+		t.Fatalf("Received RPCs mismatch: %v", err)
+	}
+}
+
+func TestFailedUpdate_ShouldRollback(t *testing.T) {
+	t.Parallel()
+	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
+		SessionPoolConfig: SessionPoolConfig{
+			MinOpened:     0,
+			MaxOpened:     1,
+			WriteSessions: 0,
+		},
+	})
+	defer teardown()
+	server.TestSpanner.PutExecutionTime(MethodExecuteSql,
+		SimulatedExecutionTime{
+			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid update")},
+		})
+	_, err := client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *ReadWriteTransaction) error {
+		_, err := tx.Update(ctx, NewStatement("UPDATE FOO SET BAR='value' WHERE ID=1"))
+		return err
+	})
+	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
+		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
+	}
+	// The failed update should trigger a rollback.
+	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
+		&sppb.CreateSessionRequest{},
+		&sppb.BeginTransactionRequest{},
+		&sppb.ExecuteSqlRequest{},
+		&sppb.RollbackRequest{},
+	}); err != nil {
+		t.Fatalf("Received RPCs mismatch: %v", err)
+	}
+}
diff --git a/spanner/transaction.go b/spanner/transaction.go
index dc3b436..7b761f7 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -954,12 +954,14 @@
 // runInTransaction executes f under a read-write transaction context.
 func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(context.Context, *ReadWriteTransaction) error) (time.Time, error) {
 	var (
-		ts  time.Time
-		err error
+		ts              time.Time
+		err             error
+		errDuringCommit bool
 	)
 	if err = f(context.WithValue(ctx, transactionInProgressKey{}, 1), t); err == nil {
 		// Try to commit if transaction body returns no error.
 		ts, err = t.commit(ctx)
+		errDuringCommit = err != nil
 	}
 	if err != nil {
 		if isAbortErr(err) {
@@ -972,9 +974,15 @@
 			t.sh.destroy()
 			return ts, err
 		}
-		// Not going to commit, according to API spec, should rollback the
-		// transaction.
-		t.rollback(ctx)
+		// Rollback the transaction unless the error occurred during the
+		// commit. Executing a rollback after a commit has failed will
+		// otherwise cause an error. Note that transient errors, such as
+		// UNAVAILABLE, are already handled in the gRPC layer and do not show
+		// up here. Context errors (deadline exceeded / canceled) during
+		// commits are also not rolled back.
+		if !errDuringCommit {
+			t.rollback(ctx)
+		}
 		return ts, err
 	}
 	// err == nil, return commit timestamp.