spanner: return transaction outcome unknown

If a DEADLINE_EXCEEDED or CANCELED error occurs while a COMMIT request
is in flight, the outcome of the transaction is unknown as the request
might have been received and processed by the server. The client
library now returns a Spanner error with a
TransactionOutcomeUnknownError error wrapped when this happens. A user
application can check specifically for this condition by checking for
the presence of such an error in the error chain:

_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
	tx.BufferWrite([]*Mutation{
		Insert("FOO", []string{"ID", "NAME"}, []interface{}{int64(1), "bar"}),
	})
	return nil
})
var outcomeUnknown *TransactionOutcomeUnknownError
if errorAs(err, &outcomeUnknown) {
	// DEADLINE_EXCEEDED or CANCELED occurred during commit.
	// The outcome of the transaction cannot be guaranteed.
	// Do custom error handling for this specific case.
}

Fixes #1781.

Change-Id: Iff5b2eb89b738d23c2a0fd1cc0418f38c736beb4
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52370
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
diff --git a/spanner/client_test.go b/spanner/client_test.go
index fd294ca..d83122a 100644
--- a/spanner/client_test.go
+++ b/spanner/client_test.go
@@ -1449,3 +1449,35 @@
 		}
 	}
 }
+
+func TestReadWriteTransaction_ContextTimeoutDuringDuringCommit(t *testing.T) {
+	t.Parallel()
+	server, client, teardown := setupMockedTestServer(t)
+	defer teardown()
+	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
+		SimulatedExecutionTime{
+			MinimumExecutionTime: time.Minute,
+		})
+	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+	defer cancel()
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		tx.BufferWrite([]*Mutation{Insert("FOO", []string{"ID", "NAME"}, []interface{}{int64(1), "bar"})})
+		return nil
+	})
+	errContext, _ := context.WithTimeout(context.Background(), -time.Second)
+	w := toSpannerErrorWithMetadata(errContext.Err(), nil, true).(*Error)
+	var se *Error
+	if !errorAs(err, &se) {
+		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, w)
+	}
+	if se.GRPCStatus().Code() != w.GRPCStatus().Code() {
+		t.Fatalf("Error status mismatch:\nGot: %v\nWant: %v", se.GRPCStatus(), w.GRPCStatus())
+	}
+	if se.Error() != w.Error() {
+		t.Fatalf("Error message mismatch:\nGot %s\nWant: %s", se.Error(), w.Error())
+	}
+	var outcome *TransactionOutcomeUnknownError
+	if !errorAs(err, &outcome) {
+		t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error")
+	}
+}
diff --git a/spanner/errors.go b/spanner/errors.go
index 82f39e8..21c8be0 100644
--- a/spanner/errors.go
+++ b/spanner/errors.go
@@ -46,6 +46,25 @@
 	additionalInformation string
 }
 
