spanner: use BatchCreateSessions to init pool

Adds a session client that is responsible for creating
sessions for the session pool and others that need a
session. The session client encapsulates the details of
spreading all sessions over all available gRPC channels.

The session client also contains functionality for creating
batches of sessions that can be used by the session pool to
speed up initialization of a large session pool or to create
a burst of sessions. These batches are also automatically
evenly distributed over all available channels.

The default MinOpened sessions configuration is changed to
100. The session pool uses the BatchCreateSessions method
to initialize the pool if MinOpened > 0.

Updates #1566.

Change-Id: I13e6fbc321688cdbd396913e4f4c01aa8631fb2c
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/45111
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/spanner/client.go b/spanner/client.go
index 1f15cb0..2a30b8e 100644
--- a/spanner/client.go
+++ b/spanner/client.go
@@ -21,7 +21,6 @@
 	"fmt"
 	"os"
 	"regexp"
-	"sync/atomic"
 	"time"
 
 	"cloud.google.com/go/internal/trace"
@@ -64,16 +63,8 @@
 // Client is a client for reading and writing data to a Cloud Spanner database.
 // A client is safe to use concurrently, except for its Close method.
 type Client struct {
-	// rr must be accessed through atomic operations.
-	rr      uint32
-	clients []*vkit.Client
-
-	database string
-	// Metadata to be sent with each request.
-	md           metadata.MD
+	sc           *sessionClient
 	idleSessions *sessionPool
-	// sessionLabels for the sessions created by this client.
-	sessionLabels map[string]string
 }
 
 // ClientConfig has configurations for the client.
@@ -110,23 +101,12 @@
 // form projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID. It uses
 // a default configuration.
 func NewClient(ctx context.Context, database string, opts ...option.ClientOption) (*Client, error) {
-	return NewClientWithConfig(ctx, database, ClientConfig{}, opts...)
+	return NewClientWithConfig(ctx, database, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, opts...)
 }
 
 // NewClientWithConfig creates a client to a database. A valid database name has
 // the form projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID.
 func NewClientWithConfig(ctx context.Context, database string, config ClientConfig, opts ...option.ClientOption) (c *Client, err error) {
-	c = &Client{
-		database: database,
-		md:       metadata.Pairs(resourcePrefixHeader, database),
-	}
-
-	// Make a copy of labels.
-	c.sessionLabels = make(map[string]string)
-	for k, v := range config.SessionLabels {
-		c.sessionLabels[k] = v
-	}
-
 	// Prepare gRPC channels.
 	if config.NumChannels == 0 {
 		config.NumChannels = numChannels
@@ -137,7 +117,7 @@
 		config.MaxOpened = uint64(config.NumChannels * 100)
 	}
 	if config.MaxBurst == 0 {
-		config.MaxBurst = 10
+		config.MaxBurst = DefaultSessionPoolConfig.MaxBurst
 	}
 
 	// Validate database path.
@@ -174,43 +154,44 @@
 	// TODO(deklerk): This should be replaced with a balancer with
 	// config.NumChannels connections, instead of config.NumChannels
 	// clients.
+	var clients []*vkit.Client
 	for i := 0; i < config.NumChannels; i++ {
 		client, err := vkit.NewClient(ctx, allOpts...)
 		if err != nil {
 			return nil, errDial(i, err)
 		}
-		c.clients = append(c.clients, client)
+		clients = append(clients, client)
 	}
 
-	// Prepare session pool.
-	// TODO: support more loadbalancing options.
-	config.SessionPoolConfig.getRPCClient = func() (*vkit.Client, error) {
-		return c.rrNext(), nil
+	// TODO(loite): Remove as the original map cannot be changed by the user
+	// anyways, and the client library is also not changing it.
+	// Make a copy of labels.
+	sessionLabels := make(map[string]string)
+	for k, v := range config.SessionLabels {
+		sessionLabels[k] = v
 	}
-	config.SessionPoolConfig.sessionLabels = c.sessionLabels
-	sp, err := newSessionPool(database, config.SessionPoolConfig, c.md)
+	// Create a session client.
+	sc := newSessionClient(clients, database, sessionLabels, metadata.Pairs(resourcePrefixHeader, database))
+	// Create a session pool.
+	config.SessionPoolConfig.sessionLabels = sessionLabels
+	sp, err := newSessionPool(sc, config.SessionPoolConfig)
 	if err != nil {
-		c.Close()
+		sc.close()
 		return nil, err
 	}
-	c.idleSessions = sp
+	c = &Client{
+		sc:           sc,
+		idleSessions: sp,
+	}
 	return c, nil
 }
 
-// rrNext returns the next available vkit Cloud Spanner RPC client in a
-// round-robin manner.
-func (c *Client) rrNext() *vkit.Client {
-	return c.clients[atomic.AddUint32(&c.rr, 1)%uint32(len(c.clients))]
-}
-
 // Close closes the client.
 func (c *Client) Close() {
 	if c.idleSessions != nil {
 		c.idleSessions.close()
 	}
-	for _, gpc := range c.clients {
-		gpc.Close()
-	}
+	c.sc.close()
 }
 
 // Single provides a read-only snapshot transaction optimized for the case
@@ -273,8 +254,7 @@
 	}()
 
 	// Create session.
