spanner: Use Go 1.13 error-unwrapping

Use Go 1.13 error-unwrapping and the equivalent xerrors features
for builds on versions < 1.13. This makes it possible to use
wrapped errors with the Spanner client.

This change also deprecates the gRPC code in the Spanner error struct.
All functions that need the gRPC code will extract it from the wrapped
error instead of reading this field.

Fixes #1223 and #1608.
Updates #1310.

Change-Id: Iea914adb5ca78af5e78cc948d8b1180eb3d647d2
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/48730
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
diff --git a/spanner/client_test.go b/spanner/client_test.go
index 9e72358..8f08f38 100644
--- a/spanner/client_test.go
+++ b/spanner/client_test.go
@@ -48,10 +48,10 @@
 				Key: "x-goog-api-client",
 				ValuesValidator: func(token ...string) error {
 					if len(token) != 1 {
-						return spannerErrorf(codes.Internal, "unexpected number of api client token headers: %v", len(token))
+						return status.Errorf(codes.Internal, "unexpected number of api client token headers: %v", len(token))
 					}
 					if !strings.HasPrefix(token[0], "gl-go/") {
-						return spannerErrorf(codes.Internal, "unexpected api client token: %v", token[0])
+						return status.Errorf(codes.Internal, "unexpected api client token: %v", token[0])
 					}
 					return nil
 				},
@@ -144,7 +144,7 @@
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken: EncodeResumeToken(2),
-			Err:         spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"),
+			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
 		},
 	)
 	// When the client is fetching the partial result set with resume token 3,
@@ -154,7 +154,7 @@
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken: EncodeResumeToken(3),
-			Err:         spannerErrorf(codes.Unavailable, "server is unavailable"),
+			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
 		},
 	)
 	ctx := context.Background()
@@ -177,7 +177,7 @@
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken: EncodeResumeToken(2),
-			Err:         spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"),
+			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
 		},
 	)
 	// 'Session not found' is not retryable and the error will be returned to
@@ -186,7 +186,7 @@
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken: EncodeResumeToken(3),
-			Err:         spannerErrorf(codes.NotFound, "Session not found"),
+			Err:         status.Errorf(codes.NotFound, "Session not found"),
 		},
 	)
 	ctx := context.Background()
@@ -221,14 +221,14 @@
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken: EncodeResumeToken(2),
-			Err:         spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"),
+			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
 		},
 	)
 	server.TestSpanner.AddPartialResultSetError(
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken:   EncodeResumeToken(3),
-			Err:           spannerErrorf(codes.Unavailable, "server is unavailable"),
+			Err:           status.Errorf(codes.Unavailable, "server is unavailable"),
 			ExecutionTime: 50 * time.Millisecond,
 		},
 	)
@@ -262,14 +262,14 @@
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken: EncodeResumeToken(2),
-			Err:         spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"),
+			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
 		},
 	)
 	server.TestSpanner.AddPartialResultSetError(
 		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
 		PartialResultSetExecutionTime{
 			ResumeToken: EncodeResumeToken(3),
-			Err:         spannerErrorf(codes.Unavailable, "server is unavailable"),
+			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
 		},
 	)
 	ctx := context.Background()
@@ -328,7 +328,7 @@
 		}
 	}
 	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
-		return spannerErrorf(codes.Internal, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
+		return status.Errorf(codes.Internal, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
 	}
 	return nil
 }
@@ -537,7 +537,7 @@
 			rowCount++
 		}
 		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
-			return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
+			return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
 		}
 		return nil
 	})
@@ -569,6 +569,7 @@
 }
 
 func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) {
+	t.Parallel()
 	_, client, teardown := setupMockedTestServer(t)
 	defer teardown()
 	ctx := context.Background()
@@ -600,3 +601,76 @@
 		t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, 1)
 	}
 }