+// TransactionOutcomeUnknownError is wrapped in a Spanner error when the error
+// occurred during a transaction, and the outcome of the transaction is
+// unknown as a result of the error. This could be the case if a timeout or
+// canceled error occurs after a Commit request has been sent, but before the
+// client has received a response from the server.
+type TransactionOutcomeUnknownError struct {
+	// err is the wrapped error that caused this TransactionOutcomeUnknownError
+	// error. The wrapped error can be read with the Unwrap method.
+	err error
+}
+
+const transactionOutcomeUnknownMsg = "transaction outcome unknown"
+
+// Error implements error.Error.
+func (*TransactionOutcomeUnknownError) Error() string { return transactionOutcomeUnknownMsg }
+
+// Unwrap returns the wrapped error (if any).
+func (e *TransactionOutcomeUnknownError) Unwrap() error { return e.err }
+
 // Error implements error.Error.
 func (e *Error) Error() string {
 	if e == nil {
@@ -100,14 +119,14 @@
 
 // toSpannerError converts general Go error to *spanner.Error.
 func toSpannerError(err error) error {
-	return toSpannerErrorWithMetadata(err, nil)
+	return toSpannerErrorWithMetadata(err, nil, false)
 }
 
 // toSpannerErrorWithMetadata converts general Go error and grpc trailers to
 // *spanner.Error.
 //
 // Note: modifies original error if trailers aren't nil.
-func toSpannerErrorWithMetadata(err error, trailers metadata.MD) error {
+func toSpannerErrorWithMetadata(err error, trailers metadata.MD, errorDuringCommit bool) error {
 	if err == nil {
 		return nil
 	}
@@ -120,11 +139,24 @@
 	}
 	switch {
 	case err == context.DeadlineExceeded || err == context.Canceled:
-		return &Error{status.FromContextError(err).Code(), status.FromContextError(err).Err(), err.Error(), trailers, ""}
+		desc := err.Error()
+		wrapped := status.FromContextError(err).Err()
+		if errorDuringCommit {
+			desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
+			wrapped = &TransactionOutcomeUnknownError{err: wrapped}
+		}
+		return &Error{status.FromContextError(err).Code(), wrapped, desc, trailers, ""}
 	case status.Code(err) == codes.Unknown:
 		return &Error{codes.Unknown, err, err.Error(), trailers, ""}
 	default:
-		return &Error{status.Convert(err).Code(), err, status.Convert(err).Message(), trailers, ""}
+		statusErr := status.Convert(err)
+		code, desc := statusErr.Code(), statusErr.Message()
+		wrapped := err
+		if errorDuringCommit && (code == codes.DeadlineExceeded || code == codes.Canceled) {
+			desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
+			wrapped = &TransactionOutcomeUnknownError{err: wrapped}
+		}
+		return &Error{code, wrapped, desc, trailers, ""}
 	}
 }
 
diff --git a/spanner/errors_test.go b/spanner/errors_test.go
index 31ebf46..f6a3385 100644
--- a/spanner/errors_test.go
+++ b/spanner/errors_test.go
@@ -19,6 +19,7 @@
 import (
 	"context"
 	"errors"
+	"strings"
 	"testing"
 
 	"google.golang.org/grpc/codes"
@@ -40,25 +41,30 @@
 
 func TestToSpannerError(t *testing.T) {
 	for _, test := range []struct {
-		err      error
-		wantCode codes.Code
-		wantMsg  string
+		err              error
+		wantCode         codes.Code
+		wantMsg          string
+		wantWrappedError error
 	}{
-		{errors.New("wha?"), codes.Unknown, `spanner: code = "Unknown", desc = "wha?"`},
-		{context.Canceled, codes.Canceled, `spanner: code = "Canceled", desc = "context canceled"`},
-		{context.DeadlineExceeded, codes.DeadlineExceeded, `spanner: code = "DeadlineExceeded", desc = "context deadline exceeded"`},
-		{status.Errorf(codes.ResourceExhausted, "so tired"), codes.ResourceExhausted, `spanner: code = "ResourceExhausted", desc = "so tired"`},
-		{spannerErrorf(codes.InvalidArgument, "bad"), codes.InvalidArgument, `spanner: code = "InvalidArgument", desc = "bad"`},
+		{errors.New("wha?"), codes.Unknown, `spanner: code = "Unknown", desc = "wha?"`, errors.New("wha?")},
+		{context.Canceled, codes.Canceled, `spanner: code = "Canceled", desc = "context canceled"`, status.Errorf(codes.Canceled, "context canceled")},
+		{context.DeadlineExceeded, codes.DeadlineExceeded, `spanner: code = "DeadlineExceeded", desc = "context deadline exceeded"`, status.Errorf(codes.DeadlineExceeded, "context deadline exceeded")},
+		{status.Errorf(codes.ResourceExhausted, "so tired"), codes.ResourceExhausted, `spanner: code = "ResourceExhausted", desc = "so tired"`, status.Errorf(codes.ResourceExhausted, "so tired")},
+		{spannerErrorf(codes.InvalidArgument, "bad"), codes.InvalidArgument, `spanner: code = "InvalidArgument", desc = "bad"`, status.Errorf(codes.InvalidArgument, "bad")},
 		{&wrappedTestError{
 			wrapped: spannerErrorf(codes.Aborted, "Transaction aborted"),
 			msg:     "error with wrapped Spanner error",
-		}, codes.Aborted, `spanner: code = "Aborted", desc = "Transaction aborted"`},
+		}, codes.Aborted, `spanner: code = "Aborted", desc = "Transaction aborted"`, status.Errorf(codes.Aborted, "Transaction aborted")},
 		{&wrappedTestError{
 			wrapped: errors.New("wha?"),
 			msg:     "error with wrapped non-gRPC and non-Spanner error",
-		}, codes.Unknown, `spanner: code = "Unknown", desc = "error with wrapped non-gRPC and non-Spanner error"`},
+		}, codes.Unknown, `spanner: code = "Unknown", desc = "error with wrapped non-gRPC and non-Spanner error"`,
+			&wrappedTestError{
+				wrapped: errors.New("wha?"),
+				msg:     "error with wrapped non-gRPC and non-Spanner error"}},
 	} {
 		err := toSpannerError(test.err)
+		errDuringCommit := toSpannerErrorWithMetadata(test.err, nil, true)
 		if got, want := ErrCode(err), test.wantCode; got != want {
 			t.Errorf("%v: got %s, want %s", test.err, got, want)
 		}
@@ -69,5 +75,29 @@
 		if got, want := err.Error(), test.wantMsg; got != want {
 			t.Errorf("%v: got msg %s, want mgs %s", test.err, got, want)
 		}
+		if got, want := err.(*Error).err, test.wantWrappedError; got.Error() != want.Error() {
+			t.Errorf("%v: Wrapped mismatch\nGot: %v\nWant: %v", test.err, got, want)
+		}
+		code := status.Code(errDuringCommit)
+		gotWrappedDuringCommit := errDuringCommit.(*Error).err
+		// Only DEADLINE_EXCEEDED and CANCELED should indicate that the
+		// transaction outcome is unknown.
+		if code == codes.DeadlineExceeded || code == codes.Canceled {
+			if !strings.Contains(errDuringCommit.Error(), transactionOutcomeUnknownMsg) {
+				t.Errorf(`%v: Missing %q from error during commit.\nGot: %v`, test.err, transactionOutcomeUnknownMsg, errDuringCommit)
+			}
+			wantWrappedDuringCommit := &TransactionOutcomeUnknownError{}
+			if gotWrappedDuringCommit.Error() != wantWrappedDuringCommit.Error() {
+				t.Errorf("%v: Wrapped commit error mismatch\nGot: %v\nWant: %v", test.err, gotWrappedDuringCommit, wantWrappedDuringCommit)
+			}
+		} else {
+			if strings.Contains(errDuringCommit.Error(), transactionOutcomeUnknownMsg) {
+				t.Errorf(`%v: Got unexpected %q in error during commit.\nGot: %v`, test.err, transactionOutcomeUnknownMsg, errDuringCommit)
+			}
+			wantWrappedDuringCommit := test.wantWrappedError
+			if gotWrappedDuringCommit.Error() != wantWrappedDuringCommit.Error() {
+				t.Errorf("%v: Wrapped commit error mismatch\nGot: %v\nWant: %v", test.err, gotWrappedDuringCommit, wantWrappedDuringCommit)
+			}
+		}
 	}
 }