-	sc := c.rrNext()
-	s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md)
+	s, err = c.sc.createSession(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -318,8 +298,7 @@
 // BatchReadOnlyTransactionFromID reconstruct a BatchReadOnlyTransaction from
 // BatchReadOnlyTransactionID
 func (c *Client) BatchReadOnlyTransactionFromID(tid BatchReadOnlyTransactionID) *BatchReadOnlyTransaction {
-	sc := c.rrNext()
-	s := &session{valid: true, client: sc, id: tid.sid, createTime: time.Now(), md: c.md}
+	s := c.sc.sessionWithID(tid.sid)
 	sh := &sessionHandle{session: s}
 
 	t := &BatchReadOnlyTransaction{
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index 6e1384b..70c19c9 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -213,6 +213,32 @@
 	}
 }
 
+func TestIntegration_InitSessionPool(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	// Set up testing environment.
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
+	defer cleanup()
+	sp := client.idleSessions
+	sp.mu.Lock()
+	want := sp.MinOpened
+	sp.mu.Unlock()
+	var numOpened int
+	for {
+		select {
+		case <-ctx.Done():
+			t.Fatalf("timed out, got %d session(s), want %d", numOpened, want)
+		default:
+			sp.mu.Lock()
+			numOpened = sp.idleList.Len() + sp.idleWriteList.Len()
+			sp.mu.Unlock()
+			if uint64(numOpened) == want {
+				return
+			}
+		}
+	}
+}
+
 // Test SingleUse transaction.
 func TestIntegration_SingleUse(t *testing.T) {
 	t.Parallel()
@@ -220,7 +246,7 @@
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	writes := []struct {
@@ -420,7 +446,7 @@
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	writes := []struct {
@@ -463,7 +489,7 @@
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	writes := []struct {
@@ -649,7 +675,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	for i, tb := range []TimestampBound{
@@ -682,7 +708,7 @@
 	// Give a longer deadline because of transaction backoffs.
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	// Set up two accounts
@@ -773,7 +799,7 @@
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements)
 	defer cleanup()
 
 	// Includes k0..k14. Strings sort lexically, eg "k1" < "k10" < "k2".
@@ -841,7 +867,7 @@
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
 	// Set up testing environment.
-	client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements)
 	defer cleanup()
 
 	var ms []*Mutation
@@ -887,7 +913,7 @@
 	// You cannot use a transaction from inside a read-write transaction.
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
@@ -919,7 +945,9 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	// Create a client with MinOpened=0 to prevent the session pool maintainer
+	// from repeatedly trying to create sessions for the invalid database.
+	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, SessionPoolConfig{}, singerDBStatements)
 	defer cleanup()
 
 	// Drop the testing database.
@@ -970,7 +998,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
@@ -1118,7 +1146,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	tests := []struct {
@@ -1205,7 +1233,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, nil)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, nil)
 	defer cleanup()
 
 	for _, test := range []struct {
@@ -1250,7 +1278,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, nil)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, nil)
 	defer cleanup()
 
 	newRow := func(vals []interface{}) *Row {
@@ -1306,7 +1334,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	accounts := []*Mutation{
@@ -1348,12 +1376,12 @@
 func TestIntegration_InvalidDatabase(t *testing.T) {
 	t.Parallel()
 
-	if testProjectID == "" {
-		t.Skip("Integration tests skipped: GCLOUD_TESTS_GOLANG_PROJECT_ID is missing")
+	if databaseAdmin == nil {
+		t.Skip("Integration tests skipped")
 	}
 	ctx := context.Background()
 	dbPath := fmt.Sprintf("projects/%v/instances/%v/databases/invalid", testProjectID, testInstanceID)
-	c, err := createClient(ctx, dbPath)
+	c, err := createClient(ctx, dbPath, SessionPoolConfig{})
 	// Client creation should succeed even if the database is invalid.
 	if err != nil {
 		t.Fatal(err)
@@ -1369,7 +1397,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements)
 	defer cleanup()
 
 	// Read over invalid table fails
@@ -1415,7 +1443,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	// Test 1: User error should abort the transaction.
@@ -1556,13 +1584,13 @@
 	)
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements)
+	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements)
 	defer cleanup()
 
 	if err = populate(ctx, client); err != nil {
 		t.Fatal(err)
 	}
-	if client2, err = createClient(ctx, dbPath); err != nil {
+	if client2, err = createClient(ctx, dbPath, SessionPoolConfig{}); err != nil {
 		t.Fatal(err)
 	}
 	defer client2.Close()
@@ -1642,13 +1670,13 @@
 	)
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements)
+	client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements)
 	defer cleanup()
 
 	if err = populate(ctx, client); err != nil {
 		t.Fatal(err)
 	}
-	if client2, err = createClient(ctx, dbPath); err != nil {
+	if client2, err = createClient(ctx, dbPath, SessionPoolConfig{}); err != nil {
 		t.Fatal(err)
 	}
 	defer client2.Close()
@@ -1729,7 +1757,7 @@
 	)
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements)
 	defer cleanup()
 
 	if txn, err = client.BatchReadOnlyTransaction(ctx, StrongRead()); err != nil {
@@ -1758,7 +1786,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, ctsDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, ctsDBStatements)
 	defer cleanup()
 
 	type testTableRow struct {
@@ -1828,7 +1856,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	// Function that reads a single row's first name from within a transaction.
@@ -1995,7 +2023,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, nil)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, nil)
 	defer cleanup()
 
 	type tRow []interface{}
@@ -2165,7 +2193,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	columns := []string{"SingerId", "FirstName", "LastName"}
@@ -2219,7 +2247,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	columns := []string{"SingerId", "FirstName", "LastName"}
@@ -2272,7 +2300,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
@@ -2296,7 +2324,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	columns := []string{"SingerId", "FirstName", "LastName"}
@@ -2347,7 +2375,7 @@
 
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
-	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
 	defer cleanup()
 
 	columns := []string{"SingerId", "FirstName", "LastName"}
@@ -2399,7 +2427,7 @@
 }
 
 // Prepare initializes Cloud Spanner testing DB and clients.
-func prepareIntegrationTest(ctx context.Context, t *testing.T, statements []string) (*Client, string, func()) {
+func prepareIntegrationTest(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string) (*Client, string, func()) {
 	if databaseAdmin == nil {
 		t.Skip("Integration tests skipped")
 	}
@@ -2419,7 +2447,7 @@
 	if _, err := op.Wait(ctx); err != nil {
 		t.Fatalf("cannot create testing DB %v: %v", dbPath, err)
 	}
-	client, err := createClient(ctx, dbPath)
+	client, err := createClient(ctx, dbPath, spc)
 	if err != nil {
 		t.Fatalf("cannot create data client on DB %v: %v", dbPath, err)
 	}
@@ -2563,9 +2591,9 @@
 }
 
 // createClient creates Cloud Spanner data client.
-func createClient(ctx context.Context, dbPath string) (client *Client, err error) {
+func createClient(ctx context.Context, dbPath string, spc SessionPoolConfig) (client *Client, err error) {
 	client, err = NewClientWithConfig(ctx, dbPath, ClientConfig{
-		SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.2},
+		SessionPoolConfig: spc,
 	}, option.WithTokenSource(testutil.TokenSource(ctx, Scope)), option.WithEndpoint(endpoint))
 	if err != nil {
 		return nil, fmt.Errorf("cannot create data client on DB %v: %v", dbPath, err)
diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go
index 965e5c2..8c61acc 100644
--- a/spanner/internal/testutil/inmem_spanner_server.go
+++ b/spanner/internal/testutil/inmem_spanner_server.go
@@ -56,6 +56,7 @@
 const (
 	MethodBeginTransaction    string = "BEGIN_TRANSACTION"
 	MethodCommitTransaction   string = "COMMIT_TRANSACTION"
+	MethodBatchCreateSession  string = "BATCH_CREATE_SESSION"
 	MethodCreateSession       string = "CREATE_SESSION"
 	MethodDeleteSession       string = "DELETE_SESSION"
 	MethodGetSession          string = "GET_SESSION"
@@ -206,6 +207,8 @@
 
 	TotalSessionsCreated() uint
 	TotalSessionsDeleted() uint
+	SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32)
+	SetMaxSessionsReturnedByServerInTotal(sessionCount int32)
 
 	ReceivedRequests() chan interface{}
 	DumpSessions() map[string]bool
@@ -249,7 +252,10 @@
 
 	totalSessionsCreated uint
 	totalSessionsDeleted uint
-	receivedRequests     chan interface{}
+	// The maximum number of sessions that will be created per batch request.
+	maxSessionsReturnedByServerPerBatchRequest int32
+	maxSessionsReturnedByServerInTotal         int32
+	receivedRequests                           chan interface{}
 	// Session ping history.
 	pings []string
 
@@ -362,6 +368,18 @@
 	return s.totalSessionsDeleted
 }
 
+func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	s.maxSessionsReturnedByServerPerBatchRequest = sessionCount
+}
+
+func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	s.maxSessionsReturnedByServerInTotal = sessionCount
+}
+
 func (s *inMemSpannerServer) ReceivedRequests() chan interface{} {
 	return s.receivedRequests
 }
