spanner: correctly allow MinOpened sessions to be spun up
Session maintainer incorrectly uses the channel returned by time.After in two
select statements. However, time.After does not _close_ the channel, it
_signals_ the channel: so, only one select statement catches the signal.
This CL changes the behavior to use a context.WithTimeout instead, which does
get closed (ctx.Done).
Also moves shrinkPool and replenishPool into their own methods.
Fixes #1259
Change-Id: Ide2e417ca51ea9bc2416bde5e19d82defe427da6
Reviewed-on: https://code-review.googlesource.com/c/36711
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/spanner/session.go b/spanner/session.go
index 784ed76..4c33dad 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -961,83 +961,8 @@
var (
windowSize uint64 = 10
iteration uint64
- timeout <-chan time.Time
)
- // replenishPool is run if numOpened is less than sessionsToKeep, timeouts on sampleInterval.
- replenishPool := func(sessionsToKeep uint64) {
- ctx, _ := context.WithTimeout(context.Background(), hc.sampleInterval)
- for {
- select {
- case <-timeout:
- return
- default:
- }
-
- p := hc.pool
- p.mu.Lock()
- // Take budget before the actual session creation.
- if sessionsToKeep <= p.numOpened {
- p.mu.Unlock()
- break
- }
- p.numOpened++
- trace.RecordStat(ctx, trace.OpenSessionCount, int64(p.numOpened))
- p.createReqs++
- shouldPrepareWrite := p.shouldPrepareWrite()
- p.mu.Unlock()
- var (
- s *session
- err error
- )
- if s, err = p.createSession(ctx); err != nil {
- log.Printf("Failed to create session, error: %v", toSpannerError(err))
- continue
- }
- if shouldPrepareWrite {
- if err = s.prepareForWrite(ctx); err != nil {
- p.recycle(s)
- log.Printf("Failed to prepare session, error: %v", toSpannerError(err))
- continue
- }
- }
- p.recycle(s)
- }
- }
-
- // shrinkPool, scales down the session pool.
- shrinkPool := func(sessionsToKeep uint64) {
- for {
- select {
- case <-timeout:
- return
- default:
- }
-
- p := hc.pool
- p.mu.Lock()
-
- if sessionsToKeep >= p.numOpened {
- p.mu.Unlock()
- break
- }
-
- var s *session
- if p.idleList.Len() > 0 {
- s = p.idleList.Front().Value.(*session)
- } else if p.idleWriteList.Len() > 0 {
- s = p.idleWriteList.Front().Value.(*session)
- }
- p.mu.Unlock()
- if s != nil {
- // destroy session as expire.
- s.destroy(true)
- } else {
- break
- }
- }
- }
-
for {
if hc.isClosing() {
hc.waitWorkers.Done()
@@ -1061,24 +986,95 @@
minUint64(currSessionsOpened, hc.pool.MaxIdle+maxSessionsInUse))
hc.mu.Unlock()
- timeout = time.After(hc.sampleInterval)
+ ctx, cancel := context.WithTimeout(context.Background(), hc.sampleInterval)
+
// Replenish or Shrink pool if needed.
// Note: we don't need to worry about pending create session requests, we only need to sample the current sessions in use.
// the routines will not try to create extra / delete creating sessions.
if sessionsToKeep > currSessionsOpened {
- replenishPool(sessionsToKeep)
+ hc.replenishPool(ctx, sessionsToKeep)
} else {
- shrinkPool(sessionsToKeep)
+ hc.shrinkPool(ctx, sessionsToKeep)
}
select {
- case <-timeout:
+ case <-ctx.Done():
case <-hc.done:
+ cancel()
}
iteration++
}
}
+// replenishPool is run if numOpened is less than sessionsToKeep, timeouts on sampleInterval.
+func (hc *healthChecker) replenishPool(ctx context.Context, sessionsToKeep uint64) {
+ for {
+ if ctx.Err() != nil {
+ return
+ }
+
+ p := hc.pool
+ p.mu.Lock()
+ // Take budget before the actual session creation.
+ if sessionsToKeep <= p.numOpened {
+ p.mu.Unlock()
+ break
+ }
+ p.numOpened++
+ trace.RecordStat(ctx, trace.OpenSessionCount, int64(p.numOpened))
+ p.createReqs++
+ shouldPrepareWrite := p.shouldPrepareWrite()
+ p.mu.Unlock()
+ var (
+ s *session
+ err error
+ )
+ if s, err = p.createSession(ctx); err != nil {
+ log.Printf("Failed to create session, error: %v", toSpannerError(err))
+ continue
+ }
+ if shouldPrepareWrite {
+ if err = s.prepareForWrite(ctx); err != nil {
+ p.recycle(s)
+ log.Printf("Failed to prepare session, error: %v", toSpannerError(err))
+ continue
+ }
+ }
+ p.recycle(s)
+ }
+}
+
+// shrinkPool, scales down the session pool.
+func (hc *healthChecker) shrinkPool(ctx context.Context, sessionsToKeep uint64) {
+ for {
+ if ctx.Err() != nil {
+ return
+ }
+
+ p := hc.pool
+ p.mu.Lock()
+
+ if sessionsToKeep >= p.numOpened {
+ p.mu.Unlock()
+ break
+ }
+
+ var s *session
+ if p.idleList.Len() > 0 {
+ s = p.idleList.Front().Value.(*session)
+ } else if p.idleWriteList.Len() > 0 {
+ s = p.idleWriteList.Front().Value.(*session)
+ }
+ p.mu.Unlock()
+ if s != nil {
+ // destroy session as expire.
+ s.destroy(true)
+ } else {
+ break
+ }
+ }
+}
+
// shouldDropSession returns true if a particular error leads to the removal of a session
func shouldDropSession(err error) bool {
if err == nil {
diff --git a/spanner/session_test.go b/spanner/session_test.go
index 54b906c..be39ea9 100644
--- a/spanner/session_test.go
+++ b/spanner/session_test.go
@@ -980,6 +980,61 @@
})
}
+// Tests that maintainer creates up to MinOpened connections.
+//
+// Historical context: This test also checks that a low healthCheckSampleInterval
+// does not prevent it from opening connections. See: https://github.com/googleapis/google-cloud-go/issues/1259
+func TestMaintainer_CreatesSessions(t *testing.T) {
+ t.Parallel()
+
+ rawServerStub := testutil.NewMockCloudSpannerClient(t)
+ serverClientMock := testutil.FuncMock{MockCloudSpannerClient: rawServerStub}
+ serverClientMock.CreateSessionFn = func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+ time.Sleep(10 * time.Millisecond)
+ return rawServerStub.CreateSession(c, r, opts...)
+ }
+ spc := SessionPoolConfig{
+ MinOpened: 10,
+ MaxIdle: 10,
+ healthCheckSampleInterval: time.Millisecond,
+ getRPCClient: func() (sppb.SpannerClient, error) {
+ return &serverClientMock, nil
+ },
+ }
+ db := "mockdb"
+ sp, err := newSessionPool(db, spc, nil)
+ if err != nil {
+ t.Fatalf("cannot create session pool: %v", err)
+ }
+ client := Client{
+ database: db,
+ idleSessions: sp,
+ }
+ defer func() {
+ client.Close()
+ sp.hc.close()
+ sp.close()
+ }()
+
+ timeoutAmt := 2 * time.Second
+ timeout := time.After(timeoutAmt)
+ var numOpened uint64
+loop:
+ for {
+ select {
+ case <-timeout:
+ t.Fatalf("timed out after %v, got %d session(s), want %d", timeoutAmt, numOpened, spc.MinOpened)
+ default:
+ sp.mu.Lock()
+ numOpened = sp.numOpened
+ sp.mu.Unlock()
+ if numOpened == 10 {
+ break loop
+ }
+ }
+ }
+}
+
func (s1 *session) Equal(s2 *session) bool {
return s1.client == s2.client &&
s1.id == s2.id &&