diff --git a/spanner/retry_test.go b/spanner/retry_test.go
index 40ca438..88f382a 100644
--- a/spanner/retry_test.go
+++ b/spanner/retry_test.go
@@ -36,7 +36,7 @@
 	trailers := map[string]string{
 		retryInfoKey: string(b),
 	}
-	gotDelay, ok := extractRetryDelay(toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, ""), metadata.New(trailers)))
+	gotDelay, ok := extractRetryDelay(toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, ""), metadata.New(trailers), true))
 	if !ok || !testEqual(time.Second, gotDelay) {
 		t.Errorf("<ok, retryDelay> = <%t, %v>, want <true, %v>", ok, gotDelay, time.Second)
 	}
@@ -52,7 +52,7 @@
 		retryInfoKey: string(b),
 	}
 	retryer := onCodes(gax.Backoff{}, codes.Aborted)
-	err := toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, "transaction was aborted"), metadata.New(trailers))
+	err := toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, "transaction was aborted"), metadata.New(trailers), true)
 	maxSeenDelay, shouldRetry := retryer.Retry(err)
 	if !shouldRetry {
 		t.Fatalf("expected shouldRetry to be true")
diff --git a/spanner/transaction.go b/spanner/transaction.go
index a42916b..dc3b436 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -918,7 +918,7 @@
 		Mutations: mPb,
 	}, gax.WithGRPCOptions(grpc.Trailer(&trailer)))
 	if e != nil {
-		return ts, toSpannerErrorWithMetadata(e, trailer)
+		return ts, toSpannerErrorWithMetadata(e, trailer, true)
 	}
 	if tstamp := res.GetCommitTimestamp(); tstamp != nil {
 		ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
@@ -1036,7 +1036,7 @@
 				// Discard the bad session.
 				sh.destroy()
 			}
-			return ts, toSpannerError(err)
+			return ts, toSpannerErrorWithMetadata(err, trailers, true)
 		} else if err == nil {
 			if tstamp := res.GetCommitTimestamp(); tstamp != nil {
 				ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go
index 3fb3955..45a8121 100644
--- a/spanner/transaction_test.go
+++ b/spanner/transaction_test.go
@@ -239,7 +239,11 @@
 		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
 		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
 	}
-	if _, got := client.Apply(ctx, ms, ApplyAtLeastOnce()); !testEqual(wantErr, got) {
+	_, got := client.Apply(ctx, ms, ApplyAtLeastOnce())
+	// Remove any trailers sent by the mock server to prevent the comparison to
+	// fail on that.
+	got.(*Error).trailers = nil
+	if !testEqual(wantErr, got) {
 		t.Fatalf("Expect Apply to fail, got %v, want %v.", got, wantErr)
 	}
 }