@@ -393,6 +411,7 @@
 
 func (s *inMemSpannerServer) initDefaults() {
 	s.sessionCounter = 0
+	s.maxSessionsReturnedByServerPerBatchRequest = 100
 	s.sessions = make(map[string]*spannerpb.Session)
 	s.sessionLastUseTime = make(map[string]time.Time)
 	s.transactions = make(map[string]*spannerpb.Transaction)
@@ -401,9 +420,7 @@
 	s.transactionCounters = make(map[string]*uint64)
 }
 
-func (s *inMemSpannerServer) generateSessionName(database string) string {
-	s.mu.Lock()
-	defer s.mu.Unlock()
+func (s *inMemSpannerServer) generateSessionNameLocked(database string) string {
 	s.sessionCounter++
 	return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter)
 }
@@ -524,13 +541,16 @@
 		}
 		totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
 		<-time.After(totalExecutionTime)
+		s.mu.Lock()
 		if executionTime.Errors != nil && len(executionTime.Errors) > 0 {
 			err := executionTime.Errors[0]
 			if !executionTime.KeepError {
 				executionTime.Errors = executionTime.Errors[1:]
 			}
+			s.mu.Unlock()
 			return err
 		}
+		s.mu.Unlock()
 	}
 	return nil
 }
@@ -542,16 +562,52 @@
 	if req.Database == "" {
 		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
 	}
-	sessionName := s.generateSessionName(req.Database)
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal {
+		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
+	}
+	sessionName := s.generateSessionNameLocked(req.Database)
 	ts := getCurrentTimestamp()
 	session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
-	s.mu.Lock()
 	s.totalSessionsCreated++
 	s.sessions[sessionName] = session
-	s.mu.Unlock()
 	return session, nil
 }
 
+func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
+	if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
+		return nil, err
+	}
+	if req.Database == "" {
+		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
+	}
+	if req.SessionCount <= 0 {
+		return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0")
+	}
+	sessionsToCreate := req.SessionCount
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal {
+		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
+	}
+	if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest {
+		sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest
+	}
+	if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal {
+		sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions))
+	}
+	sessions := make([]*spannerpb.Session, sessionsToCreate)
+	for i := int32(0); i < sessionsToCreate; i++ {
+		sessionName := s.generateSessionNameLocked(req.Database)
+		ts := getCurrentTimestamp()
+		sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
+		s.totalSessionsCreated++
+		s.sessions[sessionName] = sessions[i]
+	}
+	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
+}
+
 func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
 	if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
 		return nil, err
diff --git a/spanner/oc_test.go b/spanner/oc_test.go
index 7e95713..3e3d9f8 100644
--- a/spanner/oc_test.go
+++ b/spanner/oc_test.go
@@ -33,7 +33,10 @@
 	ms := stestutil.NewMockCloudSpanner(t, trxTs)
 	ms.Serve()
 	ctx := context.Background()
-	c, err := NewClient(ctx, "projects/P/instances/I/databases/D",
+	c, err := NewClientWithConfig(ctx, "projects/P/instances/I/databases/D",
+		ClientConfig{SessionPoolConfig: SessionPoolConfig{
+			MinOpened: 0,
+		}},
 		option.WithEndpoint(ms.Addr()),
 		option.WithGRPCDialOption(grpc.WithInsecure()),
 		option.WithoutAuthentication())
diff --git a/spanner/pdml.go b/spanner/pdml.go
index 242a48e..6f160a2 100644
--- a/spanner/pdml.go
+++ b/spanner/pdml.go
@@ -41,9 +41,8 @@
 		s  *session
 		sh *sessionHandle
 	)
-	// Create a session that will be used only for this request.
-	sc := c.rrNext()
-	s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md)
+	// Create session.
+	s, err = c.sc.createSession(ctx)
 	if err != nil {
 		return 0, toSpannerError(err)
 	}
diff --git a/spanner/session.go b/spanner/session.go
index 20329c3..3010b44 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -22,6 +22,7 @@
 	"context"
 	"fmt"
 	"log"
+	"math"
 	"math/rand"
 	"strings"
 	"sync"
@@ -185,7 +186,7 @@
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
 	defer cancel()
 	// s.getID is safe even when s is invalid.
-	_, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()})
+	_, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.md), &sppb.GetSessionRequest{Name: s.getID()})
 	return err
 }
 
@@ -303,7 +304,7 @@
 	if s.isWritePrepared() {
 		return nil
 	}
-	tx, err := beginTransaction(ctx, s.getID(), s.client)
+	tx, err := beginTransaction(contextWithOutgoingMetadata(ctx, s.md), s.getID(), s.client)
 	if err != nil {
 		return err
 	}
@@ -313,10 +314,6 @@
 
 // SessionPoolConfig stores configurations of a session pool.
 type SessionPoolConfig struct {
-	// getRPCClient is the caller supplied method for getting a gRPC client to
-	// Cloud Spanner, this makes session pool able to use client pooling.
-	getRPCClient func() (*vkit.Client, error)
-
 	// MaxOpened is the maximum number of opened sessions allowed by the session
 	// pool. If the client tries to open a session and there are already
 	// MaxOpened sessions, it will block until one becomes available or the
@@ -332,7 +329,7 @@
 	// therefore it is posssible that the number of opened sessions drops below
 	// MinOpened.
 	//
-	// Defaults to 0.
+	// Defaults to 100.
 	MinOpened uint64
 
 	// MaxIdle is the maximum number of idle sessions, pool is allowed to keep.
@@ -348,7 +345,7 @@
 	// WriteSessions is the fraction of sessions we try to keep prepared for
 	// write.
 	//
-	// Defaults to 0.
+	// Defaults to 0.2.
 	WriteSessions float64
 
 	// HealthCheckWorkers is number of workers used by health checker for this
@@ -372,25 +369,59 @@
 	sessionLabels map[string]string
 }
 
