spanner: refactor mockclient

This CL refactors MockCloudSpannerClient, removing the Actions
expectations/return value setup slice, the nice bool (fail all requests), and
the injErr map (fail specific request). These are replaced by a simpler
ReceivedRequests channel that users can assert on to introspect requests sent
to the stub, as well as a simple function-injectable wrapper to perform
test-specific overloading of methods for things like errors.

This CL is in preparation for a future unit test that asserts on the contents
of a request (which we can now introspect using ReceivedRequests). This CL
also attempts to make these tests more readable, by moving assertions and
custom response logic directly into the test instead of having to figure
out how different methods in the mock behave.

This CL also does a large amount of general clean-up:

- setup and mockClient merged into a single function, serverClientMock, that
inits the client and the mock. Also, consolidate all test cleanup into a
callback, and make sure the callback gets called everywhere.
- TestReadOnlyAcquire split into several tests.
- Amazingly integration_test.go's TestMain was causing -short to skip ALL
tests. Its preconditions have been split into a new function
initIntegrationTest to prevent this, causing -short to once again work as
expected.
- prepare renamed prepareIntegrationTest.
- All integration tests renamed to TestIntegration_*.
- Moved TestStructParametersBind to integration_test.go, renamed
TestIntegration_StructParametersBind.
- Moved test-specific variables like errAbrt from global scope to the one
test that uses them.
- Call and assign context.Background once instead of re-calling it everywhere
that context is expected.
- Convert to using Fatalf instead of Errorf. This is slightly overkill: Errorf
should be preferred. However, these tests tend to snowball fail such that the
first failure causes all subsequent assertions to fail. In the future we might
want to revisit this and look more closely.
- Re-write several test descriptions to be understandable and clear.
- Re-write several test names to be more indicative of what's being tested.
- Re-write several test assertions to use better failure descriptions, as well
as %+v.
- Removed several testing.Short->t.Skip statements from unit tests (should
only be on integration tests and the like).

Change-Id: I100319acad62db08da1f168bf5087bdfccd3b6b2
Reviewed-on: https://code-review.googlesource.com/c/36270
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Eno Compton <enocom@google.com>
diff --git a/internal/kokoro/vet.sh b/internal/kokoro/vet.sh
index 30d34c1..9ef25ef 100755
--- a/internal/kokoro/vet.sh
+++ b/internal/kokoro/vet.sh
@@ -52,6 +52,7 @@
     grep -v "ALL_CAPS" | \
     grep -v "go-cloud-debug-agent" | \
     grep -v "mock_test" | \
+    grep -v "internal/testutil/funcmock.go" | \
     grep -v "a blank import should be only in a main or test package" | \
     grep -vE "\.pb\.go:" || true) | tee /dev/stderr | (! read)
 
diff --git a/spanner/big_pdml_test.go b/spanner/big_pdml_test.go
index 262eded..e098337 100644
--- a/spanner/big_pdml_test.go
+++ b/spanner/big_pdml_test.go
@@ -24,11 +24,11 @@
 	"testing"
 )
 
-func TestBigPDML(t *testing.T) {
+func TestIntegration_BigPDML(t *testing.T) {
 	const nRows int = 1e4
 
 	ctx := context.Background()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	columns := []string{"SingerId", "FirstName", "LastName"}
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index 694f9e0..76f8a94 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -123,21 +123,31 @@
 )
 
 func TestMain(m *testing.M) {
+	cleanup := initIntegrationTests()
+	res := m.Run()
+	cleanup()
+	os.Exit(res)
+}
+
+func initIntegrationTests() func() {
+	ctx := context.Background()
 	flag.Parse() // needed for testing.Short()
+	noop := func() {}
+
 	if testing.Short() {
 		log.Println("Integration tests skipped in -short mode.")
-		return
+		return noop
 	}
+
 	if testProjectID == "" {
 		log.Println("Integration tests skipped: GCLOUD_TESTS_GOLANG_PROJECT_ID is missing")
-		return
+		return noop
 	}
-	ctx := context.Background()
 
 	ts := testutil.TokenSource(ctx, AdminScope, Scope)
 	if ts == nil {
 		log.Printf("Integration test skipped: cannot get service account credential from environment variable %v", "GCLOUD_TESTS_GOLANG_KEY")
-		return
+		return noop
 	}
 	var err error
 
@@ -147,17 +157,18 @@
 		log.Fatalf("cannot create admin client: %v", err)
 	}
 
-	res := m.Run()
-	cleanupDatabases()
-	os.Exit(res)
+	return func() {
+		cleanupDatabases()
+		admin.Close()
+	}
 }
 
 // Test SingleUse transaction.
