spanner: extract retry info from status

The Spanner client used to extract RetryInfo from the trailers of a
gRPC request. This meant that an extra option had to be added to the
gRPC call to ensure that the trailers were parsed. It also meant that
these trailers needed to be kept in a separate field in spanner.Error.

RetryInfo and other specific error details are however also included in
the wrapped statusError. Instead of getting this information from the
trailers of the request, the Spanner client should get it directly from
the wrapped statusError. This makes it less error prone, as we don't
have to specify extra options for the RPCs where we might want trailers
to be parsed. It also prepares the Spanner client for getting other
additional information from the wrapped statusError, such as additional
information on the type of resource that was not found. This will
allow us to stop parsing textual error messages to determine whether an
error was a 'Session not found' error, and instead use the details from
the statusError.

Fixes #1813.

Change-Id: I9fab63c5f2e3c8d632f136fe3822c170318c5d78
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52790
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
diff --git a/spanner/client_test.go b/spanner/client_test.go
index 047d454..13e239f 100644
--- a/spanner/client_test.go
+++ b/spanner/client_test.go
@@ -1465,7 +1465,7 @@
 		return nil
 	})
 	errContext, _ := context.WithTimeout(context.Background(), -time.Second)
-	w := toSpannerErrorWithMetadata(errContext.Err(), nil, true).(*Error)
+	w := toSpannerErrorWithCommitInfo(errContext.Err(), true).(*Error)
 	var se *Error
 	if !errorAs(err, &se) {
 		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, w)
diff --git a/spanner/cmp_test.go b/spanner/cmp_test.go
index 178ac5e..0e94edc 100644
--- a/spanner/cmp_test.go
+++ b/spanner/cmp_test.go
@@ -24,6 +24,6 @@
 // TODO(deklerk): move this to internal/testutil
 func testEqual(a, b interface{}) bool {
 	return testutil.Equal(a, b,
-		cmp.AllowUnexported(TimestampBound{}, Error{}, Mutation{}, Row{},
-			Partition{}, BatchReadOnlyTransactionID{}))
+		cmp.AllowUnexported(TimestampBound{}, Error{}, TransactionOutcomeUnknownError{},
+			Mutation{}, Row{}, Partition{}, BatchReadOnlyTransactionID{}))
 }
diff --git a/spanner/errors.go b/spanner/errors.go
index 21c8be0..2c6c7ef 100644
--- a/spanner/errors.go
+++ b/spanner/errors.go
@@ -21,7 +21,6 @@
 	"fmt"
 
 	"google.golang.org/grpc/codes"
-	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/status"
 )
 
@@ -39,8 +38,6 @@
 	err error
 	// Desc explains more details of the error.
 	Desc string
-	// trailers are the trailers returned in the response, if any.
-	trailers metadata.MD
 	// additionalInformation optionally contains any additional information
 	// about the error.
 	additionalInformation string
@@ -119,22 +116,19 @@
 
 // toSpannerError converts general Go error to *spanner.Error.
 func toSpannerError(err error) error {
-	return toSpannerErrorWithMetadata(err, nil, false)
+	return toSpannerErrorWithCommitInfo(err, false)
 }
 
-// toSpannerErrorWithMetadata converts general Go error and grpc trailers to
-// *spanner.Error.
+// toSpannerErrorWithCommitInfo converts general Go error to *spanner.Error
+// with additional information if the error occurred during a Commit request.
 //
-// Note: modifies original error if trailers aren't nil.
-func toSpannerErrorWithMetadata(err error, trailers metadata.MD, errorDuringCommit bool) error {
+// If err is already a *spanner.Error, err is returned unmodified.
+func toSpannerErrorWithCommitInfo(err error, errorDuringCommit bool) error {
 	if err == nil {
 		return nil
 	}
 	var se *Error
 	if errorAs(err, &se) {
-		if trailers != nil {
-			se.trailers = metadata.Join(se.trailers, trailers)
-		}
 		return se
 	}
 	switch {
@@ -145,9 +139,9 @@
 			desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
 			wrapped = &TransactionOutcomeUnknownError{err: wrapped}
 		}
-		return &Error{status.FromContextError(err).Code(), wrapped, desc, trailers, ""}
+		return &Error{status.FromContextError(err).Code(), wrapped, desc, ""}
 	case status.Code(err) == codes.Unknown:
-		return &Error{codes.Unknown, err, err.Error(), trailers, ""}
+		return &Error{codes.Unknown, err, err.Error(), ""}
 	default:
 		statusErr := status.Convert(err)
 		code, desc := statusErr.Code(), statusErr.Message()
@@ -156,7 +150,7 @@
 			desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
 			wrapped = &TransactionOutcomeUnknownError{err: wrapped}
 		}
-		return &Error{code, wrapped, desc, trailers, ""}
+		return &Error{code, wrapped, desc, ""}
 	}
 }
 