-// errNoRPCGetter returns error for SessionPoolConfig missing getRPCClient method.
-func errNoRPCGetter() error {
-	return spannerErrorf(codes.InvalidArgument, "require SessionPoolConfig.getRPCClient != nil, got nil")
+// DefaultSessionPoolConfig is the default configuration for the session pool
+// that will be used for a Spanner client, unless the user supplies a specific
+// session pool config.
+var DefaultSessionPoolConfig = SessionPoolConfig{
+	MinOpened:           100,
+	MaxOpened:           numChannels * 100,
+	MaxBurst:            10,
+	WriteSessions:       0.2,
+	HealthCheckWorkers:  10,
+	HealthCheckInterval: 5 * time.Minute,
 }
 
 // errMinOpenedGTMapOpened returns error for SessionPoolConfig.MaxOpened < SessionPoolConfig.MinOpened when SessionPoolConfig.MaxOpened is set.
 func errMinOpenedGTMaxOpened(maxOpened, minOpened uint64) error {
 	return spannerErrorf(codes.InvalidArgument,
-		"require SessionPoolConfig.MaxOpened >= SessionPoolConfig.MinOpened, got %v and %v", maxOpened, minOpened)
+		"require SessionPoolConfig.MaxOpened >= SessionPoolConfig.MinOpened, got %d and %d", maxOpened, minOpened)
+}
+
+// errWriteFractionOutOfRange returns error for
+// SessionPoolConfig.WriteFraction < 0 or SessionPoolConfig.WriteFraction > 1
+func errWriteFractionOutOfRange(writeFraction float64) error {
+	return spannerErrorf(codes.InvalidArgument,
+		"require SessionPoolConfig.WriteSessions >= 0.0 && SessionPoolConfig.WriteSessions <= 1.0, got %.2f", writeFraction)
+}
+
+// errHealthCheckWorkersNegative returns error for
+// SessionPoolConfig.HealthCheckWorkers < 0
+func errHealthCheckWorkersNegative(workers int) error {
+	return spannerErrorf(codes.InvalidArgument,
+		"require SessionPoolConfig.HealthCheckWorkers >= 0, got %d", workers)
+}
+
+// errHealthCheckIntervalNegative returns error for
+// SessionPoolConfig.HealthCheckInterval < 0
+func errHealthCheckIntervalNegative(interval time.Duration) error {
+	return spannerErrorf(codes.InvalidArgument,
+		"require SessionPoolConfig.HealthCheckInterval >= 0, got %v", interval)
 }
 
 // validate verifies that the SessionPoolConfig is good for use.
 func (spc *SessionPoolConfig) validate() error {
-	if spc.getRPCClient == nil {
-		return errNoRPCGetter()
-	}
 	if spc.MinOpened > spc.MaxOpened && spc.MaxOpened > 0 {
 		return errMinOpenedGTMaxOpened(spc.MaxOpened, spc.MinOpened)
 	}
+	if spc.WriteSessions < 0.0 || spc.WriteSessions > 1.0 {
+		return errWriteFractionOutOfRange(spc.WriteSessions)
+	}
+	if spc.HealthCheckWorkers < 0 {
+		return errHealthCheckWorkersNegative(spc.HealthCheckWorkers)
+	}
+	if spc.HealthCheckInterval < 0 {
+		return errHealthCheckIntervalNegative(spc.HealthCheckInterval)
+	}
 	return nil
 }
 
@@ -400,8 +431,8 @@
 	mu sync.Mutex
 	// valid marks the validity of the session pool.
 	valid bool
-	// db is the database name that all sessions in the pool are associated with.
-	db string
+	// sc is used to create the sessions for the pool.
+	sc *sessionClient
 	// idleList caches idle session IDs. Session IDs in this list can be
 	// allocated for use.
 	idleList list.List
@@ -418,23 +449,20 @@
 	prepareReqs uint64
 	// configuration of the session pool.
 	SessionPoolConfig
-	// Metadata to be sent with each request
-	md metadata.MD
 	// hc is the health checker
 	hc *healthChecker
 }
 
 // newSessionPool creates a new session pool.