-func TestSingleUse(t *testing.T) {
+func TestIntegration_SingleUse(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	writes := []struct {
@@ -363,11 +374,11 @@
 
 // Test ReadOnlyTransaction. The testsuite is mostly like SingleUse, except it
 // also tests for a single timestamp across multiple reads.
-func TestReadOnlyTransaction(t *testing.T) {
+func TestIntegration_ReadOnlyTransaction(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	writes := []struct {
@@ -546,10 +557,10 @@
 }
 
 // Test ReadOnlyTransaction with different timestamp bound when there's an update at the same time.
-func TestUpdateDuringRead(t *testing.T) {
+func TestIntegration_UpdateDuringRead(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	for i, tb := range []TimestampBound{
@@ -576,11 +587,11 @@
 }
 
 // Test ReadWriteTransaction.
-func TestReadWriteTransaction(t *testing.T) {
+func TestIntegration_ReadWriteTransaction(t *testing.T) {
 	// Give a longer deadline because of transaction backoffs.
 	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	// Set up two accounts
@@ -665,11 +676,11 @@
 	}
 }
 
-func TestReads(t *testing.T) {
+func TestIntegration_Reads(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepare(ctx, t, readDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements)
 	defer cleanup()
 
 	// Includes k0..k14. Strings sort lexically, eg "k1" < "k10" < "k2".
@@ -729,13 +740,13 @@
 	indexRangeReads(ctx, t, client)
 }
 
-func TestEarlyTimestamp(t *testing.T) {
+func TestIntegration_EarlyTimestamp(t *testing.T) {
 	// Test that we can get the timestamp from a read-only transaction as
 	// soon as we have read at least one row.
 	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepare(ctx, t, readDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements)
 	defer cleanup()
 
 	var ms []*Mutation
@@ -775,10 +786,10 @@
 	}
 }
 
-func TestNestedTransaction(t *testing.T) {
+func TestIntegration_NestedTransaction(t *testing.T) {
 	// You cannot use a transaction from inside a read-write transaction.
 	ctx := context.Background()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
@@ -805,10 +816,10 @@
 }
 
 // Test client recovery on database recreation.
-func TestDbRemovalRecovery(t *testing.T) {
+func TestIntegration_DbRemovalRecovery(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
 	defer cancel()
-	client, dbPath, cleanup := prepare(ctx, t, singerDBStatements)
+	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	// Drop the testing database.
@@ -854,10 +865,10 @@
 }
 
 // Test encoding/decoding non-struct Cloud Spanner types.
-func TestBasicTypes(t *testing.T) {
+func TestIntegration_BasicTypes(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
@@ -1000,10 +1011,10 @@
 }
 
 // Test decoding Cloud Spanner STRUCT type.
-func TestStructTypes(t *testing.T) {
+func TestIntegration_StructTypes(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	tests := []struct {
@@ -1085,9 +1096,9 @@
 	}
 }
 
-func TestStructParametersUnsupported(t *testing.T) {
+func TestIntegration_StructParametersUnsupported(t *testing.T) {
 	ctx := context.Background()
-	client, _, cleanup := prepare(ctx, t, nil)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, nil)
 	defer cleanup()
 
 	for _, test := range []struct {
@@ -1127,9 +1138,9 @@
 }
 
 // Test queries of the form "SELECT expr".
-func TestQueryExpressions(t *testing.T) {
+func TestIntegration_QueryExpressions(t *testing.T) {
 	ctx := context.Background()
-	client, _, cleanup := prepare(ctx, t, nil)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, nil)
 	defer cleanup()
 
 	newRow := func(vals []interface{}) *Row {
@@ -1180,9 +1191,9 @@
 	}
 }
 
-func TestQueryStats(t *testing.T) {
+func TestIntegration_QueryStats(t *testing.T) {
 	ctx := context.Background()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	accounts := []*Mutation{
@@ -1221,7 +1232,7 @@
 	}
 }
 
-func TestInvalidDatabase(t *testing.T) {
+func TestIntegration_InvalidDatabase(t *testing.T) {
 	if testProjectID == "" {
 		t.Skip("Integration tests skipped: GCLOUD_TESTS_GOLANG_PROJECT_ID is missing")
 	}
@@ -1238,9 +1249,9 @@
 	}
 }
 
-func TestReadErrors(t *testing.T) {
+func TestIntegration_ReadErrors(t *testing.T) {
 	ctx := context.Background()
-	client, _, cleanup := prepare(ctx, t, readDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements)
 	defer cleanup()
 
 	// Read over invalid table fails
@@ -1280,10 +1291,10 @@
 }
 
 // Test TransactionRunner. Test that transactions are aborted and retried as expected.
-func TestTransactionRunner(t *testing.T) {
+func TestIntegration_TransactionRunner(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	// Test 1: User error should abort the transaction.
@@ -1412,7 +1423,7 @@
 // Test PartitionQuery of BatchReadOnlyTransaction, create partitions then
 // serialize and deserialize both transaction and partition to be used in
 // execution on another client, and compare results.
-func TestBatchQuery(t *testing.T) {
+func TestIntegration_BatchQuery(t *testing.T) {
 	// Set up testing environment.
 	var (
 		client2 *Client
@@ -1420,7 +1431,7 @@
 	)
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, dbPath, cleanup := prepare(ctx, t, simpleDBStatements)
+	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements)
 	defer cleanup()
 
 	if err = populate(ctx, client); err != nil {
@@ -1496,7 +1507,7 @@
 }
 
 // Test PartitionRead of BatchReadOnlyTransaction, similar to TestBatchQuery
-func TestBatchRead(t *testing.T) {
+func TestIntegration_BatchRead(t *testing.T) {
 	// Set up testing environment.
 	var (
 		client2 *Client
@@ -1504,7 +1515,7 @@
 	)
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, dbPath, cleanup := prepare(ctx, t, simpleDBStatements)
+	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements)
 	defer cleanup()
 
 	if err = populate(ctx, client); err != nil {
@@ -1579,7 +1590,7 @@
 }
 
 // Test normal txReadEnv method on BatchReadOnlyTransaction.
-func TestBROTNormal(t *testing.T) {
+func TestIntegration_BROTNormal(t *testing.T) {
 	// Set up testing environment and create txn.
 	var (
 		txn *BatchReadOnlyTransaction
@@ -1589,7 +1600,7 @@
 	)
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, simpleDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements)
 	defer cleanup()
 
 	if txn, err = client.BatchReadOnlyTransaction(ctx, StrongRead()); err != nil {
@@ -1613,10 +1624,10 @@
 	}
 }
 
-func TestCommitTimestamp(t *testing.T) {
+func TestIntegration_CommitTimestamp(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, ctsDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, ctsDBStatements)
 	defer cleanup()
 
 	type testTableRow struct {
@@ -1680,10 +1691,10 @@
 	}
 }
 
-func TestDML(t *testing.T) {
+func TestIntegration_DML(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	// Function that reads a single row's first name from within a transaction.
@@ -1845,10 +1856,178 @@
 	}
 }
 
-func TestPDML(t *testing.T) {
+func TestIntegration_StructParametersBind(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	client, _, cleanup := prepareIntegrationTest(ctx, t, nil)
+	defer cleanup()
+
+	type tRow []interface{}
+	type tRows []struct{ trow tRow }
+
+	type allFields struct {
+		Stringf string
+		Intf    int
+		Boolf   bool
+		Floatf  float64
+		Bytef   []byte
+		Timef   time.Time
+		Datef   civil.Date
+	}
+	allColumns := []string{
+		"Stringf",
+		"Intf",
+		"Boolf",
+		"Floatf",
+		"Bytef",
+		"Timef",
+		"Datef",
+	}
+	s1 := allFields{"abc", 300, false, 3.45, []byte("foo"), t1, d1}
+	s2 := allFields{"def", -300, false, -3.45, []byte("bar"), t2, d2}
+
+	dynamicStructType := reflect.StructOf([]reflect.StructField{
+		{Name: "A", Type: reflect.TypeOf(t1), Tag: `spanner:"ff1"`},
+	})
+	s3 := reflect.New(dynamicStructType)
+	s3.Elem().Field(0).Set(reflect.ValueOf(t1))
+
+	for i, test := range []struct {
+		param interface{}
+		sql   string
+		cols  []string
+		trows tRows
+	}{
+		// Struct value.
+		{
+			s1,
+			"SELECT" +
+				" @p.Stringf," +
+				" @p.Intf," +
+				" @p.Boolf," +
+				" @p.Floatf," +
+				" @p.Bytef," +
+				" @p.Timef," +
+				" @p.Datef",
+			allColumns,
+			tRows{
+				{tRow{"abc", 300, false, 3.45, []byte("foo"), t1, d1}},
+			},
+		},
+		// Array of struct value.
+		{
+			[]allFields{s1, s2},
+			"SELECT * FROM UNNEST(@p)",
+			allColumns,
+			tRows{
+				{tRow{"abc", 300, false, 3.45, []byte("foo"), t1, d1}},
+				{tRow{"def", -300, false, -3.45, []byte("bar"), t2, d2}},
+			},
+		},
+		// Null struct.
+		{
+			(*allFields)(nil),
+			"SELECT @p IS NULL",
+			[]string{""},
+			tRows{
+				{tRow{true}},
+			},
+		},
+		// Null Array of struct.
+		{
+			[]allFields(nil),
+			"SELECT @p IS NULL",
+			[]string{""},
+			tRows{
+				{tRow{true}},
+			},
+		},
+		// Empty struct.
+		{
+			struct{}{},
+			"SELECT @p IS NULL ",
+			[]string{""},
+			tRows{
+				{tRow{false}},
+			},
+		},
+		// Empty array of struct.
+		{
+			[]allFields{},
+			"SELECT * FROM UNNEST(@p) ",
+			allColumns,
+			tRows{},
+		},
+		// Struct with duplicate fields.
+		{
+			struct {
+				A int `spanner:"field"`
+				B int `spanner:"field"`
+			}{10, 20},
+			"SELECT * FROM UNNEST([@p]) ",
+			[]string{"field", "field"},
+			tRows{
+				{tRow{10, 20}},
+			},
+		},
+		// Struct with unnamed fields.
+		{
+			struct {
+				A string `spanner:""`
+			}{"hello"},
+			"SELECT * FROM UNNEST([@p]) ",
+			[]string{""},
+			tRows{
+				{tRow{"hello"}},
+			},
+		},
+		// Mixed struct.
+		{
+			struct {
+				DynamicStructField interface{}  `spanner:"f1"`
+				ArrayStructField   []*allFields `spanner:"f2"`
+			}{
+				DynamicStructField: s3.Interface(),
+				ArrayStructField:   []*allFields{nil},
+			},
+			"SELECT @p.f1.ff1, ARRAY_LENGTH(@p.f2), @p.f2[OFFSET(0)] IS NULL ",
+			[]string{"ff1", "", ""},
+			tRows{
+				{tRow{t1, 1, true}},
+			},
+		},
+	} {
+		iter := client.Single().Query(ctx, Statement{
+			SQL:    test.sql,
+			Params: map[string]interface{}{"p": test.param},
+		})
+		var gotRows []*Row
+		err := iter.Do(func(r *Row) error {
+			gotRows = append(gotRows, r)
+			return nil
+		})
+		if err != nil {
+			t.Errorf("Failed to execute test case %d, error: %v", i, err)
+		}
+
+		var wantRows []*Row
+		for j, row := range test.trows {
+			r, err := NewRow(test.cols, row.trow)
+			if err != nil {
+				t.Errorf("Invalid row %d in test case %d", j, i)
+			}
+			wantRows = append(wantRows, r)
+		}
+		if !testEqual(gotRows, wantRows) {
+			t.Errorf("%d: Want result %v, got result %v", i, wantRows, gotRows)
+		}
+	}
+}
+
+func TestIntegration_PDML(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepare(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
 	defer cleanup()
 
 	columns := []string{"SingerId", "FirstName", "LastName"}
@@ -1891,7 +2070,7 @@
 }
 
 // Prepare initializes Cloud Spanner testing DB and clients.
-func prepare(ctx context.Context, t *testing.T, statements []string) (*Client, string, func()) {
+func prepareIntegrationTest(ctx context.Context, t *testing.T, statements []string) (*Client, string, func()) {
 	if admin == nil {
 		t.Skip("Integration tests skipped")
 	}
diff --git a/spanner/internal/testutil/funcmock.go b/spanner/internal/testutil/funcmock.go
new file mode 100644
index 0000000..d669851
--- /dev/null
+++ b/spanner/internal/testutil/funcmock.go
@@ -0,0 +1,65 @@
+/*
+Copyright 2018 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package testutil
+
+import (
+	"context"
+
+	sppb "google.golang.org/genproto/googleapis/spanner/v1"
+	"google.golang.org/grpc"
+)
+
+// FuncMock overloads some of MockCloudSpannerClient's methods with pluggable
+// functions.
+//
+// Note: if you overload a method, you're in charge of making sure
+// MockCloudSpannerClient.ReceivedRequests receives the request appropriately.
+type FuncMock struct {
+	CommitFn           func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error)
+	BeginTransactionFn func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error)
+	GetSessionFn       func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
+	CreateSessionFn    func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
+	*MockCloudSpannerClient
+}
+
+func (s FuncMock) Commit(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
+	if s.CommitFn == nil {
+		return s.MockCloudSpannerClient.Commit(c, r, opts...)
+	}
+	return s.CommitFn(c, r, opts...)
+}
+
+func (s FuncMock) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
+	if s.BeginTransactionFn == nil {
+		return s.MockCloudSpannerClient.BeginTransaction(c, r, opts...)
+	}
+	return s.BeginTransactionFn(c, r, opts...)
+}
+
+func (s *FuncMock) GetSession(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+	if s.GetSessionFn == nil {
+		return s.MockCloudSpannerClient.GetSession(c, r, opts...)
+	}
+	return s.GetSessionFn(c, r, opts...)
+}
+
+func (s *FuncMock) CreateSession(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+	if s.CreateSessionFn == nil {
+		return s.MockCloudSpannerClient.CreateSession(c, r, opts...)
+	}
+	return s.CreateSessionFn(c, r, opts...)
+}
diff --git a/spanner/internal/testutil/mockclient.go b/spanner/internal/testutil/mockclient.go
index c2d498a..f808458 100644
--- a/spanner/internal/testutil/mockclient.go
+++ b/spanner/internal/testutil/mockclient.go
@@ -34,12 +34,6 @@
 	"google.golang.org/grpc/status"
 )
 
-// Action is a mocked RPC activity that MockCloudSpannerClient will take.
-type Action struct {
-	Method string
-	Err    error
-}
-
 // MockCloudSpannerClient is a mock implementation of sppb.SpannerClient.
 type MockCloudSpannerClient struct {
 	sppb.SpannerClient
@@ -48,60 +42,32 @@
 	t  *testing.T
 	// Live sessions on the client.
 	sessions map[string]bool
-	// Expected set of actions that will be executed by the client.
-	actions []Action
 	// Session ping history.
 	pings []string
-	// Injected error, will be returned by all APIs.
-	injErr map[string]error
-	// Client will not fail on any request.
-	nice bool
 	// Client will stall on any requests.
 	freezed chan struct{}
+
+	// Expected set of actions that have been executed by the client. These
+	// interfaces should be type reflected against with *Request types in sppb,
+	// such as sppb.GetSessionRequest. Buffered to a large degree.
+	ReceivedRequests chan interface{}
 }
 
 // NewMockCloudSpannerClient creates new MockCloudSpannerClient instance.
-func NewMockCloudSpannerClient(t *testing.T, acts ...Action) *MockCloudSpannerClient {
-	mc := &MockCloudSpannerClient{t: t, sessions: map[string]bool{}, injErr: map[string]error{}}
-	mc.SetActions(acts...)
+func NewMockCloudSpannerClient(t *testing.T) *MockCloudSpannerClient {
+	mc := &MockCloudSpannerClient{
+		t:                t,
+		sessions:         map[string]bool{},
+		ReceivedRequests: make(chan interface{}, 100000),
+	}
+
 	// Produce a closed channel, so the default action of ready is to not block.
 	mc.Freeze()
 	mc.Unfreeze()
+
 	return mc
 }
 
-// MakeNice makes this a nice mock which will not fail on any request.
-func (m *MockCloudSpannerClient) MakeNice() {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	m.nice = true
-}
-
-// MakeStrict makes this a strict mock which will fail on any unexpected request.
-func (m *MockCloudSpannerClient) MakeStrict() {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	m.nice = false
-}
-
-// InjectError injects a global error that will be returned by all calls to method
-// regardless of the actions array.
-func (m *MockCloudSpannerClient) InjectError(method string, err error) {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	m.injErr[method] = err
-}
-
-// SetActions sets the new set of expected actions to MockCloudSpannerClient.
-func (m *MockCloudSpannerClient) SetActions(acts ...Action) {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	m.actions = nil
-	for _, act := range acts {
-		m.actions = append(m.actions, act)
-	}
-}
-
 // DumpPings dumps the ping history.
 func (m *MockCloudSpannerClient) DumpPings() []string {
 	m.mu.Lock()
@@ -123,11 +89,10 @@
 // CreateSession is a placeholder for SpannerClient.CreateSession.
 func (m *MockCloudSpannerClient) CreateSession(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	if err := m.injErr["CreateSession"]; err != nil {
-		return nil, err
-	}
 	s := &sppb.Session{}
 	if r.Database != "mockdb" {
 		// Reject other databases
@@ -142,11 +107,10 @@
 // GetSession is a placeholder for SpannerClient.GetSession.
 func (m *MockCloudSpannerClient) GetSession(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	if err := m.injErr["GetSession"]; err != nil {
-		return nil, err
-	}
 	m.pings = append(m.pings, r.Name)
 	if _, ok := m.sessions[r.Name]; !ok {
 		return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Session not found: %v", r.Name))
@@ -157,11 +121,10 @@
 // DeleteSession is a placeholder for SpannerClient.DeleteSession.
 func (m *MockCloudSpannerClient) DeleteSession(c context.Context, r *sppb.DeleteSessionRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	if err := m.injErr["DeleteSession"]; err != nil {
-		return nil, err
-	}
 	if _, ok := m.sessions[r.Name]; !ok {
 		// Session not found.
 		return &empty.Empty{}, status.Errorf(codes.NotFound, fmt.Sprintf("Session not found: %v", r.Name))
@@ -174,12 +137,10 @@
 // ExecuteStreamingSql is a mock implementation of SpannerClient.ExecuteStreamingSql.
 func (m *MockCloudSpannerClient) ExecuteStreamingSql(c context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (sppb.Spanner_ExecuteStreamingSqlClient, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	act, err := m.expectAction("ExecuteStreamingSql")
-	if err != nil {
-		return nil, err
-	}
 	wantReq := &sppb.ExecuteSqlRequest{
 		Session: "mocksession",
 		Transaction: &sppb.TransactionSelector{
@@ -205,21 +166,16 @@
 	if !proto.Equal(r, wantReq) {
 		return nil, fmt.Errorf("got query request: %v, want: %v", r, wantReq)
 	}
-	if act.Err != nil {
-		return nil, act.Err
-	}
 	return nil, errors.New("query never succeeds on mock client")
 }
 
 // StreamingRead is a placeholder for SpannerClient.StreamingRead.
 func (m *MockCloudSpannerClient) StreamingRead(c context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (sppb.Spanner_StreamingReadClient, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	act, err := m.expectAction("StreamingRead", "StreamingReadIndex")
-	if err != nil {
-		return nil, err
-	}
 	wantReq := &sppb.ReadRequest{
 		Session: "mocksession",
 		Transaction: &sppb.TransactionSelector{
@@ -250,32 +206,19 @@
 			All:    false,
 		},
 	}
-	if act.Method == "StreamingIndexRead" {
-		wantReq.Index = "idx1"
-	}
 	if !proto.Equal(r, wantReq) {
 		return nil, fmt.Errorf("got query request: %v, want: %v", r, wantReq)
 	}
-	if act.Err != nil {
-		return nil, act.Err
-	}
 	return nil, errors.New("read never succeeds on mock client")
 }
 
 // BeginTransaction is a placeholder for SpannerClient.BeginTransaction.
 func (m *MockCloudSpannerClient) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	if !m.nice {
-		act, err := m.expectAction("BeginTransaction")
-		if err != nil {
-			return nil, err
-		}
-		if act.Err != nil {
-			return nil, act.Err
-		}
-	}
 	resp := &sppb.Transaction{Id: []byte("transaction-1")}
 	if _, ok := r.Options.Mode.(*sppb.TransactionOptions_ReadOnly_); ok {
 		resp.ReadTimestamp = &pbt.Timestamp{Seconds: 3, Nanos: 4}
@@ -286,67 +229,37 @@
 // Commit is a placeholder for SpannerClient.Commit.
 func (m *MockCloudSpannerClient) Commit(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	if !m.nice {
-		act, err := m.expectAction("Commit")
-		if err != nil {
-			return nil, err
-		}
-		if act.Err != nil {
-			return nil, act.Err
-		}
-	}
 	return &sppb.CommitResponse{CommitTimestamp: &pbt.Timestamp{Seconds: 1, Nanos: 2}}, nil
 }
 
 // Rollback is a placeholder for SpannerClient.Rollback.
 func (m *MockCloudSpannerClient) Rollback(c context.Context, r *sppb.RollbackRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	m.mu.Lock()
 	defer m.mu.Unlock()
-	if !m.nice {
-		act, err := m.expectAction("Rollback")
-		if err != nil {
-			return nil, err
-		}
-		if act.Err != nil {
-			return nil, act.Err
-		}
-	}
 	return nil, nil
 }
 
 // PartitionQuery is a placeholder for SpannerServer.PartitionQuery.
 func (m *MockCloudSpannerClient) PartitionQuery(ctx context.Context, r *sppb.PartitionQueryRequest, opts ...grpc.CallOption) (*sppb.PartitionResponse, error) {
 	m.ready()
+	m.ReceivedRequests <- r
+
 	return nil, errors.New("Unimplemented")
 }
 
 // PartitionRead is a placeholder for SpannerServer.PartitionRead.
 func (m *MockCloudSpannerClient) PartitionRead(ctx context.Context, r *sppb.PartitionReadRequest, opts ...grpc.CallOption) (*sppb.PartitionResponse, error) {
 	m.ready()
-	return nil, errors.New("Unimplemented")
-}
+	m.ReceivedRequests <- r
 
-func (m *MockCloudSpannerClient) expectAction(methods ...string) (Action, error) {
-	for _, me := range methods {
-		if err := m.injErr[me]; err != nil {
-			return Action{}, err
-		}
-	}
-	if len(m.actions) == 0 {
-		m.t.Fatalf("unexpected %v executed", methods)
-	}
-	act := m.actions[0]
-	m.actions = m.actions[1:]
-	for _, me := range methods {
-		if me == act.Method {
-			return act, nil
-		}
-	}
-	m.t.Fatalf("unexpected call of one of %v, want method %s", methods, act.Method)
-	return Action{}, nil
+	return nil, errors.New("Unimplemented")
 }
 
 // Freeze stalls all requests.
@@ -363,13 +276,6 @@
 	close(m.freezed)
 }
 
-// CheckActionsConsumed checks that all actions have been consumed.
-func (m *MockCloudSpannerClient) CheckActionsConsumed() {
-	if len(m.actions) != 0 {
-		m.t.Fatalf("unconsumed mock client actions: %v", m.actions)
-	}
-}
-
 // ready checks conditions before executing requests
 // TODO: add checks for injected errors, actions
 func (m *MockCloudSpannerClient) ready() {
diff --git a/spanner/session.go b/spanner/session.go
index d058e57..8d34685 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -1086,7 +1086,7 @@
 	// If a Cloud Spanner can no longer locate the session (for example, if session is garbage collected), then caller
 	// should not try to return the session back into the session pool.
 	// TODO: once gRPC can return auxiliary error information, stop parsing the error message.
-	if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found:") {
+	if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found") {
 		return true
 	}
 	return false
diff --git a/spanner/session_test.go b/spanner/session_test.go
index 932d635..54b906c 100644
--- a/spanner/session_test.go
+++ b/spanner/session_test.go
@@ -23,43 +23,21 @@
 	"fmt"
 	"math/rand"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
 	"cloud.google.com/go/spanner/internal/testutil"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
+	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
 
-// setup prepares test environment for regular session pool tests.
-//
-// Note: be sure to call cleanup!
-func setup(t *testing.T, spc SessionPoolConfig) (sp *sessionPool, sc *testutil.MockCloudSpannerClient, cleanup func()) {
-	sc = testutil.NewMockCloudSpannerClient(t)
-	spc.getRPCClient = func() (sppb.SpannerClient, error) {
-		return sc, nil
-	}
-	if spc.HealthCheckInterval == 0 {
-		spc.HealthCheckInterval = 50 * time.Millisecond
-	}
-	if spc.healthCheckSampleInterval == 0 {
-		spc.healthCheckSampleInterval = 10 * time.Millisecond
-	}
-	sp, err := newSessionPool("mockdb", spc, nil)
-	if err != nil {
-		t.Fatalf("cannot create session pool: %v", err)
-	}
-	cleanup = func() {
-		sp.hc.close()
-		sp.close()
-	}
-	return
-}
-
 // TestSessionPoolConfigValidation tests session pool config validation.
 func TestSessionPoolConfigValidation(t *testing.T) {
 	t.Parallel()
+
 	sc := testutil.NewMockCloudSpannerClient(t)
 	for _, test := range []struct {
 		spc SessionPoolConfig
@@ -81,7 +59,7 @@
 		},
 	} {
 		if _, err := newSessionPool("mockdb", test.spc, nil); !testEqual(err, test.err) {
-			t.Errorf("want %v, got %v", test.err, err)
+			t.Fatalf("want %v, got %v", test.err, err)
 		}
 	}
 }
@@ -89,46 +67,47 @@
 // TestSessionCreation tests session creation during sessionPool.Take().
 func TestSessionCreation(t *testing.T) {
 	t.Parallel()
-
-	sp, sc, cleanup := setup(t, SessionPoolConfig{})
+	ctx := context.Background()
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
 	defer cleanup()
 
-	// Take three sessions from session pool, this should trigger session pool to create three new sessions.
+	// Take three sessions from session pool, this should trigger session pool
+	// to create three new sessions.
 	shs := make([]*sessionHandle, 3)
 	// gotDs holds the unique sessions taken from session pool.
 	gotDs := map[string]bool{}
 	for i := 0; i < len(shs); i++ {
 		var err error
-		shs[i], err = sp.take(context.Background())
+		shs[i], err = sp.take(ctx)
 		if err != nil {
-			t.Errorf("failed to get session(%v): %v", i, err)
+			t.Fatalf("failed to get session(%v): %v", i, err)
 		}
 		gotDs[shs[i].getID()] = true
 	}
 	if len(gotDs) != len(shs) {
-		t.Errorf("session pool created %v sessions, want %v", len(gotDs), len(shs))
+		t.Fatalf("session pool created %v sessions, want %v", len(gotDs), len(shs))
 	}
-	if wantDs := sc.DumpSessions(); !testEqual(gotDs, wantDs) {
-		t.Errorf("session pool creates sessions %v, want %v", gotDs, wantDs)
+	if wantDs := mock.DumpSessions(); !testEqual(gotDs, wantDs) {
+		t.Fatalf("session pool creates sessions %v, want %v", gotDs, wantDs)
 	}
 	// Verify that created sessions are recorded correctly in session pool.
 	sp.mu.Lock()
 	if int(sp.numOpened) != len(shs) {
-		t.Errorf("session pool reports %v open sessions, want %v", sp.numOpened, len(shs))
+		t.Fatalf("session pool reports %v open sessions, want %v", sp.numOpened, len(shs))
 	}
 	if sp.createReqs != 0 {
-		t.Errorf("session pool reports %v session create requests, want 0", int(sp.createReqs))
+		t.Fatalf("session pool reports %v session create requests, want 0", int(sp.createReqs))
 	}
 	sp.mu.Unlock()
 	// Verify that created sessions are tracked correctly by healthcheck queue.
 	hc := sp.hc
 	hc.mu.Lock()
 	if hc.queue.Len() != len(shs) {
-		t.Errorf("healthcheck queue length = %v, want %v", hc.queue.Len(), len(shs))
+		t.Fatalf("healthcheck queue length = %v, want %v", hc.queue.Len(), len(shs))
 	}
 	for _, s := range hc.queue.sessions {
 		if !gotDs[s.getID()] {
-			t.Errorf("session %v is in healthcheck queue, but it is not created by session pool", s.getID())
+			t.Fatalf("session %v is in healthcheck queue, but it is not created by session pool", s.getID())
 		}
 	}
 	hc.mu.Unlock()
@@ -137,253 +116,304 @@
 // TestTakeFromIdleList tests taking sessions from session pool's idle list.
 func TestTakeFromIdleList(t *testing.T) {
 	t.Parallel()
+	ctx := context.Background()
 
-	sp, sc, cleanup := setup(t, SessionPoolConfig{MaxIdle: 10}) // make sure maintainer keeps the idle sessions
+	// Make sure maintainer keeps the idle sessions.
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxIdle: 10})
 	defer cleanup()
 
 	// Take ten sessions from session pool and recycle them.
 	shs := make([]*sessionHandle, 10)
 	for i := 0; i < len(shs); i++ {
 		var err error
-		shs[i], err = sp.take(context.Background())
+		shs[i], err = sp.take(ctx)
 		if err != nil {
-			t.Errorf("failed to get session(%v): %v", i, err)
+			t.Fatalf("failed to get session(%v): %v", i, err)
 		}
 	}
-	// Make sure it's sampled once before recycling, otherwise it will be cleaned up.
+	// Make sure it's sampled once before recycling, otherwise it will be
+	// cleaned up.
 	<-time.After(sp.SessionPoolConfig.healthCheckSampleInterval)
 	for i := 0; i < len(shs); i++ {
 		shs[i].recycle()
 	}
-	// Further session requests from session pool won't cause mockclient to create more sessions.
-	wantSessions := sc.DumpSessions()
-	// Take ten sessions from session pool again, this time all sessions should come from idle list.
+	// Further session requests from session pool won't cause mockclient to
+	// create more sessions.
+	wantSessions := mock.DumpSessions()
+	// Take ten sessions from session pool again, this time all sessions should
+	// come from idle list.
 	gotSessions := map[string]bool{}
 	for i := 0; i < len(shs); i++ {
-		sh, err := sp.take(context.Background())
+		sh, err := sp.take(ctx)
 		if err != nil {
-			t.Errorf("cannot take session from session pool: %v", err)
+			t.Fatalf("cannot take session from session pool: %v", err)
 		}
 		gotSessions[sh.getID()] = true
 	}
 	if len(gotSessions) != 10 {
-		t.Errorf("got %v unique sessions, want 10", len(gotSessions))
+		t.Fatalf("got %v unique sessions, want 10", len(gotSessions))
 	}
 	if !testEqual(gotSessions, wantSessions) {
-		t.Errorf("got sessions: %v, want %v", gotSessions, wantSessions)
+		t.Fatalf("got sessions: %v, want %v", gotSessions, wantSessions)
 	}
 }
 
-// TesttakeWriteSessionFromIdleList tests taking write sessions from session pool's idle list.
+// TesttakeWriteSessionFromIdleList tests taking write sessions from session
+// pool's idle list.
 func TestTakeWriteSessionFromIdleList(t *testing.T) {
 	t.Parallel()
+	ctx := context.Background()
 
-	sp, sc, cleanup := setup(t, SessionPoolConfig{MaxIdle: 20}) // make sure maintainer keeps the idle sessions
+	// Make sure maintainer keeps the idle sessions.
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxIdle: 20})
 	defer cleanup()
 
-	acts := make([]testutil.Action, 20)
-	for i := 0; i < len(acts); i++ {
-		acts[i] = testutil.Action{"BeginTransaction", nil}
-	}
-	sc.SetActions(acts...)
 	// Take ten sessions from session pool and recycle them.
 	shs := make([]*sessionHandle, 10)
 	for i := 0; i < len(shs); i++ {
 		var err error
-		shs[i], err = sp.takeWriteSession(context.Background())
+		shs[i], err = sp.takeWriteSession(ctx)
 		if err != nil {
-			t.Errorf("failed to get session(%v): %v", i, err)
+			t.Fatalf("failed to get session(%v): %v", i, err)
 		}
 	}
-	// Make sure it's sampled once before recycling, otherwise it will be cleaned up.
+	// Make sure it's sampled once before recycling, otherwise it will be
+	// cleaned up.
 	<-time.After(sp.SessionPoolConfig.healthCheckSampleInterval)
 	for i := 0; i < len(shs); i++ {
 		shs[i].recycle()
 	}
-	// Further session requests from session pool won't cause mockclient to create more sessions.
-	wantSessions := sc.DumpSessions()
-	// Take ten sessions from session pool again, this time all sessions should come from idle list.
+	// Further session requests from session pool won't cause mockclient to
+	// create more sessions.
+	wantSessions := mock.DumpSessions()
+	// Take ten sessions from session pool again, this time all sessions should
+	// come from idle list.
 	gotSessions := map[string]bool{}
 	for i := 0; i < len(shs); i++ {
-		sh, err := sp.takeWriteSession(context.Background())
+		sh, err := sp.takeWriteSession(ctx)
 		if err != nil {
-			t.Errorf("cannot take session from session pool: %v", err)
+			t.Fatalf("cannot take session from session pool: %v", err)
 		}
 		gotSessions[sh.getID()] = true
 	}
 	if len(gotSessions) != 10 {
-		t.Errorf("got %v unique sessions, want 10", len(gotSessions))
+		t.Fatalf("got %v unique sessions, want 10", len(gotSessions))
 	}
 	if !testEqual(gotSessions, wantSessions) {
-		t.Errorf("got sessions: %v, want %v", gotSessions, wantSessions)
+		t.Fatalf("got sessions: %v, want %v", gotSessions, wantSessions)
 	}
 }
 
-// TestTakeFromIdleListChecked tests taking sessions from session pool's idle list, but with a extra ping check.
+// TestTakeFromIdleListChecked tests taking sessions from session pool's idle
+// list, but with a extra ping check.
 func TestTakeFromIdleListChecked(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
+	ctx := context.Background()
 
-	sp, sc, cleanup := setup(t, SessionPoolConfig{MaxIdle: 1}) // make sure maintainer keeps the idle sessions
+	// Make sure maintainer keeps the idle sessions.
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{
+		MaxIdle:                   1,
+		HealthCheckInterval:       50 * time.Millisecond,
+		healthCheckSampleInterval: 10 * time.Millisecond,
+	})
 	defer cleanup()
 
 	// Stop healthcheck workers to simulate slow pings.
 	sp.hc.close()
+
 	// Create a session and recycle it.
-	sh, err := sp.take(context.Background())
+	sh, err := sp.take(ctx)
 	if err != nil {
-		t.Errorf("failed to get session: %v", err)
+		t.Fatalf("failed to get session: %v", err)
 	}
-	// Make sure it's sampled once before recycling, otherwise it will be cleaned up.
+
+	// Make sure it's sampled once before recycling, otherwise it will be
+	// cleaned up.
 	<-time.After(sp.SessionPoolConfig.healthCheckSampleInterval)
 	wantSid := sh.getID()
 	sh.recycle()
+
+	// TODO(deklerk) get rid of this
 	<-time.After(time.Second)
-	// Two back-to-back session requests, both of them should return the same session created before and
-	// none of them should trigger a session ping.
+
+	// Two back-to-back session requests, both of them should return the same
+	// session created before and none of them should trigger a session ping.
 	for i := 0; i < 2; i++ {
 		// Take the session from the idle list and recycle it.
-		sh, err = sp.take(context.Background())
+		sh, err = sp.take(ctx)
 		if err != nil {
-			t.Errorf("%v - failed to get session: %v", i, err)
+			t.Fatalf("%v - failed to get session: %v", i, err)
 		}
 		if gotSid := sh.getID(); gotSid != wantSid {
-			t.Errorf("%v - got session id: %v, want %v", i, gotSid, wantSid)
+			t.Fatalf("%v - got session id: %v, want %v", i, gotSid, wantSid)
 		}
-		// The two back-to-back session requests shouldn't trigger any session pings because sessionPool.Take
+
+		// The two back-to-back session requests shouldn't trigger any session
+		// pings because sessionPool.Take
 		// reschedules the next healthcheck.
-		if got, want := sc.DumpPings(), ([]string{wantSid}); !testEqual(got, want) {
-			t.Errorf("%v - got ping session requests: %v, want %v", i, got, want)
+		if got, want := mock.DumpPings(), ([]string{wantSid}); !testEqual(got, want) {
+			t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want)
 		}
 		sh.recycle()
 	}
-	// Inject session error to mockclient, and take the session from the session pool, the old session should be destroyed and
-	// the session pool will create a new session.
-	sc.InjectError("GetSession", status.Errorf(codes.NotFound, "Session not found:"))
-	// Delay to trigger sessionPool.Take to ping the session.
-	<-time.After(time.Second)
-	sh, err = sp.take(context.Background())
-	if err != nil {
-		t.Errorf("failed to get session: %v", err)
+
+	// Inject session error to server stub, and take the session from the
+	// session pool, the old session should be destroyed and the session pool
+	// will create a new session.
+	mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+		mock.MockCloudSpannerClient.ReceivedRequests <- r
+		return nil, status.Errorf(codes.NotFound, "Session not found")
 	}
-	ds := sc.DumpSessions()
+
+	// Delay to trigger sessionPool.Take to ping the session.
+	// TODO(deklerk) get rid of this
+	<-time.After(time.Second)
+
+	// take will take the idle session. Then it will send a GetSession request
+	// to check if it's healthy. It'll discover that it's not healthy
+	// (NotFound), drop it, and create a new session.
+	sh, err = sp.take(ctx)
+	if err != nil {
+		t.Fatalf("failed to get session: %v", err)
+	}
+	ds := mock.DumpSessions()
 	if len(ds) != 1 {
-		t.Errorf("dumped sessions from mockclient: %v, want %v", ds, sh.getID())
+		t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID())
 	}
 	if sh.getID() == wantSid {
-		t.Errorf("sessionPool.Take still returns the same session %v, want it to create a new one", wantSid)
+		t.Fatalf("sessionPool.Take still returns the same session %v, want it to create a new one", wantSid)
 	}
 }
 
-// TestTakeFromIdleWriteListChecked tests taking sessions from session pool's idle list, but with a extra ping check.
+// TestTakeFromIdleWriteListChecked tests taking sessions from session pool's
+// idle list, but with a extra ping check.
 func TestTakeFromIdleWriteListChecked(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
+	ctx := context.Background()
 
-	sp, sc, cleanup := setup(t, SessionPoolConfig{MaxIdle: 1}) // make sure maintainer keeps the idle sessions
+	// Make sure maintainer keeps the idle sessions.
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{
+		MaxIdle:                   1,
+		HealthCheckInterval:       50 * time.Millisecond,
+		healthCheckSampleInterval: 10 * time.Millisecond,
+	})
 	defer cleanup()
 
-	sc.MakeNice()
 	// Stop healthcheck workers to simulate slow pings.
 	sp.hc.close()
+
 	// Create a session and recycle it.
-	sh, err := sp.takeWriteSession(context.Background())
+	sh, err := sp.takeWriteSession(ctx)
 	if err != nil {
-		t.Errorf("failed to get session: %v", err)
+		t.Fatalf("failed to get session: %v", err)
 	}
 	wantSid := sh.getID()
-	// Make sure it's sampled once before recycling, otherwise it will be cleaned up.
+
+	// Make sure it's sampled once before recycling, otherwise it will be
+	// cleaned up.
 	<-time.After(sp.SessionPoolConfig.healthCheckSampleInterval)
 	sh.recycle()
+
+	// TODO(deklerk) get rid of this
 	<-time.After(time.Second)
-	// Two back-to-back session requests, both of them should return the same session created before and
-	// none of them should trigger a session ping.
+
+	// Two back-to-back session requests, both of them should return the same
+	// session created before and none of them should trigger a session ping.
 	for i := 0; i < 2; i++ {
 		// Take the session from the idle list and recycle it.
-		sh, err = sp.takeWriteSession(context.Background())
+		sh, err = sp.takeWriteSession(ctx)
 		if err != nil {
-			t.Errorf("%v - failed to get session: %v", i, err)
+			t.Fatalf("%v - failed to get session: %v", i, err)
 		}
 		if gotSid := sh.getID(); gotSid != wantSid {
-			t.Errorf("%v - got session id: %v, want %v", i, gotSid, wantSid)
+			t.Fatalf("%v - got session id: %v, want %v", i, gotSid, wantSid)
 		}
-		// The two back-to-back session requests shouldn't trigger any session pings because sessionPool.Take
-		// reschedules the next healthcheck.
-		if got, want := sc.DumpPings(), ([]string{wantSid}); !testEqual(got, want) {
-			t.Errorf("%v - got ping session requests: %v, want %v", i, got, want)
+		// The two back-to-back session requests shouldn't trigger any session
+		// pings because sessionPool.Take reschedules the next healthcheck.
+		if got, want := mock.DumpPings(), ([]string{wantSid}); !testEqual(got, want) {
+			t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want)
 		}
 		sh.recycle()
 	}
-	// Inject session error to mockclient, and take the session from the session pool, the old session should be destroyed and
-	// the session pool will create a new session.
-	sc.InjectError("GetSession", status.Errorf(codes.NotFound, "Session not found:"))
-	// Delay to trigger sessionPool.Take to ping the session.
-	<-time.After(time.Second)
-	sh, err = sp.takeWriteSession(context.Background())
-	if err != nil {
-		t.Errorf("failed to get session: %v", err)
+
+	// Inject session error to mockclient, and take the session from the
+	// session pool, the old session should be destroyed and the session pool
+	// will create a new session.
+	mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+		mock.MockCloudSpannerClient.ReceivedRequests <- r
+		return nil, status.Errorf(codes.NotFound, "Session not found")
 	}
-	ds := sc.DumpSessions()
+
+	// Delay to trigger sessionPool.Take to ping the session.
+	// TOOD(deklerk) get rid of this
+	<-time.After(time.Second)
+
+	sh, err = sp.takeWriteSession(ctx)
+	if err != nil {
+		t.Fatalf("failed to get session: %v", err)
+	}
+	ds := mock.DumpSessions()
 	if len(ds) != 1 {
-		t.Errorf("dumped sessions from mockclient: %v, want %v", ds, sh.getID())
+		t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID())
 	}
 	if sh.getID() == wantSid {
-		t.Errorf("sessionPool.Take still returns the same session %v, want it to create a new one", wantSid)
+		t.Fatalf("sessionPool.Take still returns the same session %v, want it to create a new one", wantSid)
 	}
 }
 
 // TestMaxOpenedSessions tests max open sessions constraint.
 func TestMaxOpenedSessions(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
-
-	sp, _, cleanup := setup(t, SessionPoolConfig{MaxOpened: 1})
+	ctx := context.Background()
+	_, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MaxOpened: 1})
 	defer cleanup()
 
-	sh1, err := sp.take(context.Background())
+	sh1, err := sp.take(ctx)
 	if err != nil {
-		t.Errorf("cannot take session from session pool: %v", err)
+		t.Fatalf("cannot take session from session pool: %v", err)
 	}
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
-	defer cancel()
+
 	// Session request will timeout due to the max open sessions constraint.
-	_, gotErr := sp.take(ctx)
+	ctx2, cancel := context.WithTimeout(ctx, time.Second)
+	defer cancel()
+	_, gotErr := sp.take(ctx2)
 	if wantErr := errGetSessionTimeout(); !testEqual(gotErr, wantErr) {
-		t.Errorf("the second session retrival returns error %v, want %v", gotErr, wantErr)
+		t.Fatalf("the second session retrival returns error %v, want %v", gotErr, wantErr)
 	}
+
 	go func() {
+		// TODO(deklerk) remove this
 		<-time.After(time.Second)
-		// destroy the first session to allow the next session request to proceed.
+		// Destroy the first session to allow the next session request to
+		// proceed.
 		sh1.destroy()
 	}()
-	// Now session request can be processed because the first session will be destroyed.
-	sh2, err := sp.take(context.Background())
+
+	// Now session request can be processed because the first session will be
+	// destroyed.
+	sh2, err := sp.take(ctx)
 	if err != nil {
-		t.Errorf("after the first session is destroyed, session retrival still returns error %v, want nil", err)
+		t.Fatalf("after the first session is destroyed, session retrival still returns error %v, want nil", err)
 	}
 	if !sh2.session.isValid() || sh2.getID() == "" {
-		t.Errorf("got invalid session: %v", sh2.session)
+		t.Fatalf("got invalid session: %v", sh2.session)
 	}
 }
 
 // TestMinOpenedSessions tests min open session constraint.
 func TestMinOpenedSessions(t *testing.T) {
-	sp, _, cleanup := setup(t, SessionPoolConfig{MinOpened: 1})
+	t.Parallel()
+	ctx := context.Background()
+	_, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1})
 	defer cleanup()
 
 	// Take ten sessions from session pool and recycle them.
 	var ss []*session
 	var shs []*sessionHandle
 	for i := 0; i < 10; i++ {
-		sh, err := sp.take(context.Background())
+		sh, err := sp.take(ctx)
 		if err != nil {
-			t.Errorf("failed to get session(%v): %v", i, err)
+			t.Fatalf("failed to get session(%v): %v", i, err)
 		}
 		ss = append(ss, sh.session)
 		shs = append(shs, sh)
@@ -392,32 +422,44 @@
 	for _, sh := range shs {
 		sh.recycle()
 	}
+
 	// Simulate session expiration.
 	for _, s := range ss {
 		s.destroy(true)
 	}
+
 	sp.mu.Lock()
 	defer sp.mu.Unlock()
-	// There should be still one session left in idle list due to the min open sessions constraint.
+	// There should be still one session left in idle list due to the min open
+	// sessions constraint.
 	if sp.idleList.Len() != 1 {
-		t.Errorf("got %v sessions in idle list, want 1 %d", sp.idleList.Len(), sp.numOpened)
+		t.Fatalf("got %v sessions in idle list, want 1 %d", sp.idleList.Len(), sp.numOpened)
 	}
 }
 
 // TestMaxBurst tests max burst constraint.
 func TestMaxBurst(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
-
-	sp, sc, cleanup := setup(t, SessionPoolConfig{MaxBurst: 1})
+	ctx := context.Background()
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxBurst: 1})
 	defer cleanup()
 
 	// Will cause session creation RPC to be retried forever.
-	sc.InjectError("CreateSession", status.Errorf(codes.Unavailable, "try later"))
-	// This session request will never finish until the injected error is cleared.
-	go sp.take(context.Background())
+	allowRequests := make(chan struct{})
+	mock.CreateSessionFn = func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+		select {
+		case <-allowRequests:
+			return mock.MockCloudSpannerClient.CreateSession(c, r, opts...)
+		default:
+			mock.MockCloudSpannerClient.ReceivedRequests <- r
+			return nil, status.Errorf(codes.Unavailable, "try later")
+		}
+	}
+
+	// This session request will never finish until the injected error is
+	// cleared.
+	go sp.take(ctx)
+
 	// Poll for the execution of the first session request.
 	for {
 		sp.mu.Lock()
@@ -430,46 +472,46 @@
 		// The first session request is being executed.
 		break
 	}
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+
+	ctx2, cancel := context.WithTimeout(ctx, time.Second)
 	defer cancel()
-	_, gotErr := sp.take(ctx)
+	_, gotErr := sp.take(ctx2)
+
 	// Since MaxBurst == 1, the second session request should block.
 	if wantErr := errGetSessionTimeout(); !testEqual(gotErr, wantErr) {
-		t.Errorf("session retrival returns error %v, want %v", gotErr, wantErr)
+		t.Fatalf("session retrival returns error %v, want %v", gotErr, wantErr)
 	}
+
 	// Let the first session request succeed.
-	sc.InjectError("CreateSession", nil)
+	close(allowRequests)
+
 	// Now new session request can proceed because the first session request will eventually succeed.
-	sh, err := sp.take(context.Background())
+	sh, err := sp.take(ctx)
 	if err != nil {
-		t.Errorf("session retrival returns error %v, want nil", err)
+		t.Fatalf("session retrival returns error %v, want nil", err)
 	}
 	if !sh.session.isValid() || sh.getID() == "" {
-		t.Errorf("got invalid session: %v", sh.session)
+		t.Fatalf("got invalid session: %v", sh.session)
 	}
 }
 
-// TestSessionrecycle tests recycling sessions.
+// TestSessionRecycle tests recycling sessions.
 func TestSessionRecycle(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
-
-	// Set MaxIdle to ensure shs[0] is not destroyed from scale down.
-	sp, _, cleanup := setup(t, SessionPoolConfig{MinOpened: 1, MaxIdle: 2})
+	ctx := context.Background()
+	_, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1, MaxIdle: 2})
 	defer cleanup()
 
 	// Test session is correctly recycled and reused.
 	for i := 0; i < 20; i++ {
-		s, err := sp.take(context.Background())
+		s, err := sp.take(ctx)
 		if err != nil {
-			t.Errorf("cannot get the session %v: %v", i, err)
+			t.Fatalf("cannot get the session %v: %v", i, err)
 		}
 		s.recycle()
 	}
 	if sp.numOpened != 1 {
-		t.Errorf("Expect session pool size %d, got %d", 1, sp.numOpened)
+		t.Fatalf("Expect session pool size %d, got %d", 1, sp.numOpened)
 	}
 }
 
@@ -478,24 +520,24 @@
 func TestSessionDestroy(t *testing.T) {
 	t.Skip("s.destroy(true) is flakey")
 	t.Parallel()
-
-	sp, _, cleanup := setup(t, SessionPoolConfig{MinOpened: 1})
+	ctx := context.Background()
+	_, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1})
 	defer cleanup()
 
 	<-time.After(10 * time.Millisecond) // maintainer will create one session, we wait for it create session to avoid flakiness in test
-	sh, err := sp.take(context.Background())
+	sh, err := sp.take(ctx)
 	if err != nil {
-		t.Errorf("cannot get session from session pool: %v", err)
+		t.Fatalf("cannot get session from session pool: %v", err)
 	}
 	s := sh.session
 	sh.recycle()
 	if d := s.destroy(true); d || !s.isValid() {
 		// Session should be remaining because of min open sessions constraint.
-		t.Errorf("session %s invalid, want it to stay alive. (destroy in expiration mode, success: %v)", s.id, d)
+		t.Fatalf("session %s invalid, want it to stay alive. (destroy in expiration mode, success: %v)", s.id, d)
 	}
 	if d := s.destroy(false); !d || s.isValid() {
 		// Session should be destroyed.
-		t.Errorf("failed to destroy session %s. (destroy in default mode, success: %v)", s.id, d)
+		t.Fatalf("failed to destroy session %s. (destroy in default mode, success: %v)", s.id, d)
 	}
 }
 
@@ -526,7 +568,7 @@
 		got := heap.Pop(&hh).(*session)
 		want[idx].hcIndex = -1
 		if !testEqual(got, want[idx]) {
-			t.Errorf("%v: heap.Pop returns %v, want %v", idx, got, want[idx])
+			t.Fatalf("%v: heap.Pop returns %v, want %v", idx, got, want[idx])
 		}
 	}
 }
@@ -535,17 +577,17 @@
 // perform healthchecks properly.
 func TestHealthCheckScheduler(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
-
-	sp, sc, cleanup := setup(t, SessionPoolConfig{})
+	ctx := context.Background()
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{
+		HealthCheckInterval:       50 * time.Millisecond,
+		healthCheckSampleInterval: 10 * time.Millisecond,
+	})
 	defer cleanup()
 
 	// Create 50 sessions.
 	ss := []string{}
 	for i := 0; i < 50; i++ {
-		sh, err := sp.take(context.Background())
+		sh, err := sp.take(ctx)
 		if err != nil {
 			t.Fatalf("cannot get session from session pool: %v", err)
 		}
@@ -554,7 +596,7 @@
 
 	// Wait for 10-30 pings per session.
 	waitFor(t, func() error {
-		dp := sc.DumpPings()
+		dp := mock.DumpPings()
 		gotPings := map[string]int64{}
 		for _, p := range dp {
 			gotPings[p]++
@@ -572,87 +614,92 @@
 
 // Tests that a fractions of sessions are prepared for write by health checker.
 func TestWriteSessionsPrepared(t *testing.T) {
-	if testing.Short() {
-		t.SkipNow()
-	}
-
-	sp, sc, cleanup := setup(t, SessionPoolConfig{WriteSessions: 0.5, MaxIdle: 20})
+	t.Parallel()
+	ctx := context.Background()
+	_, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{WriteSessions: 0.5, MaxIdle: 20})
 	defer cleanup()
 
-	sc.MakeNice()
 	shs := make([]*sessionHandle, 10)
 	var err error
 	for i := 0; i < 10; i++ {
-		shs[i], err = sp.take(context.Background())
+		shs[i], err = sp.take(ctx)
 		if err != nil {
-			t.Errorf("cannot get session from session pool: %v", err)
+			t.Fatalf("cannot get session from session pool: %v", err)
 		}
 	}
 	// Now there are 10 sessions in the pool. Release them.
 	for _, sh := range shs {
 		sh.recycle()
 	}
+
 	// Sleep for 1s, allowing healthcheck workers to invoke begin transaction.
+	// TODO(deklerk) get rid of this
 	<-time.After(time.Second)
 	wshs := make([]*sessionHandle, 5)
 	for i := 0; i < 5; i++ {
-		wshs[i], err = sp.takeWriteSession(context.Background())
+		wshs[i], err = sp.takeWriteSession(ctx)
 		if err != nil {
-			t.Errorf("cannot get session from session pool: %v", err)
+			t.Fatalf("cannot get session from session pool: %v", err)
 		}
 		if wshs[i].getTransactionID() == nil {
-			t.Errorf("got nil transaction id from session pool")
+			t.Fatalf("got nil transaction id from session pool")
 		}
 	}
 	for _, sh := range wshs {
 		sh.recycle()
 	}
+
+	// TODO(deklerk) get rid of this
 	<-time.After(time.Second)
+
 	// Now force creation of 10 more sessions.
 	shs = make([]*sessionHandle, 20)
 	for i := 0; i < 20; i++ {
-		shs[i], err = sp.take(context.Background())
+		shs[i], err = sp.take(ctx)
 		if err != nil {
-			t.Errorf("cannot get session from session pool: %v", err)
+			t.Fatalf("cannot get session from session pool: %v", err)
 		}
 	}
+
 	// Now there are 20 sessions in the pool. Release them.
 	for _, sh := range shs {
 		sh.recycle()
 	}
+
+	// TODO(deklerk) get rid of this
 	<-time.After(time.Second)
+
 	if sp.idleWriteList.Len() != 10 {
-		t.Errorf("Expect 10 write prepared session, got: %d", sp.idleWriteList.Len())
+		t.Fatalf("Expect 10 write prepared session, got: %d", sp.idleWriteList.Len())
 	}
 }
 
 // TestTakeFromWriteQueue tests that sessionPool.take() returns write prepared sessions as well.
 func TestTakeFromWriteQueue(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
-
-	sp, sc, cleanup := setup(t, SessionPoolConfig{MaxOpened: 1, WriteSessions: 1.0, MaxIdle: 1})
+	ctx := context.Background()
+	_, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MaxOpened: 1, WriteSessions: 1.0, MaxIdle: 1})
 	defer cleanup()
 
-	sc.MakeNice()
-	sh, err := sp.take(context.Background())
+	sh, err := sp.take(ctx)
 	if err != nil {
-		t.Errorf("cannot get session from session pool: %v", err)
+		t.Fatalf("cannot get session from session pool: %v", err)
 	}
 	sh.recycle()
+
+	// TODO(deklerk) get rid of this
 	<-time.After(time.Second)
+
 	// The session should now be in write queue but take should also return it.
 	if sp.idleWriteList.Len() == 0 {
-		t.Errorf("write queue unexpectedly empty")
+		t.Fatalf("write queue unexpectedly empty")
 	}
 	if sp.idleList.Len() != 0 {
-		t.Errorf("read queue not empty")
+		t.Fatalf("read queue not empty")
 	}
-	sh, err = sp.take(context.Background())
+	sh, err = sp.take(ctx)
 	if err != nil {
-		t.Errorf("cannot get session from session pool: %v", err)
+		t.Fatalf("cannot get session from session pool: %v", err)
 	}
 	sh.recycle()
 }
@@ -660,47 +707,63 @@
 // TestSessionHealthCheck tests healthchecking cases.
 func TestSessionHealthCheck(t *testing.T) {
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
-
-	sp, sc, cleanup := setup(t, SessionPoolConfig{})
+	ctx := context.Background()
+	_, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{
+		HealthCheckInterval:       50 * time.Millisecond,
+		healthCheckSampleInterval: 10 * time.Millisecond,
+	})
 	defer cleanup()
 
-	// Test pinging sessions.
-	sh, err := sp.take(context.Background())
-	if err != nil {
-		t.Errorf("cannot get session from session pool: %v", err)
+	var requestShouldErr int64 // 0 == false, 1 == true
+	mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+		if shouldErr := atomic.LoadInt64(&requestShouldErr); shouldErr == 1 {
+			mock.MockCloudSpannerClient.ReceivedRequests <- r
+			return nil, status.Errorf(codes.NotFound, "Session not found")
+		}
+		return mock.MockCloudSpannerClient.GetSession(c, r, opts...)
 	}
+
+	// Test pinging sessions.
+	sh, err := sp.take(ctx)
+	if err != nil {
+		t.Fatalf("cannot get session from session pool: %v", err)
+	}
+
 	// Wait for healthchecker to send pings to session.
 	waitFor(t, func() error {
-		pings := sc.DumpPings()
+		pings := mock.DumpPings()
 		if len(pings) == 0 || pings[0] != sh.getID() {
 			return fmt.Errorf("healthchecker didn't send any ping to session %v", sh.getID())
 		}
 		return nil
 	})
 	// Test broken session detection.
-	sh, err = sp.take(context.Background())
+	sh, err = sp.take(ctx)
 	if err != nil {
-		t.Errorf("cannot get session from session pool: %v", err)
+		t.Fatalf("cannot get session from session pool: %v", err)
 	}
-	sc.InjectError("GetSession", status.Errorf(codes.NotFound, "Session not found:"))
+
+	atomic.SwapInt64(&requestShouldErr, 1)
+
 	// Wait for healthcheck workers to find the broken session and tear it down.
+	// TODO(deklerk) get rid of this
 	<-time.After(1 * time.Second)
+
 	s := sh.session
 	if sh.session.isValid() {
-		t.Errorf("session(%v) is still alive, want it to be dropped by healthcheck workers", s)
+		t.Fatalf("session(%v) is still alive, want it to be dropped by healthcheck workers", s)
 	}
-	sc.InjectError("GetSession", nil)
+
+	atomic.SwapInt64(&requestShouldErr, 0)
+
 	// Test garbage collection.
-	sh, err = sp.take(context.Background())
+	sh, err = sp.take(ctx)
 	if err != nil {
-		t.Errorf("cannot get session from session pool: %v", err)
+		t.Fatalf("cannot get session from session pool: %v", err)
 	}
 	sp.close()
 	if sh.session.isValid() {
-		t.Errorf("session(%v) is still alive, want it to be garbage collected", s)
+		t.Fatalf("session(%v) is still alive, want it to be garbage collected", s)
 	}
 }
 
@@ -715,13 +778,11 @@
 // and it is expected that all sessions that are taken from session pool remains valid.
 // When all test workers and healthcheck workers exit, mockclient, session pool
 // and healthchecker should be in consistent state.
-
 func TestStressSessionPool(t *testing.T) {
 	t.Parallel()
+	ctx := context.Background()
+
 	// Use concurrent workers to test different session pool built from different configurations.
-	if testing.Short() {
-		t.SkipNow()
-	}
 	for ti, cfg := range []SessionPoolConfig{
 		{},
 		{MinOpened: 10, MaxOpened: 100},
@@ -735,7 +796,6 @@
 		cfg.healthCheckSampleInterval = 10 * time.Millisecond
 		cfg.HealthCheckWorkers = 50
 		sc := testutil.NewMockCloudSpannerClient(t)
-		sc.MakeNice()
 		cfg.getRPCClient = func() (sppb.SpannerClient, error) {
 			return sc, nil
 		}
@@ -761,9 +821,9 @@
 						gotErr error
 					)
 					if takeWrite {
-						sh, gotErr = pool.takeWriteSession(context.Background())
+						sh, gotErr = pool.takeWriteSession(ctx)
 					} else {
-						sh, gotErr = pool.take(context.Background())
+						sh, gotErr = pool.take(ctx)
 					}
 					if gotErr != nil {
 						if pool.isValid() {
@@ -806,28 +866,28 @@
 		for sl := sp.idleList.Front(); sl != nil; sl = sl.Next() {
 			s := sl.Value.(*session)
 			if idleSessions[s.getID()] {
-				t.Errorf("%v: found duplicated session in idle list: %v", ti, s.getID())
+				t.Fatalf("%v: found duplicated session in idle list: %v", ti, s.getID())
 			}
 			idleSessions[s.getID()] = true
 		}
 		for sl := sp.idleWriteList.Front(); sl != nil; sl = sl.Next() {
 			s := sl.Value.(*session)
 			if idleSessions[s.getID()] {
-				t.Errorf("%v: found duplicated session in idle write list: %v", ti, s.getID())
+				t.Fatalf("%v: found duplicated session in idle write list: %v", ti, s.getID())
 			}
 			idleSessions[s.getID()] = true
 		}
 		sp.mu.Lock()
 		if int(sp.numOpened) != len(idleSessions) {
-			t.Errorf("%v: number of opened sessions (%v) != number of idle sessions (%v)", ti, sp.numOpened, len(idleSessions))
+			t.Fatalf("%v: number of opened sessions (%v) != number of idle sessions (%v)", ti, sp.numOpened, len(idleSessions))
 		}
 		if sp.createReqs != 0 {
-			t.Errorf("%v: number of pending session creations = %v, want 0", ti, sp.createReqs)
+			t.Fatalf("%v: number of pending session creations = %v, want 0", ti, sp.createReqs)
 		}
 		// Dump healthcheck queue.
 		for _, s := range sp.hc.queue.sessions {
 			if hcSessions[s.getID()] {
-				t.Errorf("%v: found duplicated session in healthcheck queue: %v", ti, s.getID())
+				t.Fatalf("%v: found duplicated session in healthcheck queue: %v", ti, s.getID())
 			}
 			hcSessions[s.getID()] = true
 		}
@@ -835,15 +895,15 @@
 
 		// Verify that idleSessions == hcSessions == mockSessions.
 		if !testEqual(idleSessions, hcSessions) {
-			t.Errorf("%v: sessions in idle list (%v) != sessions in healthcheck queue (%v)", ti, idleSessions, hcSessions)
+			t.Fatalf("%v: sessions in idle list (%v) != sessions in healthcheck queue (%v)", ti, idleSessions, hcSessions)
 		}
 		if !testEqual(hcSessions, mockSessions) {
-			t.Errorf("%v: sessions in healthcheck queue (%v) != sessions in mockclient (%v)", ti, hcSessions, mockSessions)
+			t.Fatalf("%v: sessions in healthcheck queue (%v) != sessions in mockclient (%v)", ti, hcSessions, mockSessions)
 		}
 		sp.close()
 		mockSessions = sc.DumpSessions()
 		if len(mockSessions) != 0 {
-			t.Errorf("Found live sessions: %v", mockSessions)
+			t.Fatalf("Found live sessions: %v", mockSessions)
 		}
 	}
 }
@@ -858,15 +918,11 @@
 func TestMaintainer(t *testing.T) {
 	t.Skip("asserting session state seems flakey")
 	t.Parallel()
-	if testing.Short() {
-		t.SkipNow()
-	}
-	var (
-		minOpened uint64 = 5
-		maxIdle   uint64 = 4
-	)
+	ctx := context.Background()
 
-	sp, _, cleanup := setup(t, SessionPoolConfig{MinOpened: minOpened, MaxIdle: maxIdle})
+	minOpened := uint64(5)
+	maxIdle := uint64(4)
+	_, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: minOpened, MaxIdle: maxIdle})
 	defer cleanup()
 
 	sampleInterval := sp.SessionPoolConfig.healthCheckSampleInterval
@@ -886,14 +942,14 @@
 	shs := make([]*sessionHandle, 10)
 	for i := 0; i < len(shs); i++ {
 		var err error
-		shs[i], err = sp.take(context.Background())
+		shs[i], err = sp.take(ctx)
 		if err != nil {
-			t.Errorf("cannot get session from session pool: %v", err)
+			t.Fatalf("cannot get session from session pool: %v", err)
 		}
 	}
 	sp.mu.Lock()
 	if sp.numOpened != 10 {
-		t.Errorf("Scale out from normal use. Expect %d open, got %d", 10, sp.numOpened)
+		t.Fatalf("Scale out from normal use. Expect %d open, got %d", 10, sp.numOpened)
 	}
 	sp.mu.Unlock()
 
@@ -939,6 +995,7 @@
 }
 
 func waitFor(t *testing.T, assert func() error) {
+	t.Helper()
 	timeout := 15 * time.Second
 	ta := time.After(timeout)
 
diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go
index 9ac6c76..a710ecf 100644
--- a/spanner/transaction_test.go
+++ b/spanner/transaction_test.go
@@ -19,82 +19,100 @@
 import (
 	"context"
 	"errors"
+	"fmt"
+	"reflect"
 	"sync"
 	"testing"
 	"time"
 
 	"cloud.google.com/go/spanner/internal/testutil"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
+	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 )
 
-var (
-	errAbrt = spannerErrorf(codes.Aborted, "")
-	errUsr  = errors.New("error")
-)
+// Single can only be used once.
+func TestSingle(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
 
-// setup sets up a Client using mockclient
-func mockClient(t *testing.T) (*sessionPool, *testutil.MockCloudSpannerClient, *Client) {
-	var (
-		mc       = testutil.NewMockCloudSpannerClient(t)
-		spc      = SessionPoolConfig{}
-		database = "mockdb"
-	)
-	spc.getRPCClient = func() (sppb.SpannerClient, error) {
-		return mc, nil
+	txn := client.Single()
+	defer txn.Close()
+	_, _, e := txn.acquire(ctx)
+	if e != nil {
+		t.Fatalf("Acquire for single use, got %v, want nil.", e)
 	}
-	sp, err := newSessionPool(database, spc, nil)
-	if err != nil {
-		t.Fatalf("cannot create session pool: %v", err)
+	_, _, e = txn.acquire(ctx)
+	if wantErr := errTxClosed(); !testEqual(e, wantErr) {
+		t.Fatalf("Second acquire for single use, got %v, want %v.", e, wantErr)
 	}
-	return sp, mc, &Client{
-		database:     database,
-		idleSessions: sp,
+
+	// Only one CreateSessionRequest is sent.
+	if err := shouldHaveReceived(mock, []interface{}{&sppb.CreateSessionRequest{}}); err != nil {
+		t.Fatal(err)
 	}
 }
 
-// TestReadOnlyAcquire tests acquire for ReadOnlyTransaction.
-func TestReadOnlyAcquire(t *testing.T) {
+// Re-using ReadOnlyTransaction: can recover from acquire failure.
+func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) {
 	t.Parallel()
-	_, mc, client := mockClient(t)
-	defer client.Close()
-	mc.SetActions(
-		testutil.Action{"BeginTransaction", errUsr},
-		testutil.Action{"BeginTransaction", nil},
-		testutil.Action{"BeginTransaction", nil},
-	)
+	ctx := context.Background()
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
 
-	// Singleuse should only be used once.
-	txn := client.Single()
+	txn := client.ReadOnlyTransaction()
 	defer txn.Close()
-	_, _, e := txn.acquire(context.Background())
-	if e != nil {
-		t.Errorf("Acquire for single use, got %v, want nil.", e)
+
+	// First request will fail, which should trigger a retry.
+	errUsr := errors.New("error")
+	firstCall := true
+	mock.BeginTransactionFn = func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
+		if firstCall {
+			mock.MockCloudSpannerClient.ReceivedRequests <- r
+			firstCall = false
+			return nil, errUsr
+		}
+		return mock.MockCloudSpannerClient.BeginTransaction(c, r, opts...)
 	}
-	_, _, e = txn.acquire(context.Background())
-	if wantErr := errTxClosed(); !testEqual(e, wantErr) {
-		t.Errorf("Second acquire for single use, got %v, want %v.", e, wantErr)
-	}
-	// Multiuse can recover from acquire failure.
-	txn = client.ReadOnlyTransaction()
-	_, _, e = txn.acquire(context.Background())
+
+	_, _, e := txn.acquire(ctx)
 	if wantErr := toSpannerError(errUsr); !testEqual(e, wantErr) {
-		t.Errorf("Acquire for multi use, got %v, want %v.", e, wantErr)
+		t.Fatalf("Acquire for multi use, got %v, want %v.", e, wantErr)
 	}
-	_, _, e = txn.acquire(context.Background())
+	_, _, e = txn.acquire(ctx)
 	if e != nil {
-		t.Errorf("Acquire for multi use, got %v, want nil.", e)
+		t.Fatalf("Acquire for multi use, got %v, want nil.", e)
 	}
+}
+
+// ReadOnlyTransaction: can not be used after close.
+func TestReadOnlyTransaction_UseAfterClose(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	client, _, _, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
+
+	txn := client.ReadOnlyTransaction()
 	txn.Close()
-	// Multiuse can not be used after close.
-	_, _, e = txn.acquire(context.Background())
+
+	_, _, e := txn.acquire(ctx)
 	if wantErr := errTxClosed(); !testEqual(e, wantErr) {
-		t.Errorf("Second acquire for multi use, got %v, want %v.", e, wantErr)
+		t.Fatalf("Second acquire for multi use, got %v, want %v.", e, wantErr)
 	}
-	// Multiuse can be acquired concurrently.
-	txn = client.ReadOnlyTransaction()
+}
+
+// ReadOnlyTransaction: can be acquired concurrently.
+func TestReadOnlyTransaction_Concurrent(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
+	txn := client.ReadOnlyTransaction()
 	defer txn.Close()
-	mc.Freeze()
+
+	mock.Freeze()
 	var (
 		sh1 *sessionHandle
 		sh2 *sessionHandle
@@ -105,7 +123,7 @@
 	acquire := func(sh **sessionHandle, ts **sppb.TransactionSelector) {
 		defer wg.Done()
 		var e error
-		*sh, *ts, e = txn.acquire(context.Background())
+		*sh, *ts, e = txn.acquire(ctx)
 		if e != nil {
 			t.Errorf("Concurrent acquire for multiuse, got %v, expect nil.", e)
 		}
@@ -113,109 +131,214 @@
 	wg.Add(2)
 	go acquire(&sh1, &ts1)
 	go acquire(&sh2, &ts2)
+
+	// TODO(deklerk): Get rid of this.
 	<-time.After(100 * time.Millisecond)
-	mc.Unfreeze()
+
+	mock.Unfreeze()
 	wg.Wait()
-	if !testEqual(sh1.session, sh2.session) {
-		t.Errorf("Expect acquire to get same session handle, got %v and %v.", sh1, sh2)
+	if sh1.session.id != sh2.session.id {
+		t.Fatalf("Expected acquire to get same session handle, got %v and %v.", sh1, sh2)
 	}
 	if !testEqual(ts1, ts2) {
-		t.Errorf("Expect acquire to get same transaction selector, got %v and %v.", ts1, ts2)
+		t.Fatalf("Expected acquire to get same transaction selector, got %v and %v.", ts1, ts2)
 	}
 }
 
-// TestRetryOnAbort tests transaction retries on abort.
-func TestRetryOnAbort(t *testing.T) {
-	t.Parallel()
-	_, mc, client := mockClient(t)
-	defer client.Close()
-	// commit in writeOnlyTransaction
-	mc.SetActions(
-		testutil.Action{"Commit", errAbrt}, // abort on first commit
-		testutil.Action{"Commit", nil},
-	)
-
-	ms := []*Mutation{
-		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
-		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
-	}
-	if _, e := client.Apply(context.Background(), ms, ApplyAtLeastOnce()); e != nil {
-		t.Errorf("applyAtLeastOnce retry on abort, got %v, want nil.", e)
-	}
-	// begin and commit in ReadWriteTransaction
-	mc.SetActions(
-		testutil.Action{"BeginTransaction", nil},     // let takeWriteSession succeed and get a session handle
-		testutil.Action{"Commit", errAbrt},           // let first commit fail and retry will begin new transaction
-		testutil.Action{"BeginTransaction", errAbrt}, // this time we can fail the begin attempt
-		testutil.Action{"BeginTransaction", nil},
-		testutil.Action{"Commit", nil},
-	)
-
-	if _, e := client.Apply(context.Background(), ms); e != nil {
-		t.Errorf("ReadWriteTransaction retry on abort, got %v, want nil.", e)
-	}
-}
-
-// TestBadSession tests bad session (session not found error).
-// TODO: session closed from transaction close
-func TestBadSession(t *testing.T) {
+func TestApply_Single(t *testing.T) {
 	t.Parallel()
 	ctx := context.Background()
-	sp, mc, client := mockClient(t)
-	defer client.Close()
-	var sid string
-	// Prepare a session, get the session id for use in testing.
-	if s, e := sp.take(ctx); e != nil {
-		t.Fatal("Prepare session failed.")
-	} else {
-		sid = s.getID()
-		s.recycle()
-	}
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
 
-	wantErr := spannerErrorf(codes.NotFound, "Session not found: %v", sid)
-	// ReadOnlyTransaction
-	mc.SetActions(
-		testutil.Action{"BeginTransaction", wantErr},
-		testutil.Action{"BeginTransaction", wantErr},
-		testutil.Action{"BeginTransaction", wantErr},
-	)
-	txn := client.ReadOnlyTransaction()
-	defer txn.Close()
-	if _, _, got := txn.acquire(ctx); !testEqual(wantErr, got) {
-		t.Errorf("Expect acquire to fail, got %v, want %v.", got, wantErr)
-	}
-	// The failure should recycle the session, we expect it to be used in following requests.
-	if got := txn.Query(ctx, NewStatement("SELECT 1")); !testEqual(wantErr, got.err) {
-		t.Errorf("Expect Query to fail, got %v, want %v.", got.err, wantErr)
-	}
-	if got := txn.Read(ctx, "Users", KeySets(Key{"alice"}, Key{"bob"}), []string{"name", "email"}); !testEqual(wantErr, got.err) {
-		t.Errorf("Expect Read to fail, got %v, want %v.", got.err, wantErr)
-	}
-	// writeOnlyTransaction
 	ms := []*Mutation{
 		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
 		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
 	}
-	mc.SetActions(testutil.Action{"Commit", wantErr})
-	if _, got := client.Apply(context.Background(), ms, ApplyAtLeastOnce()); !testEqual(wantErr, got) {
-		t.Errorf("Expect applyAtLeastOnce to fail, got %v, want %v.", got, wantErr)
+	if _, e := client.Apply(ctx, ms, ApplyAtLeastOnce()); e != nil {
+		t.Fatalf("applyAtLeastOnce retry on abort, got %v, want nil.", e)
+	}
+
+	if err := shouldHaveReceived(mock, []interface{}{
+		&sppb.CreateSessionRequest{},
+		&sppb.CommitRequest{},
+	}); err != nil {
+		t.Fatal(err)
 	}
 }
 
-func TestFunctionErrorReturned(t *testing.T) {
+// Transaction retries on abort.
+func TestApply_RetryOnAbort(t *testing.T) {
+	ctx := context.Background()
 	t.Parallel()
-	_, mc, client := mockClient(t)
-	defer client.Close()
-	mc.SetActions(
-		testutil.Action{"BeginTransaction", nil},
-		testutil.Action{"Rollback", nil},
-	)
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
+
+	// First commit will fail, and the retry will begin a new transaction.
+	errAbrt := spannerErrorf(codes.Aborted, "")
+	firstCommitCall := true
+	mock.CommitFn = func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
+		if firstCommitCall {
+			mock.MockCloudSpannerClient.ReceivedRequests <- r
+			firstCommitCall = false
+			return nil, errAbrt
+		}
+		return mock.MockCloudSpannerClient.Commit(c, r, opts...)
+	}
+
+	ms := []*Mutation{
+		Insert("Accounts", []string{"AccountId"}, []interface{}{int64(1)}),
+	}
+
+	if _, e := client.Apply(ctx, ms); e != nil {
+		t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e)
+	}
+
+	if err := shouldHaveReceived(mock, []interface{}{
+		&sppb.CreateSessionRequest{},
+		&sppb.BeginTransactionRequest{},
+		&sppb.CommitRequest{}, // First commit fails.
+		&sppb.BeginTransactionRequest{},
+		&sppb.CommitRequest{}, // Second commit succeeds.
+	}); err != nil {
+		t.Fatal(err)
+	}
+}
+
+// Tests that NotFound errors cause failures, and aren't retried.
+func TestTransaction_NotFound(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
+
+	wantErr := spannerErrorf(codes.NotFound, "Session not found")
+	mock.BeginTransactionFn = func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
+		mock.MockCloudSpannerClient.ReceivedRequests <- r
+		return nil, wantErr
+	}
+	mock.CommitFn = func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
+		mock.MockCloudSpannerClient.ReceivedRequests <- r
+		return nil, wantErr
+	}
+
+	txn := client.ReadOnlyTransaction()
+	defer txn.Close()
+
+	if _, _, got := txn.acquire(ctx); !testEqual(wantErr, got) {
+		t.Fatalf("Expect acquire to fail, got %v, want %v.", got, wantErr)
+	}
+
+	// The failure should recycle the session, we expect it to be used in following requests.
+	if got := txn.Query(ctx, NewStatement("SELECT 1")); !testEqual(wantErr, got.err) {
+		t.Fatalf("Expect Query to fail, got %v, want %v.", got.err, wantErr)
+	}
+
+	if got := txn.Read(ctx, "Users", KeySets(Key{"alice"}, Key{"bob"}), []string{"name", "email"}); !testEqual(wantErr, got.err) {
+		t.Fatalf("Expect Read to fail, got %v, want %v.", got.err, wantErr)
+	}
+
+	ms := []*Mutation{
+		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
+		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
+	}
+	if _, got := client.Apply(ctx, ms, ApplyAtLeastOnce()); !testEqual(wantErr, got) {
+		t.Fatalf("Expect Apply to fail, got %v, want %v.", got, wantErr)
+	}
+}
+
+// When an error is returned from the closure sent into ReadWriteTransaction, it
+// kicks off a rollback.
+func TestReadWriteTransaction_ErrorReturned(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
 
 	want := errors.New("an error")
-	_, got := client.ReadWriteTransaction(context.Background(),
-		func(context.Context, *ReadWriteTransaction) error { return want })
+	_, got := client.ReadWriteTransaction(ctx, func(context.Context, *ReadWriteTransaction) error {
+		return want
+	})
 	if got != want {
-		t.Errorf("got <%v>, want <%v>", got, want)
+		t.Fatalf("got %+v, want %+v", got, want)
 	}
-	mc.CheckActionsConsumed()
+	if err := shouldHaveReceived(mock, []interface{}{
+		&sppb.CreateSessionRequest{},
+		&sppb.BeginTransactionRequest{},
+		&sppb.RollbackRequest{},
+	}); err != nil {
+		t.Fatal(err)
+	}
+}
+
+// shouldHaveReceived asserts that exactly expectedRequests were present in
+// the server's ReceivedRequests channel. It only looks at type, not contents.
+//
+// Note: this in-place modifies serverClientMock by popping items off the
+// ReceivedRequests channel.
+func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) error {
+	got := drainRequests(mock)
+
+	if len(got) != len(want) {
+		var gotMsg string
+		for _, r := range got {
+			gotMsg += fmt.Sprintf("%v: %+v]\n", reflect.TypeOf(r), r)
+		}
+
+		var wantMsg string
+		for _, r := range want {
+			wantMsg += fmt.Sprintf("%v: %+v]\n", reflect.TypeOf(r), r)
+		}
+
+		return fmt.Errorf("got %d requests, want %d requests:\ngot:\n%s\nwant:\n%s", len(got), len(want), gotMsg, wantMsg)
+	}
+
+	for i, want := range want {
+		if reflect.TypeOf(got[i]) != reflect.TypeOf(want) {
+			return fmt.Errorf("request %d: got %+v, want %+v", i, reflect.TypeOf(got[i]), reflect.TypeOf(want))
+		}
+	}
+
+	return nil
+}
+
+func drainRequests(mock *testutil.FuncMock) []interface{} {
+	var reqs []interface{}
+loop:
+	for {
+		select {
+		case req := <-mock.ReceivedRequests:
+			reqs = append(reqs, req)
+		default:
+			break loop
+		}
+	}
+	return reqs
+}
+
+// serverClientMock sets up a client configured to a NewMockCloudSpannerClient
+// that is wrapped with a function-injectable wrapper.
+//
+// Note: be sure to call cleanup!
+func serverClientMock(t *testing.T, spc SessionPoolConfig) (_ *Client, _ *sessionPool, _ *testutil.FuncMock, cleanup func()) {
+	rawServerStub := testutil.NewMockCloudSpannerClient(t)
+	serverClientMock := testutil.FuncMock{MockCloudSpannerClient: rawServerStub}
+	spc.getRPCClient = func() (sppb.SpannerClient, error) {
+		return &serverClientMock, nil
+	}
+	db := "mockdb"
+	sp, err := newSessionPool(db, spc, nil)
+	if err != nil {
+		t.Fatalf("cannot create session pool: %v", err)
+	}
+	client := Client{
+		database:     db,
+		idleSessions: sp,
+	}
+	cleanup = func() {
+		client.Close()
+		sp.hc.close()
+		sp.close()
+	}
+	return &client, sp, &serverClientMock, cleanup
 }
diff --git a/spanner/value_test.go b/spanner/value_test.go
index 4e11422..34207ed 100644
--- a/spanner/value_test.go
+++ b/spanner/value_test.go
@@ -17,14 +17,13 @@
 package spanner
 
 import (
-	"context"
 	"math"
 	"reflect"
 	"testing"
 	"time"
 
 	"cloud.google.com/go/civil"
-	proto "github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/proto"
 	proto3 "github.com/golang/protobuf/ptypes/struct"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
 )
@@ -1394,171 +1393,3 @@
 		}
 	}
 }
-
-func TestStructParametersBind(t *testing.T) {
-	t.Parallel()
-	ctx := context.Background()
-	client, _, cleanup := prepare(ctx, t, nil)
-	defer cleanup()
-
-	type tRow []interface{}
-	type tRows []struct{ trow tRow }
-
-	type allFields struct {
-		Stringf string
-		Intf    int
-		Boolf   bool
-		Floatf  float64
-		Bytef   []byte
-		Timef   time.Time
-		Datef   civil.Date
-	}
-	allColumns := []string{
-		"Stringf",
-		"Intf",
-		"Boolf",
-		"Floatf",
-		"Bytef",
-		"Timef",
-		"Datef",
-	}
-	s1 := allFields{"abc", 300, false, 3.45, []byte("foo"), t1, d1}
-	s2 := allFields{"def", -300, false, -3.45, []byte("bar"), t2, d2}
-
-	dynamicStructType := reflect.StructOf([]reflect.StructField{
-		{Name: "A", Type: reflect.TypeOf(t1), Tag: `spanner:"ff1"`},
-	})
-	s3 := reflect.New(dynamicStructType)
-	s3.Elem().Field(0).Set(reflect.ValueOf(t1))
-
-	for i, test := range []struct {
-		param interface{}
-		sql   string
-		cols  []string
-		trows tRows
-	}{
-		// Struct value.
-		{
-			s1,
-			"SELECT" +
-				" @p.Stringf," +
-				" @p.Intf," +
-				" @p.Boolf," +
-				" @p.Floatf," +
-				" @p.Bytef," +
-				" @p.Timef," +
-				" @p.Datef",
-			allColumns,
-			tRows{
-				{tRow{"abc", 300, false, 3.45, []byte("foo"), t1, d1}},
-			},
-		},
-		// Array of struct value.
-		{
-			[]allFields{s1, s2},
-			"SELECT * FROM UNNEST(@p)",
-			allColumns,
-			tRows{
-				{tRow{"abc", 300, false, 3.45, []byte("foo"), t1, d1}},
-				{tRow{"def", -300, false, -3.45, []byte("bar"), t2, d2}},
-			},
-		},
-		// Null struct.
-		{
-			(*allFields)(nil),
-			"SELECT @p IS NULL",
-			[]string{""},
-			tRows{
-				{tRow{true}},
-			},
-		},
-		// Null Array of struct.
-		{
-			[]allFields(nil),
-			"SELECT @p IS NULL",
-			[]string{""},
-			tRows{
-				{tRow{true}},
-			},
-		},
-		// Empty struct.
-		{
-			struct{}{},
-			"SELECT @p IS NULL ",
-			[]string{""},
-			tRows{
-				{tRow{false}},
-			},
-		},
-		// Empty array of struct.
-		{
-			[]allFields{},
-			"SELECT * FROM UNNEST(@p) ",
-			allColumns,
-			tRows{},
-		},
-		// Struct with duplicate fields.
-		{
-			struct {
-				A int `spanner:"field"`
-				B int `spanner:"field"`
-			}{10, 20},
-			"SELECT * FROM UNNEST([@p]) ",
-			[]string{"field", "field"},
-			tRows{
-				{tRow{10, 20}},
-			},
-		},
-		// Struct with unnamed fields.
-		{
-			struct {
-				A string `spanner:""`
-			}{"hello"},
-			"SELECT * FROM UNNEST([@p]) ",
-			[]string{""},
-			tRows{
-				{tRow{"hello"}},
-			},
-		},
-		// Mixed struct.
-		{
-			struct {
-				DynamicStructField interface{}  `spanner:"f1"`
-				ArrayStructField   []*allFields `spanner:"f2"`
-			}{
-				DynamicStructField: s3.Interface(),
-				ArrayStructField:   []*allFields{nil},
-			},
-			"SELECT @p.f1.ff1, ARRAY_LENGTH(@p.f2), @p.f2[OFFSET(0)] IS NULL ",
-			[]string{"ff1", "", ""},
-			tRows{
-				{tRow{t1, 1, true}},
-			},
-		},
-	} {
-		iter := client.Single().Query(ctx, Statement{
-			SQL:    test.sql,
-			Params: map[string]interface{}{"p": test.param},
-		})
-		var gotRows []*Row
-		err := iter.Do(func(r *Row) error {
-			gotRows = append(gotRows, r)
-			return nil
-		})
-		if err != nil {
-			t.Errorf("Failed to execute test case %d, error: %v", i, err)
-		}
-
-		var wantRows []*Row
-		for j, row := range test.trows {
-			r, err := NewRow(test.cols, row.trow)
-			if err != nil {
-				t.Errorf("Invalid row %d in test case %d", j, i)
-			}
-			wantRows = append(wantRows, r)
-		}
-		if !testEqual(gotRows, wantRows) {
-			t.Errorf("%d: Want result %v, got result %v", i, wantRows, gotRows)
-		}
-	}
-}