+
+func TestReadWriteTransaction_WrapError(t *testing.T) {
+	t.Parallel()
+	server, client, teardown := setupMockedTestServer(t)
+	defer teardown()
+	// Abort the transaction on both the query as well as commit.
+	// The first abort error will be wrapped. The client will unwrap the cause
+	// of the error and retry the transaction. The aborted error on commit
+	// will not be wrapped, but will also be recognized by the client as an
+	// abort that should be retried.
+	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
+		SimulatedExecutionTime{
+			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
+		})
+	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
+		SimulatedExecutionTime{
+			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
+		})
+	msg := "query failed"
+	numAttempts := 0
+	ctx := context.Background()
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		numAttempts++
+		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
+		defer iter.Stop()
+		for {
+			_, err := iter.Next()
+			if err == iterator.Done {
+				break
+			}
+			if err != nil {
+				// Wrap the error in another error that implements the
+				// (xerrors|errors).Wrapper interface.
+				return &wrappedTestError{err, msg}
+			}
+		}
+		return nil
+	})
+	if err != nil {
+		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
+	}
+	if g, w := numAttempts, 3; g != w {
+		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", w, w)
+	}
+
+	// Execute a transaction that returns a non-retryable error that is
+	// wrapped in a custom error. The transaction should return the custom
+	// error.
+	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
+		SimulatedExecutionTime{
+			Errors: []error{status.Error(codes.NotFound, "Table not found")},
+		})
+	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		numAttempts++
+		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
+		defer iter.Stop()
+		for {
+			_, err := iter.Next()
+			if err == iterator.Done {
+				break
+			}
+			if err != nil {
+				// Wrap the error in another error that implements the
+				// (xerrors|errors).Wrapper interface.
+				return &wrappedTestError{err, msg}
+			}
+		}
+		return nil
+	})
+	if err == nil || err.Error() != msg {
+		t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg)
+	}
+}
diff --git a/spanner/errors.go b/spanner/errors.go
index cdfb96d..af03c32 100644
--- a/spanner/errors.go
+++ b/spanner/errors.go
@@ -20,7 +20,6 @@
 	"context"
 	"fmt"
 
-	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/status"
@@ -30,7 +29,14 @@
 type Error struct {
 	// Code is the canonical error code for describing the nature of a
 	// particular error.
+	//
+	// Deprecated: The error code should be extracted from the wrapped error by
+	// calling ErrCode(err error). This field will be removed in a future
+	// release.
 	Code codes.Code
+	// err is the wrapped error that caused this Spanner error. The wrapped
+	// error can be read with the Unwrap method.
+	err error
 	// Desc explains more details of the error.
 	Desc string
 	// trailers are the trailers returned in the response, if any.
@@ -42,14 +48,32 @@
 	if e == nil {
 		return fmt.Sprintf("spanner: OK")
 	}
-	return fmt.Sprintf("spanner: code = %q, desc = %q", e.Code, e.Desc)
+	code := ErrCode(e)
+	return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc)
+}
+
+// Unwrap returns the wrapped error (if any).
+func (e *Error) Unwrap() error {
+	return e.err
 }
 
 // GRPCStatus returns the corresponding gRPC Status of this Spanner error.
 // This allows the error to be converted to a gRPC status using
 // `status.Convert(error)`.
 func (e *Error) GRPCStatus() *status.Status {
-	return status.New(e.Code, e.Desc)
+	err := unwrap(e)
+	for {
+		// No gRPC Status found in the chain of errors. Return 'Unknown' with
+		// the message of the original error.
+		if err == nil {
+			return status.New(codes.Unknown, e.Desc)
+		}
+		code := status.Code(err)
+		if code != codes.Unknown {
+			return status.New(code, e.Desc)
+		}
+		err = unwrap(err)
+	}
 }
 
 // decorate decorates an existing spanner.Error with more information.
@@ -57,12 +81,15 @@
 	e.Desc = fmt.Sprintf("%v, %v", info, e.Desc)
 }
 
