spanner: correctly allow MinOpened sessions to be spun up

Session maintainer incorrectly uses the channel returned by time.After in two
select statements. However, time.After does not _close_ the channel, it
_signals_ the channel: so, only one select statement catches the signal.
This CL changes the behavior to use a context.WithTimeout instead, which does
get closed (ctx.Done).

Also moves shrinkPool and replenishPool into their own methods.

Fixes #1259

Change-Id: Ide2e417ca51ea9bc2416bde5e19d82defe427da6
Reviewed-on: https://code-review.googlesource.com/c/36711
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/spanner/session.go b/spanner/session.go
index 784ed76..4c33dad 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -961,83 +961,8 @@
 	var (
 		windowSize uint64 = 10
 		iteration  uint64
-		timeout    <-chan time.Time
 	)
 
-	// replenishPool is run if numOpened is less than sessionsToKeep, timeouts on sampleInterval.
-	replenishPool := func(sessionsToKeep uint64) {
-		ctx, _ := context.WithTimeout(context.Background(), hc.sampleInterval)
-		for {
-			select {
-			case <-timeout:
-				return
-			default:
-			}
-
-			p := hc.pool
-			p.mu.Lock()
-			// Take budget before the actual session creation.
-			if sessionsToKeep <= p.numOpened {
-				p.mu.Unlock()
-				break
-			}
-			p.numOpened++
-			trace.RecordStat(ctx, trace.OpenSessionCount, int64(p.numOpened))
-			p.createReqs++
-			shouldPrepareWrite := p.shouldPrepareWrite()
-			p.mu.Unlock()
-			var (
-				s   *session
-				err error
-			)
-			if s, err = p.createSession(ctx); err != nil {
-				log.Printf("Failed to create session, error: %v", toSpannerError(err))
-				continue
-			}
-			if shouldPrepareWrite {
-				if err = s.prepareForWrite(ctx); err != nil {
-					p.recycle(s)
-					log.Printf("Failed to prepare session, error: %v", toSpannerError(err))
-					continue
-				}
-			}
-			p.recycle(s)
-		}
-	}
-
-	// shrinkPool, scales down the session pool.
-	shrinkPool := func(sessionsToKeep uint64) {
-		for {
-			select {
-			case <-timeout:
-				return
-			default:
-			}
-
-			p := hc.pool
-			p.mu.Lock()
-
-			if sessionsToKeep >= p.numOpened {
-				p.mu.Unlock()
-				break
-			}
-
-			var s *session
-			if p.idleList.Len() > 0 {
-				s = p.idleList.Front().Value.(*session)
-			} else if p.idleWriteList.Len() > 0 {
-				s = p.idleWriteList.Front().Value.(*session)
-			}
-			p.mu.Unlock()
-			if s != nil {
-				// destroy session as expire.
-				s.destroy(true)
-			} else {
-				break
-			}
-		}
-	}
-
 	for {
 		if hc.isClosing() {
 			hc.waitWorkers.Done()
@@ -1061,24 +986,95 @@
 			minUint64(currSessionsOpened, hc.pool.MaxIdle+maxSessionsInUse))
 		hc.mu.Unlock()
 
-		timeout = time.After(hc.sampleInterval)
+		ctx, cancel := context.WithTimeout(context.Background(), hc.sampleInterval)
+
 		// Replenish or Shrink pool if needed.
 		// Note: we don't need to worry about pending create session requests, we only need to sample the current sessions in use.
 		// the routines will not try to create extra / delete creating sessions.
 		if sessionsToKeep > currSessionsOpened {
-			replenishPool(sessionsToKeep)
+			hc.replenishPool(ctx, sessionsToKeep)
 		} else {
-			shrinkPool(sessionsToKeep)
+			hc.shrinkPool(ctx, sessionsToKeep)
 		}
 
 		select {
-		case <-timeout:
+		case <-ctx.Done():
 		case <-hc.done:
+			cancel()
 		}
 		iteration++
 	}
 }
 