-func newSessionPool(db string, config SessionPoolConfig, md metadata.MD) (*sessionPool, error) {
+func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, error) {
 	if err := config.validate(); err != nil {
 		return nil, err
 	}
 	pool := &sessionPool{
-		db:                db,
+		sc:                sc,
 		valid:             true,
 		mayGetSession:     make(chan struct{}),
 		SessionPoolConfig: config,
-		md:                md,
 	}
 	if config.HealthCheckWorkers == 0 {
 		// With 10 workers and assuming average latency of 5ms for
@@ -456,10 +484,76 @@
 	// healthChecker can effectively mantain
 	// 100 checks_per_worker/sec * 10 workers * 300 seconds = 300K sessions.
 	pool.hc = newHealthChecker(config.HealthCheckInterval, config.HealthCheckWorkers, config.healthCheckSampleInterval, pool)
+
+	// First initialize the pool before we indicate that the healthchecker is
+	// ready. This prevents the maintainer from starting before the pool has
+	// been initialized, which means that we guarantee that the initial
+	// sessions are created using BatchCreateSessions.
+	if config.MinOpened > 0 {
+		numSessions := minUint64(config.MinOpened, math.MaxInt32)
+		if err := pool.initPool(int32(numSessions)); err != nil {
+			return nil, err
+		}
+	}
 	close(pool.hc.ready)
 	return pool, nil
 }
 
+func (p *sessionPool) initPool(numSessions int32) error {
+	p.mu.Lock()
+	// Take budget before the actual session creation.
+	p.numOpened += uint64(numSessions)
+	recordStat(context.Background(), OpenSessionCount, int64(p.numOpened))
+	p.createReqs += uint64(numSessions)
+	p.mu.Unlock()
+	// Asynchronously create the initial sessions for the pool.
+	return p.sc.batchCreateSessions(numSessions, p)
+}
+
+// sessionReady is executed by the SessionClient when a session has been
+// created and is ready to use. This method will add the new session to the
+// pool and decrease the number of sessions that is being created.
+func (p *sessionPool) sessionReady(s *session) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	// Set this pool as the home pool of the session and register it with the
+	// health checker.
+	s.pool = p
+	p.hc.register(s)
+	p.createReqs--
+	// Insert the session at a random position in the pool to prevent all
+	// sessions affiliated with a channel to be placed at sequentially in the
+	// pool.
+	if p.idleList.Len() > 0 {
+		pos := rand.Intn(p.idleList.Len())
+		before := p.idleList.Front()
+		for i := 0; i < pos; i++ {
+			before = before.Next()
+		}
+		s.setIdleList(p.idleList.InsertBefore(s, before))
+	} else {
+		s.setIdleList(p.idleList.PushBack(s))
+	}
+	// Notify other waiters blocking on session creation.
+	close(p.mayGetSession)
+	p.mayGetSession = make(chan struct{})
+}
+
+// sessionCreationFailed is called by the SessionClient when the creation of one
+// or more requested sessions finished with an error. sessionCreationFailed will
+// decrease the number of sessions being created and notify any waiters that
+// the session creation failed.
+func (p *sessionPool) sessionCreationFailed(err error, numSessions int32) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	p.createReqs -= uint64(numSessions)
+	p.numOpened -= uint64(numSessions)
+	recordStat(context.Background(), OpenSessionCount, int64(p.numOpened))
+	// Notify other waiters blocking on session creation.
+	close(p.mayGetSession)
+	p.mayGetSession = make(chan struct{})
+}
+
 // isValid checks if the session pool is still valid.
 func (p *sessionPool) isValid() bool {
 	if p == nil {
@@ -524,12 +618,7 @@
 		p.mayGetSession = make(chan struct{})
 		p.mu.Unlock()
 	}
-	sc, err := p.getRPCClient()
-	if err != nil {
-		doneCreate(false)
-		return nil, err
-	}
-	s, err := createSession(ctx, sc, p.db, p.sessionLabels, p.md)
+	s, err := p.sc.createSession(ctx)
 	if err != nil {
 		doneCreate(false)
 		// Should return error directly because of the previous retries on
@@ -545,20 +634,6 @@
 	return s, nil
 }
 
-func createSession(ctx context.Context, sc *vkit.Client, db string, labels map[string]string, md metadata.MD) (*session, error) {
-	var s *session
-	sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{
-		Database: db,
-		Session:  &sppb.Session{Labels: labels},
-	})
-	if e != nil {
-		return nil, toSpannerError(e)
-	}
-	// If no error, construct the new session.
-	s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md}
-	return s, nil
-}
-
 func (p *sessionPool) isHealthy(s *session) bool {
 	if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) {
 		// TODO: figure out if we need to schedule a new healthcheck worker here.
@@ -577,7 +652,6 @@
 // for read operations.
 func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) {
 	trace.TracePrintf(ctx, nil, "Acquiring a read-only session")
-	ctx = contextWithOutgoingMetadata(ctx, p.md)
 	for {
 		var (
 			s   *session
@@ -649,7 +723,6 @@
 // returned should be used for read write transactions.
 func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, error) {
 	trace.TracePrintf(ctx, nil, "Acquiring a read-write session")
-	ctx = contextWithOutgoingMetadata(ctx, p.md)
 	for {
 		var (
 			s   *session
@@ -1004,7 +1077,7 @@
 		ws := getNextForTx()
 		if ws != nil {
 			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
-			err := ws.prepareForWrite(contextWithOutgoingMetadata(ctx, hc.pool.md))
+			err := ws.prepareForWrite(ctx)
 			cancel()
 			if err != nil {
 				// Skip handling prepare error, session can be prepared in next
@@ -1044,6 +1117,9 @@
 	// Wait so that pool is ready.
 	<-hc.ready
 
+	// A maintenance window is 10 iterations. The maintainer executes a loop
+	// every hc.sampleInterval, which defaults to 1 minute, which means that
+	// the default maintenance window is 10 minutes.
 	windowSize := uint64(10)
 
 	for iteration := uint64(0); ; iteration++ {
diff --git a/spanner/session_test.go b/spanner/session_test.go
index 449ab5b..17b8ea2 100644
--- a/spanner/session_test.go
+++ b/spanner/session_test.go
@@ -25,8 +25,8 @@
 	"testing"
 	"time"
 
-	vkit "cloud.google.com/go/spanner/apiv1"
 	. "cloud.google.com/go/spanner/internal/testutil"
+	sppb "google.golang.org/genproto/googleapis/spanner/v1"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
@@ -42,21 +42,38 @@
 		err error
 	}{
 		{
-			SessionPoolConfig{},
-			errNoRPCGetter(),
-		},
-		{
 			SessionPoolConfig{
-				getRPCClient: func() (*vkit.Client, error) {
-					return client.clients[0], nil
-				},
 				MinOpened: 10,
 				MaxOpened: 5,
 			},
 			errMinOpenedGTMaxOpened(5, 10),
 		},
+		{
+			SessionPoolConfig{
+				WriteSessions: -0.1,
+			},
+			errWriteFractionOutOfRange(-0.1),
+		},
+		{
+			SessionPoolConfig{
+				WriteSessions: 2.0,
+			},
+			errWriteFractionOutOfRange(2.0),
+		},
+		{
+			SessionPoolConfig{
+				HealthCheckWorkers: -1,
+			},
+			errHealthCheckWorkersNegative(-1),
+		},
+		{
+			SessionPoolConfig{
+				HealthCheckInterval: -time.Second,
+			},
+			errHealthCheckIntervalNegative(-time.Second),
+		},
 	} {
-		if _, err := newSessionPool("mockdb", test.spc, nil); !testEqual(err, test.err) {
+		if _, err := newSessionPool(client.sc, test.spc); !testEqual(err, test.err) {
 			t.Fatalf("want %v, got %v", test.err, err)
 		}
 	}
@@ -459,8 +476,8 @@
 	defer sp.mu.Unlock()
 	// There should be still one session left in idle list due to the min open
 	// sessions constraint.
-	if sp.idleList.Len() != 1 {
-		t.Fatalf("got %v sessions in idle list, want 1 %d", sp.idleList.Len(), sp.numOpened)
+	if sp.idleList.Len() != int(sp.MinOpened) {
+		t.Fatalf("got %v sessions in idle list, want %d", sp.idleList.Len(), sp.MinOpened)
 	}
 }
 
@@ -1100,7 +1117,7 @@
 	})
 }
 
-// Tests that maintainer creates up to MinOpened connections.
+// Tests that the session pool creates up to MinOpened connections.
 //
 // Historical context: This test also checks that a low
 // healthCheckSampleInterval does not prevent it from opening connections.
@@ -1108,11 +1125,57 @@
 // creations to time out. That should not be considered a problem, but it
 // could cause the test case to fail if it happens too often.
 // See: https://github.com/googleapis/google-cloud-go/issues/1259
-func TestMaintainer_CreatesSessions(t *testing.T) {
+func TestInit_CreatesSessions(t *testing.T) {
 	t.Parallel()
 	spc := SessionPoolConfig{
 		MinOpened:                 10,
 		MaxIdle:                   10,
+		WriteSessions:             0.0,
+		healthCheckSampleInterval: 20 * time.Millisecond,
+	}
+	server, client, teardown := setupMockedTestServerWithConfig(t,
+		ClientConfig{
+			SessionPoolConfig: spc,
+			NumChannels:       4,
+		})
+	defer teardown()
+	sp := client.idleSessions
+
+	timeout := time.After(4 * time.Second)
+	var numOpened int
+loop:
+	for {
+		select {
+		case <-timeout:
+			t.Fatalf("timed out, got %d session(s), want %d", numOpened, spc.MinOpened)
+		default:
+			sp.mu.Lock()
+			numOpened = sp.idleList.Len() + sp.idleWriteList.Len()
+			sp.mu.Unlock()
+			if numOpened == 10 {
+				break loop
+			}
+		}
+	}
+	_, err := shouldHaveReceived(server.TestSpanner, []interface{}{
+		&sppb.BatchCreateSessionsRequest{},
+		&sppb.BatchCreateSessionsRequest{},
+		&sppb.BatchCreateSessionsRequest{},
+		&sppb.BatchCreateSessionsRequest{},
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+// Tests that the session pool with a MinSessions>0 also prepares WriteSessions
+// sessions.
+func TestInit_PreparesSessions(t *testing.T) {
+	t.Parallel()
+	spc := SessionPoolConfig{
+		MinOpened:                 10,
+		MaxIdle:                   10,
+		WriteSessions:             0.5,
 		healthCheckSampleInterval: 20 * time.Millisecond,
 	}
 	_, client, teardown := setupMockedTestServerWithConfig(t,
@@ -1124,17 +1187,18 @@
 
 	timeoutAmt := 4 * time.Second
 	timeout := time.After(timeoutAmt)
-	var numOpened uint64
+	var numPrepared int
+	want := int(spc.WriteSessions * float64(spc.MinOpened))
 loop:
 	for {
 		select {
 		case <-timeout:
-			t.Fatalf("timed out after %v, got %d session(s), want %d", timeoutAmt, numOpened, spc.MinOpened)
+			t.Fatalf("timed out after %v, got %d write-prepared session(s), want %d", timeoutAmt, numPrepared, want)
 		default:
 			sp.mu.Lock()
-			numOpened = sp.numOpened
+			numPrepared = sp.idleWriteList.Len()
 			sp.mu.Unlock()
-			if numOpened == 10 {
+			if numPrepared == want {
 				break loop
 			}
 		}
diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go
new file mode 100644
index 0000000..b3857a0
--- /dev/null
+++ b/spanner/sessionclient.go
@@ -0,0 +1,233 @@
+/*
+Copyright 2019 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package spanner
+
+import (
+	"context"
+	"fmt"
+	"sync"
+	"time"
+
+	"cloud.google.com/go/internal/trace"
+	vkit "cloud.google.com/go/spanner/apiv1"
+	sppb "google.golang.org/genproto/googleapis/spanner/v1"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/metadata"
+)
+
+// sessionConsumer is passed to the batchCreateSessions method and will receive
+// the sessions that are created as they become available. A sessionConsumer
+// implementation must be safe for concurrent use.
+//
+// The interface is implemented by sessionPool and is used for testing the
+// sessionClient.
+type sessionConsumer interface {
+	// sessionReady is called when a session has been created and is ready for
+	// use.
+	sessionReady(s *session)
+
+	// sessionCreationFailed is called when the creation of a sub-batch of
+	// sessions failed. The numSessions argument specifies the number of
+	// sessions that could not be created as a result of this error. A
+	// consumer may receive multiple errors per batch.
+	sessionCreationFailed(err error, numSessions int32)
+}
+
+// sessionClient creates sessions for a database, either in batches or one at a
+// time. Each session will be affiliated with a gRPC channel. sessionClient
+// will ensure that the sessions that are created are evenly distributed over
+// all available channels.
+type sessionClient struct {
+	mu     sync.Mutex
+	rr     int
+	closed bool
+
+	gapicClients  []*vkit.Client
+	database      string
+	sessionLabels map[string]string
+	md            metadata.MD
+	batchTimeout  time.Duration
+}
+
+// newSessionClient creates a session client to use for a database.
+func newSessionClient(gapicClients []*vkit.Client, database string, sessionLabels map[string]string, md metadata.MD) *sessionClient {
+	return &sessionClient{
+		gapicClients:  gapicClients,
+		database:      database,
+		sessionLabels: sessionLabels,
+		md:            md,
+		batchTimeout:  time.Minute,
+	}
+}
+
+func (sc *sessionClient) close() error {
+	sc.mu.Lock()
+	defer sc.mu.Unlock()
+	sc.closed = true
+	var errs []error
+	for _, gpc := range sc.gapicClients {
+		if err := gpc.Close(); err != nil {
+			errs = append(errs, err)
+		}
+	}
+	switch len(errs) {
+	case 0:
+		return nil
+	case 1:
+		return errs[0]
+	default:
+		return fmt.Errorf("closing gapic clients returned multiple errors: %v", errs)
+	}
+}
+
+// createSession creates one session for the database of the sessionClient. The
+// session is created using one synchronous RPC.
+func (sc *sessionClient) createSession(ctx context.Context) (*session, error) {
+	ctx = contextWithOutgoingMetadata(ctx, sc.md)
+	sc.mu.Lock()
+	if sc.closed {
+		return nil, spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
+	}
+	client := sc.rrNextGapicClientLocked()
+	sc.mu.Unlock()
+	sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{
+		Database: sc.database,
+		Session:  &sppb.Session{Labels: sc.sessionLabels},
+	})
+	if err != nil {
+		return nil, toSpannerError(err)
+	}
+	return &session{valid: true, client: client, id: sid.Name, createTime: time.Now(), md: sc.md}, nil
+}
+
+// batchCreateSessions creates a batch of sessions for the database of the
+// sessionClient and returns these to the given sessionConsumer.
+//
+// createSessionCount is the number of sessions that should be created. The
+// sessionConsumer is guaranteed to receive the requested number of sessions if
+// no error occurs. If one or more errors occur, the sessionConsumer will
+// receive any number of sessions + any number of errors, where each error will
+// include the number of sessions that could not be created as a result of the
+// error. The sum of returned sessions and errored sessions will be equal to
+// the number of requested sessions.
+func (sc *sessionClient) batchCreateSessions(createSessionCount int32, consumer sessionConsumer) error {
+	// The sessions that we create should be evenly distributed over all the
+	// channels (gapic clients) that are used by the client. Each gapic client
+	// will do a request for a fraction of the total.
+	sessionCountPerChannel := createSessionCount / int32(len(sc.gapicClients))
+	// The remainder of the calculation will be added to the number of sessions
+	// that will be created for the first channel, to ensure that we create the
+	// exact number of requested sessions.
+	remainder := createSessionCount % int32(len(sc.gapicClients))
+	sc.mu.Lock()
+	defer sc.mu.Unlock()
+	if sc.closed {
+		return spannerErrorf(codes.FailedPrecondition, "SessionClient is closed")
+	}
+	// Spread the session creation over all available gRPC channels. Spanner
+	// will maintain server side caches for a session on the gRPC channel that
+	// is used by the session. A session should therefore always use the same
+	// channel, and the sessions should be as evenly distributed as possible
+	// over the channels.
+	for i := 0; i < len(sc.gapicClients); i++ {
+		client := sc.rrNextGapicClientLocked()
+		// Determine the number of sessions that should be created for this
+		// channel. The createCount for the first channel will be increased
+		// with the remainder of the division of the total number of sessions
+		// with the number of channels. All other channels will just use the
+		// result of the division over all channels.
+		createCountForChannel := sessionCountPerChannel
+		if i == 0 {
+			// We add the remainder to the first gRPC channel we use. We could
+			// also spread the remainder over all channels, but this ensures
+			// that small batches of sessions (i.e. less than numChannels) are
+			// created in one RPC.
+			createCountForChannel += remainder
+		}
+		if createCountForChannel > 0 {
+			go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer)
+		}
+	}
+	return nil
+}
+
+// executeBatchCreateSessions executes the gRPC call for creating a batch of
+// sessions.
+func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) {
+	ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout)
+	defer cancel()
+	ctx = contextWithOutgoingMetadata(ctx, sc.md)
+
+	ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions")
+	defer func() { trace.EndSpan(ctx, nil) }()
+	trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount)
+	remainingCreateCount := createCount
+	for {
+		sc.mu.Lock()
+		closed := sc.closed
+		sc.mu.Unlock()
+		if closed {
+			err := spannerErrorf(codes.Canceled, "Session client closed")
+			trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err)
+			consumer.sessionCreationFailed(err, remainingCreateCount)
+			break
+		}
+		if ctx.Err() != nil {
+			trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err())
+			consumer.sessionCreationFailed(ctx.Err(), remainingCreateCount)
+			break
+		}
+		response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{
+			SessionCount:    remainingCreateCount,
+			Database:        sc.database,
+			SessionTemplate: &sppb.Session{Labels: labels},
+		})
+		if err != nil {
+			trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err)
+			consumer.sessionCreationFailed(err, remainingCreateCount)
+			break
+		}
+		actuallyCreated := int32(len(response.Session))
+		trace.TracePrintf(ctx, nil, "Received a batch of %d sessions", actuallyCreated)
+		for _, s := range response.Session {
+			consumer.sessionReady(&session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md})
+		}
+		if actuallyCreated < remainingCreateCount {
+			// Spanner could return less sessions than requested. In that case, we
+			// should do another call using the same gRPC channel.
+			remainingCreateCount -= actuallyCreated
+		} else {
+			trace.TracePrintf(ctx, nil, "Finished creating %d sessions", createCount)
+			break
+		}
+	}
+}
+
+func (sc *sessionClient) sessionWithID(id string) *session {
+	sc.mu.Lock()
+	defer sc.mu.Unlock()
+	return &session{valid: true, client: sc.rrNextGapicClientLocked(), id: id, createTime: time.Now(), md: sc.md}
+}
+
+// rrNextGapicClientLocked returns the next gRPC client to use for session creation. The
+// client is set on the session, and used by all subsequent gRPC calls on the
+// session. Using the same channel for all gRPC calls for a session ensures the
+// optimal usage of server side caches.
+func (sc *sessionClient) rrNextGapicClientLocked() *vkit.Client {
+	sc.rr = (sc.rr + 1) % len(sc.gapicClients)
+	return sc.gapicClients[sc.rr]
+}
diff --git a/spanner/sessionclient_test.go b/spanner/sessionclient_test.go
new file mode 100644
index 0000000..82be7fd
--- /dev/null
+++ b/spanner/sessionclient_test.go
@@ -0,0 +1,315 @@
+/*
+Copyright 2019 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package spanner
+
+import (
+	"context"
+	"sync"
+	"testing"
+	"time"
+
+	vkit "cloud.google.com/go/spanner/apiv1"
+	. "cloud.google.com/go/spanner/internal/testutil"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+)
+
+type testSessionCreateError struct {
+	err error
+	num int32
+}
+
+type testConsumer struct {
+	numExpected int32
+
+	mu       sync.Mutex
+	sessions []*session
+	errors   []*testSessionCreateError
+	numErr   int32
+
+	receivedAll chan struct{}
+}
+
+func (tc *testConsumer) sessionReady(s *session) {
+	tc.mu.Lock()
+	defer tc.mu.Unlock()
+	tc.sessions = append(tc.sessions, s)
+	tc.checkReceivedAll()
+}
+
+func (tc *testConsumer) sessionCreationFailed(err error, num int32) {
+	tc.mu.Lock()
+	defer tc.mu.Unlock()
+	tc.errors = append(tc.errors, &testSessionCreateError{
+		err: err,
+		num: num,
+	})
+	tc.numErr += num
+	tc.checkReceivedAll()
+}
+
+func (tc *testConsumer) checkReceivedAll() {
+	if int32(len(tc.sessions))+tc.numErr == tc.numExpected {
+		close(tc.receivedAll)
+	}
+}
+
+func newTestConsumer(numExpected int32) *testConsumer {
+	return &testConsumer{
+		numExpected: numExpected,
+		receivedAll: make(chan struct{}),
+	}
+}
+
+func TestCreateAndCloseSession(t *testing.T) {
+	t.Parallel()
+
+	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
+		SessionPoolConfig: SessionPoolConfig{
+			MinOpened: 0,
+			MaxOpened: 100,
+		},
+	})
+	defer teardown()
+
+	s, err := client.sc.createSession(context.Background())
+	if err != nil {
+		t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err)
+	}
+	if s == nil {
+		t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s)
+	}
+	if server.TestSpanner.TotalSessionsCreated() != 1 {
+		t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1)
+	}
+	s.delete(context.Background())
+	if server.TestSpanner.TotalSessionsDeleted() != 1 {
+		t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1)
+	}
+}
+
+func TestBatchCreateAndCloseSession(t *testing.T) {
+	t.Parallel()
+
+	numSessions := int32(100)
+	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
+	defer serverTeardown()
+	for numChannels := 1; numChannels <= 32; numChannels *= 2 {
+		prevCreated := server.TestSpanner.TotalSessionsCreated()
+		prevDeleted := server.TestSpanner.TotalSessionsDeleted()
+		client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
+			NumChannels: numChannels,
+			SessionPoolConfig: SessionPoolConfig{
+				MinOpened: 0,
+				MaxOpened: 400,
+			}}, opts...)
+		if err != nil {
+			t.Fatal(err)
+		}
+		consumer := newTestConsumer(numSessions)
+		client.sc.batchCreateSessions(numSessions, consumer)
+		<-consumer.receivedAll
+		if len(consumer.sessions) != int(numSessions) {
+			t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions)
+		}
+		created := server.TestSpanner.TotalSessionsCreated() - prevCreated
+		if created != uint(numSessions) {
+			t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions)
+		}
+		// Check that all channels are used evenly.
+		channelCounts := make(map[*vkit.Client]int32)
+		for _, s := range consumer.sessions {
+			channelCounts[s.client]++
+		}
+		if len(channelCounts) != numChannels {
+			t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels)
+		}
+		for _, c := range channelCounts {
+			if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) {
+				t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1)
+			}
+		}
+		// Delete the sessions.
+		for _, s := range consumer.sessions {
+			s.delete(context.Background())
+		}
+		deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted
+		if deleted != uint(numSessions) {
+			t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions)
+		}
+		client.Close()
+	}
+}
+
+func TestBatchCreateSessionsWithExceptions(t *testing.T) {
+	t.Parallel()
+
+	numSessions := int32(100)
+	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
+	defer serverTeardown()
+
+	// Run the test with everything between 1 and numChannels errors.
+	for numErrors := int32(1); numErrors <= numChannels; numErrors++ {
+		// Make sure that the error is not always the first call.
+		for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ {
+			client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
+				NumChannels: numChannels,
+				SessionPoolConfig: SessionPoolConfig{
+					MinOpened: 0,
+					MaxOpened: 400,
+				}}, opts...)
+			if err != nil {
+				t.Fatal(err)
+			}
+			// Register the errors on the server.
+			errors := make([]error, numErrors+firstErrorAt)
+			for i := firstErrorAt; i < numErrors+firstErrorAt; i++ {
+				errors[i] = spannerErrorf(codes.FailedPrecondition, "session creation failed")
+			}
+			server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
+				Errors: errors,
+			})
+			consumer := newTestConsumer(numSessions)
+			client.sc.batchCreateSessions(numSessions, consumer)
+			<-consumer.receivedAll
+
+			sessionsReturned := int32(len(consumer.sessions))
+			if int32(len(consumer.errors)) != numErrors {
+				t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors)
+			}
+			for _, e := range consumer.errors {
+				if g, w := status.Code(e.err), codes.FailedPrecondition; g != w {
+					t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w)
+				}
+			}
+			maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels)
+			minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1)
+			if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions {
+				t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions)
+			}
+			client.Close()
+		}
+	}
+}
+
+func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) {
+	t.Parallel()
+
+	numChannels := 4
+	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
+		NumChannels: numChannels,
+		SessionPoolConfig: SessionPoolConfig{
+			MinOpened: 0,
+			MaxOpened: 100,
+		},
+	})
+	defer teardown()
+	// Ensure that the server will never return more than 10 sessions per batch
+	// create request.
+	server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10)
+	numSessions := int32(100)
+	// Request a batch of sessions that is larger than will be returned by the
+	// server in one request. The server will return at most 10 sessions per
+	// request. The sessionCreator will spread these requests over the 4
+	// channels that are available, i.e. do requests for 25 sessions in each
+	// request. The batch should still return 100 sessions.
+	consumer := newTestConsumer(numSessions)
+	client.sc.batchCreateSessions(numSessions, consumer)
+	<-consumer.receivedAll
+	if len(consumer.errors) > 0 {
+		t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0)
+	}
+	returnedSessionCount := int32(len(consumer.sessions))
+	if returnedSessionCount != numSessions {
+		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions)
+	}
+}
+
+func TestBatchCreateSessions_ServerExhausted(t *testing.T) {
+	t.Parallel()
+
+	numChannels := 4
+	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
+		NumChannels: numChannels,
+		SessionPoolConfig: SessionPoolConfig{
+			MinOpened: 0,
+			MaxOpened: 100,
+		},
+	})
+	defer teardown()
+	numSessions := int32(100)
+	maxSessions := int32(50)
+	// Ensure that the server will never return more than 50 sessions in total.
+	server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions)
+	consumer := newTestConsumer(numSessions)
+	client.sc.batchCreateSessions(numSessions, consumer)
+	<-consumer.receivedAll
+	// Session creation should end with at least one RESOURCE_EXHAUSTED error.
+	if len(consumer.errors) == 0 {
+		t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0)
+	}
+	for _, e := range consumer.errors {
+		if g, w := status.Code(e.err), codes.ResourceExhausted; g != w {
+			t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w)
+		}
+	}
+	// The number of returned sessions should be equal to the max of the
+	// server.
+	returnedSessionCount := int32(len(consumer.sessions))
+	if returnedSessionCount != maxSessions {
+		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions)
+	}
+	if consumer.numErr != (numSessions - maxSessions) {
+		t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions)
+	}
+}
+
+func TestBatchCreateSessions_WithTimeout(t *testing.T) {
+	t.Parallel()
+
+	numSessions := int32(100)
+	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
+	defer serverTeardown()
+	server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
+		MinimumExecutionTime: time.Second,
+	})
+	client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
+		SessionPoolConfig: SessionPoolConfig{
+			MinOpened: 0,
+			MaxOpened: 400,
+		}}, opts...)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	client.sc.batchTimeout = 10 * time.Millisecond
+	consumer := newTestConsumer(numSessions)
+	client.sc.batchCreateSessions(numSessions, consumer)
+	<-consumer.receivedAll
+	if len(consumer.sessions) > 0 {
+		t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0)
+	}
+	if len(consumer.errors) != numChannels {
+		t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels)
+	}
+	for _, e := range consumer.errors {
+		if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w {
+			t.Fatalf("Error code mismatch\ngot: %v\nwant: %v", g, w)
+		}
+	}
+	client.Close()
+}