-// spannerErrorf generates a *spanner.Error with the given error code and
-// description.
-func spannerErrorf(ec codes.Code, format string, args ...interface{}) error {
+// spannerErrorf generates a *spanner.Error with the given description and a
+// status error with the given error code as its wrapped error.
+func spannerErrorf(code codes.Code, format string, args ...interface{}) error {
+	msg := fmt.Sprintf(format, args...)
+	wrapped := status.Error(code, msg)
 	return &Error{
-		Code: ec,
-		Desc: fmt.Sprintf(format, args...),
+		Code: code,
+		err:  wrapped,
+		Desc: msg,
 	}
 }
 
@@ -79,37 +106,36 @@
 	if err == nil {
 		return nil
 	}
-	if se, ok := err.(*Error); ok {
+	var se *Error
+	if errorAs(err, &se) {
 		if trailers != nil {
 			se.trailers = metadata.Join(se.trailers, trailers)
 		}
 		return se
 	}
 	switch {
-	case err == context.DeadlineExceeded:
-		return &Error{codes.DeadlineExceeded, err.Error(), trailers}
-	case err == context.Canceled:
-		return &Error{codes.Canceled, err.Error(), trailers}
+	case err == context.DeadlineExceeded || err == context.Canceled:
+		return &Error{status.FromContextError(err).Code(), status.FromContextError(err).Err(), err.Error(), trailers}
 	case status.Code(err) == codes.Unknown:
-		return &Error{codes.Unknown, err.Error(), trailers}
+		return &Error{codes.Unknown, err, err.Error(), trailers}
 	default:
-		return &Error{status.Code(err), grpc.ErrorDesc(err), trailers}
+		return &Error{status.Convert(err).Code(), err, status.Convert(err).Message(), trailers}
 	}
 }
 
 // ErrCode extracts the canonical error code from a Go error.
 func ErrCode(err error) codes.Code {
-	se, ok := toSpannerError(err).(*Error)
+	s, ok := status.FromError(err)
 	if !ok {
 		return codes.Unknown
 	}
-	return se.Code
+	return s.Code()
 }
 
 // ErrDesc extracts the Cloud Spanner error description from a Go error.
 func ErrDesc(err error) string {
-	se, ok := toSpannerError(err).(*Error)
-	if !ok {
+	var se *Error
+	if !errorAs(err, &se) {
 		return err.Error()
 	}
 	return se.Desc
@@ -117,8 +143,8 @@
 
 // errTrailers extracts the grpc trailers if present from a Go error.
 func errTrailers(err error) metadata.MD {
-	se, ok := err.(*Error)
-	if !ok {
+	var se *Error
+	if !errorAs(err, &se) {
 		return nil
 	}
 	return se.trailers
diff --git a/spanner/errors112.go b/spanner/errors112.go
new file mode 100644
index 0000000..318005e
--- /dev/null
+++ b/spanner/errors112.go
@@ -0,0 +1,33 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TODO: Remove entire file when support for Go1.12 and lower has been dropped.
+// +build !go1.13
+
+package spanner
+
+import "golang.org/x/xerrors"
+
+// unwrap is a generic implementation of (errors|xerrors).Unwrap(error). This
+// implementation uses xerrors and is included in Go 1.12 and earlier builds.
+func unwrap(err error) error {
+	return xerrors.Unwrap(err)
+}
+
+// errorAs is a generic implementation of
+// (errors|xerrors).As(error, interface{}). This implementation uses xerrors
+// and is included in Go 1.12 and earlier builds.
+func errorAs(err error, target interface{}) bool {
+	return xerrors.As(err, target)
+}
diff --git a/spanner/errors113.go b/spanner/errors113.go
new file mode 100644
index 0000000..41a6ea9
--- /dev/null
+++ b/spanner/errors113.go
@@ -0,0 +1,33 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TODO: Remove entire file when support for Go1.12 and lower has been dropped.
+// +build go1.13
+
+package spanner
+
+import "errors"
+
+// unwrap is a generic implementation of (errors|xerrors).Unwrap(error). This
+// implementation uses errors and is included in Go 1.13 and later builds.
+func unwrap(err error) error {
+	return errors.Unwrap(err)
+}
+
+// errorAs is a generic implementation of
+// (errors|xerrors).As(error, interface{}). This implementation uses errors and
+// is included in Go 1.13 and later builds.
+func errorAs(err error, target interface{}) bool {
+	return errors.As(err, target)
+}
diff --git a/spanner/errors_test.go b/spanner/errors_test.go
index d8484d3..31ebf46 100644
--- a/spanner/errors_test.go
+++ b/spanner/errors_test.go
@@ -25,24 +25,49 @@
 	"google.golang.org/grpc/status"
 )
 
+type wrappedTestError struct {
+	wrapped error
+	msg     string
+}
+
+func (w *wrappedTestError) Error() string {
+	return w.msg
+}
+
+func (w *wrappedTestError) Unwrap() error {
+	return w.wrapped
+}
+
 func TestToSpannerError(t *testing.T) {
 	for _, test := range []struct {
 		err      error
 		wantCode codes.Code
+		wantMsg  string
 	}{
-		{errors.New("wha?"), codes.Unknown},
-		{context.Canceled, codes.Canceled},
-		{context.DeadlineExceeded, codes.DeadlineExceeded},
-		{status.Errorf(codes.ResourceExhausted, "so tired"), codes.ResourceExhausted},
-		{spannerErrorf(codes.InvalidArgument, "bad"), codes.InvalidArgument},
+		{errors.New("wha?"), codes.Unknown, `spanner: code = "Unknown", desc = "wha?"`},
+		{context.Canceled, codes.Canceled, `spanner: code = "Canceled", desc = "context canceled"`},
+		{context.DeadlineExceeded, codes.DeadlineExceeded, `spanner: code = "DeadlineExceeded", desc = "context deadline exceeded"`},
+		{status.Errorf(codes.ResourceExhausted, "so tired"), codes.ResourceExhausted, `spanner: code = "ResourceExhausted", desc = "so tired"`},
+		{spannerErrorf(codes.InvalidArgument, "bad"), codes.InvalidArgument, `spanner: code = "InvalidArgument", desc = "bad"`},
+		{&wrappedTestError{
+			wrapped: spannerErrorf(codes.Aborted, "Transaction aborted"),
+			msg:     "error with wrapped Spanner error",
+		}, codes.Aborted, `spanner: code = "Aborted", desc = "Transaction aborted"`},
+		{&wrappedTestError{
+			wrapped: errors.New("wha?"),
+			msg:     "error with wrapped non-gRPC and non-Spanner error",
+		}, codes.Unknown, `spanner: code = "Unknown", desc = "error with wrapped non-gRPC and non-Spanner error"`},
 	} {
 		err := toSpannerError(test.err)
-		if got, want := err.(*Error).Code, test.wantCode; got != want {
+		if got, want := ErrCode(err), test.wantCode; got != want {
 			t.Errorf("%v: got %s, want %s", test.err, got, want)
 		}
 		converted := status.Convert(err)
 		if converted.Code() != test.wantCode {
 			t.Errorf("%v: got status %v, want status %v", test.err, converted.Code(), test.wantCode)
 		}
+		if got, want := err.Error(), test.wantMsg; got != want {
+			t.Errorf("%v: got msg %s, want mgs %s", test.err, got, want)
+		}
 	}
 }
diff --git a/spanner/go.mod b/spanner/go.mod
index c342fa6..499617d 100644
--- a/spanner/go.mod
+++ b/spanner/go.mod
@@ -17,6 +17,7 @@
 	golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6 // indirect
 	golang.org/x/sys v0.0.0-20191220220014-0732a990476f // indirect
 	golang.org/x/tools v0.0.0-20191223184912-a7b3459f0428 // indirect
+	golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
 	google.golang.org/api v0.15.0
 	google.golang.org/appengine v1.6.5 // indirect
 	google.golang.org/genproto v0.0.0-20191223191004-3caeed10a8bf
diff --git a/spanner/go.sum b/spanner/go.sum
index 04bd2e1..8273bc9 100644
--- a/spanner/go.sum
+++ b/spanner/go.sum
@@ -83,6 +83,7 @@
 go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU=
 golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go
index 5cd8b1f..ff6e54a 100644
--- a/spanner/pdml_test.go
+++ b/spanner/pdml_test.go
@@ -55,8 +55,12 @@
 	stmt := NewStatement(SelectFooFromBar)
 	_, err := client.PartitionedUpdate(ctx, stmt)
 	wantCode := codes.InvalidArgument
-	if serr, ok := err.(*Error); !ok || serr.Code != wantCode {
-		t.Errorf("got error %v, want code %s", err, wantCode)
+	var serr *Error
+	if !errorAs(err, &serr) {
+		t.Errorf("got error %v, want spanner.Error", err)
+	}
+	if ErrCode(serr) != wantCode {
+		t.Errorf("got error %v, want code %s", serr, wantCode)
 	}
 }
 
diff --git a/spanner/read_test.go b/spanner/read_test.go
index 0bf36db..26f72e0 100644
--- a/spanner/read_test.go
+++ b/spanner/read_test.go
@@ -1426,7 +1426,7 @@
 	case <-time.After(10 * time.Second):
 		t.Fatalf("timeout in waiting for failed query to return.")
 	}
-	if wantErr := toSpannerError(status.Errorf(codes.Unavailable, "mock server wants some sleep")); !testEqual(gotErr, wantErr) {
+	if wantErr := spannerErrorf(codes.Unavailable, "mock server wants some sleep"); !testEqual(gotErr, wantErr) {
 		t.Fatalf("stream() returns error: %v, but want error: %v", gotErr, wantErr)
 	}
 
diff --git a/spanner/retry.go b/spanner/retry.go
index 6c89215..3ee154e 100644
--- a/spanner/retry.go
+++ b/spanner/retry.go
@@ -80,7 +80,12 @@
 			if err == nil {
 				return nil
 			}
-			delay, shouldRetry := retryer.Retry(err)
+			// Get Spanner error.
+			var se *Error
+			if !errorAs(err, &se) {
+				return err
+			}
+			delay, shouldRetry := retryer.Retry(se)
 			if !shouldRetry {
 				return err
 			}
diff --git a/spanner/retry_test.go b/spanner/retry_test.go
index 27b8922..40ca438 100644
--- a/spanner/retry_test.go
+++ b/spanner/retry_test.go
@@ -52,7 +52,7 @@
 		retryInfoKey: string(b),
 	}
 	retryer := onCodes(gax.Backoff{}, codes.Aborted)
-	err := toSpannerErrorWithMetadata(spannerErrorf(codes.Aborted, "transaction was aborted"), metadata.New(trailers))
+	err := toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, "transaction was aborted"), metadata.New(trailers))
 	maxSeenDelay, shouldRetry := retryer.Retry(err)
 	if !shouldRetry {
 		t.Fatalf("expected shouldRetry to be true")
diff --git a/spanner/row.go b/spanner/row.go
index 0c2337d..0e3f216 100644
--- a/spanner/row.go
+++ b/spanner/row.go
@@ -176,8 +176,8 @@
 	if err == nil {
 		return nil
 	}
-	se, ok := toSpannerError(err).(*Error)
-	if !ok {
+	var se *Error
+	if !errorAs(err, &se) {
 		return spannerErrorf(codes.InvalidArgument, "failed to decode column %v, error = <%v>", i, err)
 	}
 	se.decorate(fmt.Sprintf("failed to decode column %v", i))
diff --git a/spanner/session.go b/spanner/session.go
index 1dc3268..7fc0265 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -614,16 +614,12 @@
 	}
 }
 
-// errInvalidSessionPool returns error for using an invalid session pool.
-func errInvalidSessionPool() error {
-	return spannerErrorf(codes.InvalidArgument, "invalid session pool")
-}
+// errInvalidSessionPool is the error for using an invalid session pool.
+var errInvalidSessionPool = spannerErrorf(codes.InvalidArgument, "invalid session pool")
 
 // errGetSessionTimeout returns error for context timeout during
 // sessionPool.take().
-func errGetSessionTimeout() error {
-	return spannerErrorf(codes.Canceled, "timeout / context canceled during getting session")
-}
+var errGetSessionTimeout = spannerErrorf(codes.Canceled, "timeout / context canceled during getting session")
 
 // shouldPrepareWriteLocked returns true if we should prepare more sessions for write.
 func (p *sessionPool) shouldPrepareWriteLocked() bool {
@@ -688,7 +684,7 @@
 		p.mu.Lock()
 		if !p.valid {
 			p.mu.Unlock()
-			return nil, errInvalidSessionPool()
+			return nil, errInvalidSessionPool
 		}
 		if p.idleList.Len() > 0 {
 			// Idle sessions are available, get one from the top of the idle
@@ -726,7 +722,7 @@
 			select {
 			case <-ctx.Done():
 				trace.TracePrintf(ctx, nil, "Context done waiting for session")
-				return nil, errGetSessionTimeout()
+				return nil, errGetSessionTimeout
 			case <-mayGetSession:
 			}
 			continue
@@ -765,7 +761,7 @@
 		p.mu.Lock()
 		if !p.valid {
 			p.mu.Unlock()
-			return nil, errInvalidSessionPool()
+			return nil, errInvalidSessionPool
 		}
 		if p.idleWriteList.Len() > 0 {
 			// Idle sessions are available, get one from the top of the idle
@@ -799,7 +795,7 @@
 				select {
 				case <-ctx.Done():
 					trace.TracePrintf(ctx, nil, "Context done waiting for session")
-					return nil, errGetSessionTimeout()
+					return nil, errGetSessionTimeout
 				case <-mayGetSession:
 				}
 				continue
diff --git a/spanner/session_test.go b/spanner/session_test.go
index 83d0f10..98a6420 100644
--- a/spanner/session_test.go
+++ b/spanner/session_test.go
@@ -483,7 +483,7 @@
 	ctx2, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
 	defer cancel()
 	_, gotErr := sp.take(ctx2)
-	if wantErr := errGetSessionTimeout(); !testEqual(gotErr, wantErr) {
+	if wantErr := errGetSessionTimeout; gotErr != wantErr {
 		t.Fatalf("the second session retrival returns error %v, want %v", gotErr, wantErr)
 	}
 	doneWaiting := make(chan struct{})
@@ -616,7 +616,7 @@
 	_, gotErr := sp.take(ctx2)
 
 	// Since MaxBurst == 1, the second session request should block.
-	if wantErr := errGetSessionTimeout(); !testEqual(gotErr, wantErr) {
+	if wantErr := errGetSessionTimeout; gotErr != wantErr {
 		t.Fatalf("session retrival returns error %v, want %v", gotErr, wantErr)
 	}
 
@@ -1305,7 +1305,7 @@
 			// If the session pool was closed between the take() and now (or
 			// even during a take()) then an error is ok.
 			if !wasValid {
-				if wantErr := errInvalidSessionPool(); !testEqual(gotErr, wantErr) {
+				if wantErr := errInvalidSessionPool; gotErr != wantErr {
 					t.Fatalf("%v.%v: got error when pool is closed: %v, want %v", ti, idx, gotErr, wantErr)
 				}
 			}
diff --git a/spanner/sessionclient_test.go b/spanner/sessionclient_test.go
index 82be7fd..13b61c7 100644
--- a/spanner/sessionclient_test.go
+++ b/spanner/sessionclient_test.go
@@ -178,7 +178,7 @@
 			// Register the errors on the server.
 			errors := make([]error, numErrors+firstErrorAt)
 			for i := firstErrorAt; i < numErrors+firstErrorAt; i++ {
-				errors[i] = spannerErrorf(codes.FailedPrecondition, "session creation failed")
+				errors[i] = status.Errorf(codes.FailedPrecondition, "session creation failed")
 			}
 			server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
 				Errors: errors,
diff --git a/spanner/statement.go b/spanner/statement.go
index eff4d1d..be1547f 100644
--- a/spanner/statement.go
+++ b/spanner/statement.go
@@ -84,8 +84,8 @@
 	if err == nil {
 		return nil
 	}
-	se, ok := toSpannerError(err).(*Error)
-	if !ok {
+	var se *Error
+	if !errorAs(err, &se) {
 		return spannerErrorf(codes.InvalidArgument, "failed to bind query parameter(name: %q, value: %v), error = <%v>", k, v, err)
 	}
 	se.decorate(fmt.Sprintf("failed to bind query parameter(name: %q, value: %v)", k, v))
diff --git a/spanner/transaction.go b/spanner/transaction.go
index 604c877..16371ae 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -22,10 +22,9 @@
 	"sync/atomic"
 	"time"
 
-	"github.com/googleapis/gax-go/v2"
-
 	"cloud.google.com/go/internal/trace"
 	vkit "cloud.google.com/go/spanner/apiv1"
+	"github.com/googleapis/gax-go/v2"
 	"google.golang.org/api/iterator"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
 	"google.golang.org/grpc"
diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go
index 2976c96..fbd8bb4 100644
--- a/spanner/transaction_test.go
+++ b/spanner/transaction_test.go
@@ -170,7 +170,7 @@
 	defer teardown()
 
 	// First commit will fail, and the retry will begin a new transaction.
-	errAbrt := spannerErrorf(codes.Aborted, "")
+	errAbrt := gstatus.Errorf(codes.Aborted, "")
 	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
 		SimulatedExecutionTime{
 			Errors: []error{errAbrt},
@@ -202,15 +202,16 @@
 	server, client, teardown := setupMockedTestServer(t)
 	defer teardown()
 
-	wantErr := spannerErrorf(codes.NotFound, "Session not found")
+	serverErr := gstatus.Errorf(codes.NotFound, "Session not found")
 	server.TestSpanner.PutExecutionTime(MethodBeginTransaction,
 		SimulatedExecutionTime{
-			Errors: []error{wantErr, wantErr, wantErr},
+			Errors: []error{serverErr, serverErr, serverErr},
 		})
 	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
 		SimulatedExecutionTime{
-			Errors: []error{wantErr, wantErr, wantErr},
+			Errors: []error{serverErr, serverErr, serverErr},
 		})
+	wantErr := toSpannerError(serverErr)
 
 	txn := client.ReadOnlyTransaction()
 	defer txn.Close()
diff --git a/spanner/value.go b/spanner/value.go
index 249a56c..74a0d70 100644
--- a/spanner/value.go
+++ b/spanner/value.go
@@ -1334,8 +1334,8 @@
 
 // errDecodeArrayElement returns error for failure in decoding single array element.
 func errDecodeArrayElement(i int, v proto.Message, sqlType string, err error) error {
-	se, ok := toSpannerError(err).(*Error)
-	if !ok {
+	var se *Error
+	if !errorAs(err, &se) {
 		return spannerErrorf(codes.Unknown,
 			"cannot decode %v(array element %v) as %v, error = <%v>", v, i, sqlType, err)
 	}
@@ -1602,8 +1602,8 @@
 // errDecodeStructField returns error for failure in decoding a single field of
 // a Cloud Spanner STRUCT.
 func errDecodeStructField(ty *sppb.StructType, f string, err error) error {
-	se, ok := toSpannerError(err).(*Error)
-	if !ok {
+	var se *Error
+	if !errorAs(err, &se) {
 		return spannerErrorf(codes.Unknown,
 			"cannot decode field %v of Cloud Spanner STRUCT %+v, error = <%v>", f, ty, err)
 	}