@@ -177,12 +171,3 @@
 	}
 	return se.Desc
 }
-
-// errTrailers extracts the grpc trailers if present from a Go error.
-func errTrailers(err error) metadata.MD {
-	var se *Error
-	if !errorAs(err, &se) {
-		return nil
-	}
-	return se.trailers
-}
diff --git a/spanner/errors_test.go b/spanner/errors_test.go
index f6a3385..c1e90f7 100644
--- a/spanner/errors_test.go
+++ b/spanner/errors_test.go
@@ -64,7 +64,7 @@
 				msg:     "error with wrapped non-gRPC and non-Spanner error"}},
 	} {
 		err := toSpannerError(test.err)
-		errDuringCommit := toSpannerErrorWithMetadata(test.err, nil, true)
+		errDuringCommit := toSpannerErrorWithCommitInfo(test.err, true)
 		if got, want := ErrCode(err), test.wantCode; got != want {
 			t.Errorf("%v: got %s, want %s", test.err, got, want)
 		}
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index d98d2e1..5fe62d7 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -1635,6 +1635,11 @@
 			if expectAbort && !isAbortErr(e) {
 				t.Errorf("ReadRow got %v, want Abort error.", e)
 			}
+			// Verify that we received and are able to extract retry info from
+			// the aborted error.
+			if _, hasRetryInfo := extractRetryDelay(e); !hasRetryInfo {
+				t.Errorf("Got Abort error without RetryInfo\nGot: %v", e)
+			}
 			return b, e
 		}
 		if ce := r.Column(0, &b); ce != nil {
diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go
index eb1c870..f9c35fb 100644
--- a/spanner/internal/testutil/inmem_spanner_server.go
+++ b/spanner/internal/testutil/inmem_spanner_server.go
@@ -24,9 +24,11 @@
 	"sync"
 	"time"
 
+	"github.com/golang/protobuf/ptypes"
 	emptypb "github.com/golang/protobuf/ptypes/empty"
 	structpb "github.com/golang/protobuf/ptypes/struct"
 	"github.com/golang/protobuf/ptypes/timestamp"
+	"google.golang.org/genproto/googleapis/rpc/errdetails"
 	"google.golang.org/genproto/googleapis/rpc/status"
 	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
 	"google.golang.org/grpc/codes"
@@ -495,11 +497,20 @@
 	}
 	aborted, ok := s.abortedTransactions[string(id)]
 	if ok && aborted {
-		return nil, gstatus.Error(codes.Aborted, "Transaction has been aborted")
+		return nil, newAbortedErrorWithMinimalRetryDelay()
 	}
 	return tx, nil
 }
 
+func newAbortedErrorWithMinimalRetryDelay() error {
+	st := gstatus.New(codes.Aborted, "Transaction has been aborted")
+	retry := &errdetails.RetryInfo{
+		RetryDelay: ptypes.DurationProto(time.Nanosecond),
+	}
+	st, _ = st.WithDetails(retry)
+	return st.Err()
+}
+
 func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
diff --git a/spanner/retry.go b/spanner/retry.go
index 75231af..db9ff2a 100644
--- a/spanner/retry.go
+++ b/spanner/retry.go
@@ -21,12 +21,10 @@
 	"time"
 
 	"cloud.google.com/go/internal/trace"
-	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/ptypes"
 	"github.com/googleapis/gax-go/v2"
-	edpb "google.golang.org/genproto/googleapis/rpc/errdetails"
+	"google.golang.org/genproto/googleapis/rpc/errdetails"
 	"google.golang.org/grpc/codes"
-	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/status"
 )
 
@@ -118,27 +116,27 @@
 	return funcWithRetry(ctx)
 }
 