+// replenishPool is run if numOpened is less than sessionsToKeep, timeouts on sampleInterval.
+func (hc *healthChecker) replenishPool(ctx context.Context, sessionsToKeep uint64) {
+	for {
+		if ctx.Err() != nil {
+			return
+		}
+
+		p := hc.pool
+		p.mu.Lock()
+		// Take budget before the actual session creation.
+		if sessionsToKeep <= p.numOpened {
+			p.mu.Unlock()
+			break
+		}
+		p.numOpened++
+		trace.RecordStat(ctx, trace.OpenSessionCount, int64(p.numOpened))
+		p.createReqs++
+		shouldPrepareWrite := p.shouldPrepareWrite()
+		p.mu.Unlock()
+		var (
+			s   *session
+			err error
+		)
+		if s, err = p.createSession(ctx); err != nil {
+			log.Printf("Failed to create session, error: %v", toSpannerError(err))
+			continue
+		}
+		if shouldPrepareWrite {
+			if err = s.prepareForWrite(ctx); err != nil {
+				p.recycle(s)
+				log.Printf("Failed to prepare session, error: %v", toSpannerError(err))
+				continue
+			}
+		}
+		p.recycle(s)
+	}
+}
+
+// shrinkPool, scales down the session pool.
+func (hc *healthChecker) shrinkPool(ctx context.Context, sessionsToKeep uint64) {
+	for {
+		if ctx.Err() != nil {
+			return
+		}
+
+		p := hc.pool
+		p.mu.Lock()
+
+		if sessionsToKeep >= p.numOpened {
+			p.mu.Unlock()
+			break
+		}
+
+		var s *session
+		if p.idleList.Len() > 0 {
+			s = p.idleList.Front().Value.(*session)
+		} else if p.idleWriteList.Len() > 0 {
+			s = p.idleWriteList.Front().Value.(*session)
+		}
+		p.mu.Unlock()
+		if s != nil {
+			// destroy session as expire.
+			s.destroy(true)
+		} else {
+			break
+		}
+	}
+}
+
 // shouldDropSession returns true if a particular error leads to the removal of a session
 func shouldDropSession(err error) bool {
 	if err == nil {
diff --git a/spanner/session_test.go b/spanner/session_test.go
index 54b906c..be39ea9 100644
--- a/spanner/session_test.go
+++ b/spanner/session_test.go
@@ -980,6 +980,61 @@
 	})
 }
 
+// Tests that maintainer creates up to MinOpened connections.
+//
+// Historical context: This test also checks that a low healthCheckSampleInterval
+// does not prevent it from opening connections. See: https://github.com/googleapis/google-cloud-go/issues/1259
+func TestMaintainer_CreatesSessions(t *testing.T) {
+	t.Parallel()
+
+	rawServerStub := testutil.NewMockCloudSpannerClient(t)
+	serverClientMock := testutil.FuncMock{MockCloudSpannerClient: rawServerStub}
+	serverClientMock.CreateSessionFn = func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+		time.Sleep(10 * time.Millisecond)
+		return rawServerStub.CreateSession(c, r, opts...)
+	}
+	spc := SessionPoolConfig{
+		MinOpened:                 10,
+		MaxIdle:                   10,
+		healthCheckSampleInterval: time.Millisecond,
+		getRPCClient: func() (sppb.SpannerClient, error) {
+			return &serverClientMock, nil
+		},
+	}
+	db := "mockdb"
+	sp, err := newSessionPool(db, spc, nil)
+	if err != nil {
+		t.Fatalf("cannot create session pool: %v", err)
+	}
+	client := Client{
+		database:     db,
+		idleSessions: sp,
+	}
+	defer func() {
+		client.Close()
+		sp.hc.close()
+		sp.close()
+	}()
+
+	timeoutAmt := 2 * time.Second
+	timeout := time.After(timeoutAmt)
+	var numOpened uint64
+loop:
+	for {
+		select {
+		case <-timeout:
+			t.Fatalf("timed out after %v, got %d session(s), want %d", timeoutAmt, numOpened, spc.MinOpened)
+		default:
+			sp.mu.Lock()
+			numOpened = sp.numOpened
+			sp.mu.Unlock()
+			if numOpened == 10 {
+				break loop
+			}
+		}
+	}
+}
+
 func (s1 *session) Equal(s2 *session) bool {
 	return s1.client == s2.client &&
 		s1.id == s2.id &&