-// extractRetryDelay extracts retry backoff if present.
+// extractRetryDelay extracts retry backoff from a *spanner.Error if present.
 func extractRetryDelay(err error) (time.Duration, bool) {
-	trailers := errTrailers(err)
-	if trailers == nil {
+	var se *Error
+	var s *status.Status
+	// Unwrap status error.
+	if errorAs(err, &se) {
+		s = status.Convert(se.Unwrap())
+	} else {
+		s = status.Convert(err)
+	}
+	if s == nil {
 		return 0, false
 	}
-	elem, ok := trailers[retryInfoKey]
-	if !ok || len(elem) <= 0 {
-		return 0, false
+	for _, detail := range s.Details() {
+		if retryInfo, ok := detail.(*errdetails.RetryInfo); ok {
+			delay, err := ptypes.Duration(retryInfo.RetryDelay)
+			if err != nil {
+				return 0, false
+			}
+			return delay, true
+		}
 	}
-	_, b, err := metadata.DecodeKeyValue(retryInfoKey, elem[0])
-	if err != nil {
-		return 0, false
-	}
-	var retryInfo edpb.RetryInfo
-	if proto.Unmarshal([]byte(b), &retryInfo) != nil {
-		return 0, false
-	}
-	delay, err := ptypes.Duration(retryInfo.RetryDelay)
-	if err != nil {
-		return 0, false
-	}
-	return delay, true
+	return 0, false
 }
diff --git a/spanner/retry_test.go b/spanner/retry_test.go
index 88f382a..af9c9b2 100644
--- a/spanner/retry_test.go
+++ b/spanner/retry_test.go
@@ -17,42 +17,69 @@
 package spanner
 
 import (
+	"context"
 	"testing"
 	"time"
 
-	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/ptypes"
 	"github.com/googleapis/gax-go/v2"
 	edpb "google.golang.org/genproto/googleapis/rpc/errdetails"
 	"google.golang.org/grpc/codes"
-	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/status"
 )
 
 func TestRetryInfo(t *testing.T) {
-	b, _ := proto.Marshal(&edpb.RetryInfo{
+	s := status.New(codes.Aborted, "")
+	s, err := s.WithDetails(&edpb.RetryInfo{
 		RetryDelay: ptypes.DurationProto(time.Second),
 	})
-	trailers := map[string]string{
-		retryInfoKey: string(b),
+	if err != nil {
+		t.Fatalf("Error setting retry details: %v", err)
 	}
-	gotDelay, ok := extractRetryDelay(toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, ""), metadata.New(trailers), true))
+	gotDelay, ok := extractRetryDelay(toSpannerErrorWithCommitInfo(s.Err(), true))
 	if !ok || !testEqual(time.Second, gotDelay) {
 		t.Errorf("<ok, retryDelay> = <%t, %v>, want <true, %v>", ok, gotDelay, time.Second)
 	}
 }
 
+func TestRetryInfoInWrappedError(t *testing.T) {
+	s := status.New(codes.Aborted, "")
+	s, err := s.WithDetails(&edpb.RetryInfo{
+		RetryDelay: ptypes.DurationProto(time.Second),
+	})
+	if err != nil {
+		t.Fatalf("Error setting retry details: %v", err)
+	}
+	gotDelay, ok := extractRetryDelay(
+		&wrappedTestError{wrapped: toSpannerErrorWithCommitInfo(s.Err(), true), msg: "Error that is wrapping a Spanner error"},
+	)
+	if !ok || !testEqual(time.Second, gotDelay) {
+		t.Errorf("<ok, retryDelay> = <%t, %v>, want <true, %v>", ok, gotDelay, time.Second)
+	}
+}
+
+func TestRetryInfoTransactionOutcomeUnknownError(t *testing.T) {
+	err := toSpannerErrorWithCommitInfo(context.DeadlineExceeded, true)
+	if gotDelay, ok := extractRetryDelay(err); ok {
+		t.Errorf("Got unexpected delay\nGot: %v\nWant: %v", gotDelay, 0)
+	}
+	if !testEqual(err.(*Error).err, &TransactionOutcomeUnknownError{status.FromContextError(context.DeadlineExceeded).Err()}) {
+		t.Errorf("Missing expected TransactionOutcomeUnknownError wrapped error")
+	}
+}
+
 func TestRetryerRespectsServerDelay(t *testing.T) {
 	t.Parallel()
 	serverDelay := 50 * time.Millisecond
-	b, _ := proto.Marshal(&edpb.RetryInfo{
+	s := status.New(codes.Aborted, "transaction was aborted")
+	s, err := s.WithDetails(&edpb.RetryInfo{
 		RetryDelay: ptypes.DurationProto(serverDelay),
 	})
-	trailers := map[string]string{
-		retryInfoKey: string(b),
+	if err != nil {
+		t.Fatalf("Error setting retry details: %v", err)
 	}
 	retryer := onCodes(gax.Backoff{}, codes.Aborted)
-	err := toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, "transaction was aborted"), metadata.New(trailers), true)
+	err = toSpannerErrorWithCommitInfo(s.Err(), true)
 	maxSeenDelay, shouldRetry := retryer.Retry(err)
 	if !shouldRetry {
 		t.Fatalf("expected shouldRetry to be true")
diff --git a/spanner/transaction.go b/spanner/transaction.go
index 7b761f7..fb65f0e 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -24,12 +24,9 @@
 
 	"cloud.google.com/go/internal/trace"
 	vkit "cloud.google.com/go/spanner/apiv1"
-	"github.com/googleapis/gax-go/v2"
 	"google.golang.org/api/iterator"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
-	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
-	"google.golang.org/grpc/metadata"
 )
 
 // transactionID stores a transaction ID which uniquely identifies a transaction
@@ -909,16 +906,15 @@
 		return ts, errSessionClosed(t.sh)
 	}
 
-	var trailer metadata.MD
 	res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), &sppb.CommitRequest{
 		Session: sid,
 		Transaction: &sppb.CommitRequest_TransactionId{
 			TransactionId: t.tx,
 		},
 		Mutations: mPb,
-	}, gax.WithGRPCOptions(grpc.Trailer(&trailer)))
+	})
 	if e != nil {
-		return ts, toSpannerErrorWithMetadata(e, trailer, true)
+		return ts, toSpannerErrorWithCommitInfo(e, true)
 	}
 	if tstamp := res.GetCommitTimestamp(); tstamp != nil {
 		ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
@@ -1014,7 +1010,6 @@
 		return ts, err
 	}
 
-	var trailers metadata.MD
 	// Retry-loop for aborted transactions.
 	// TODO: Replace with generic retryer.
 	for {
@@ -1038,13 +1033,13 @@
 				},
 			},
 			Mutations: mPb,
-		}, gax.WithGRPCOptions(grpc.Trailer(&trailers)))
+		})
 		if err != nil && !isAbortErr(err) {
 			if isSessionNotFoundError(err) {
 				// Discard the bad session.
 				sh.destroy()
 			}
-			return ts, toSpannerErrorWithMetadata(err, trailers, true)
+			return ts, toSpannerErrorWithCommitInfo(err, true)
 		} else if err == nil {
 			if tstamp := res.GetCommitTimestamp(); tstamp != nil {
 				ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go
index 45a8121..88448fd 100644
--- a/spanner/transaction_test.go
+++ b/spanner/transaction_test.go
@@ -26,6 +26,8 @@
 	"time"
 
 	. "cloud.google.com/go/spanner/internal/testutil"
+	"github.com/golang/protobuf/ptypes"
+	"google.golang.org/genproto/googleapis/rpc/errdetails"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
 	"google.golang.org/grpc/codes"
 	gstatus "google.golang.org/grpc/status"
@@ -170,10 +172,9 @@
 	defer teardown()
 
 	// First commit will fail, and the retry will begin a new transaction.
-	errAbrt := gstatus.Errorf(codes.Aborted, "")
 	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
 		SimulatedExecutionTime{
-			Errors: []error{errAbrt},
+			Errors: []error{newAbortedErrorWithMinimalRetryDelay()},
 		})
 
 	ms := []*Mutation{
@@ -240,9 +241,6 @@
 		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
 	}
 	_, got := client.Apply(ctx, ms, ApplyAtLeastOnce())
-	// Remove any trailers sent by the mock server to prevent the comparison to
-	// fail on that.
-	got.(*Error).trailers = nil
 	if !testEqual(wantErr, got) {
 		t.Fatalf("Expect Apply to fail, got %v, want %v.", got, wantErr)
 	}
@@ -381,3 +379,12 @@
 	}
 	return reqs
 }
+
+func newAbortedErrorWithMinimalRetryDelay() error {
+	st := gstatus.New(codes.Aborted, "Transaction has been aborted")
+	retry := &errdetails.RetryInfo{
+		RetryDelay: ptypes.DurationProto(time.Nanosecond),
+	}
+	st, _ = st.WithDetails(retry)
+	return st.Err()
+}