spanner: Track stacktrace of sessionPool.take()

Sessions are automatically checked out of the session pool by the client
library when one is needed, and returned to the pool when the user closes
the transaction and/or row iterator that was used. If the user however
forgets to close the transaction or row iterator, the session will leak
and the user will eventually get an error while trying to start a transaction.
This error can occur in a completely unrelated part of the application
than the part that causes the session leak, which makes these bugs very
hard to debug and track down.
This change allows the user to instruct the session pool to keep track
of the stacktrace of each goroutine that checks out a session from the pool.
The stacktraces of all checked out sessions + the time the session was
taken from the pool is then included in the error that is returned when the
session pool has been exhausted and no more sessions can be returned. This
option can be used to track down the part(s) of the application that is
causing a session leak.
This feature is disabled by default, and must be enabled specifically by
a user to have any effect.

Updates #1616.

Change-Id: I2ba84b65f391a99d0bed364d2a8e94f7467e3704
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/47150
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Reviewed-by: Tyler Bui-Palsulich <tbp@google.com>
diff --git a/spanner/errors.go b/spanner/errors.go
index af03c32..547a6af 100644
--- a/spanner/errors.go
+++ b/spanner/errors.go
@@ -41,6 +41,9 @@
 	Desc string
 	// trailers are the trailers returned in the response, if any.
 	trailers metadata.MD
+	// additionalInformation optionally contains any additional information
+	// about the error.
+	additionalInformation string
 }
 
 // Error implements error.Error.
@@ -49,7 +52,10 @@
 		return fmt.Sprintf("spanner: OK")
 	}
 	code := ErrCode(e)
-	return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc)
+	if e.additionalInformation == "" {
+		return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc)
+	}
+	return fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation)
 }
 
 // Unwrap returns the wrapped error (if any).
@@ -115,11 +121,11 @@
 	}
 	switch {
 	case err == context.DeadlineExceeded || err == context.Canceled:
-		return &Error{status.FromContextError(err).Code(), status.FromContextError(err).Err(), err.Error(), trailers}
+		return &Error{status.FromContextError(err).Code(), status.FromContextError(err).Err(), err.Error(), trailers, ""}
 	case status.Code(err) == codes.Unknown:
-		return &Error{codes.Unknown, err, err.Error(), trailers}
+		return &Error{codes.Unknown, err, err.Error(), trailers, ""}
 	default:
-		return &Error{status.Convert(err).Code(), err, status.Convert(err).Message(), trailers}
+		return &Error{status.Convert(err).Code(), err, status.Convert(err).Message(), trailers, ""}
 	}
 }
 
diff --git a/spanner/session.go b/spanner/session.go
index 573a008..6545c8a 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -24,6 +24,7 @@
 	"log"
 	"math"
 	"math/rand"
+	"runtime/debug"
 	"strings"
 	"sync"
 	"time"
@@ -44,19 +45,39 @@
 	// session is a pointer to a session object. Transactions never need to
 	// access it directly.
 	session *session
+	// checkoutTime is the time the session was checked out of the pool.
+	checkoutTime time.Time
+	// trackedSessionHandle is the linked list node which links the session to
+	// the list of tracked session handles. trackedSessionHandle is only set if
+	// TrackSessionHandles has been enabled in the session pool configuration.
+	trackedSessionHandle *list.Element
+	// stack is the call stack of the goroutine that checked out the session
+	// from the pool. This can be used to track down session leak problems.
+	stack []byte
 }
 
 // recycle gives the inner session object back to its home session pool. It is
 // safe to call recycle multiple times but only the first one would take effect.
 func (sh *sessionHandle) recycle() {
 	sh.mu.Lock()
-	defer sh.mu.Unlock()
 	if sh.session == nil {
 		// sessionHandle has already been recycled.
+		sh.mu.Unlock()
 		return
 	}
+	p := sh.session.pool
+	tracked := sh.trackedSessionHandle
 	sh.session.recycle()
 	sh.session = nil
+	sh.trackedSessionHandle = nil
+	sh.checkoutTime = time.Time{}
+	sh.stack = nil
+	sh.mu.Unlock()
+	if tracked != nil {
+		p.mu.Lock()
+		p.trackedSessionHandles.Remove(tracked)
+		p.mu.Unlock()
+	}
 }
 
 // getID gets the Cloud Spanner session ID from the internal session object.
@@ -109,8 +130,18 @@
 func (sh *sessionHandle) destroy() {
 	sh.mu.Lock()
 	s := sh.session
+	p := s.pool
+	tracked := sh.trackedSessionHandle
 	sh.session = nil
+	sh.trackedSessionHandle = nil
+	sh.checkoutTime = time.Time{}
+	sh.stack = nil
 	sh.mu.Unlock()
+	if tracked != nil {
+		p.mu.Lock()
+		p.trackedSessionHandles.Remove(tracked)
+		p.mu.Unlock()
+	}
 	if s == nil {
 		// sessionHandle has already been destroyed..
 		return
@@ -376,6 +407,13 @@
 	// Defaults to 5m.
 	HealthCheckInterval time.Duration
 
+	// TrackSessionHandles determines whether the session pool will keep track
+	// of the stacktrace of the goroutines that take sessions from the pool.
+	// This setting can be used to track down session leak problems.
+	//
+	// Defaults to false.
+	TrackSessionHandles bool
+
 	// healthCheckSampleInterval is how often the health checker samples live
 	// session (for use in maintaining session pool size).
 	//
@@ -450,6 +488,10 @@
 	valid bool
 	// sc is used to create the sessions for the pool.
 	sc *sessionClient
+	// trackedSessionHandles contains all sessions handles that have been
+	// checked out of the pool. The list is only filled if TrackSessionHandles
+	// has been enabled.
+	trackedSessionHandles list.List
 	// idleList caches idle session IDs. Session IDs in this list can be
 	// allocated for use.
 	idleList list.List
@@ -621,6 +663,68 @@
 // sessionPool.take().
 var errGetSessionTimeout = spannerErrorf(codes.Canceled, "timeout / context canceled during getting session")
 
+// newSessionHandle creates a new session handle for the given session for this
+// session pool. The session handle will also hold a copy of the current call
+// stack if the session pool has been configured to track the call stacks of
+// sessions being checked out of the pool.
+func (p *sessionPool) newSessionHandle(s *session) (sh *sessionHandle) {
+	sh = &sessionHandle{session: s, checkoutTime: time.Now()}
+	if p.TrackSessionHandles {
+		p.mu.Lock()
+		sh.trackedSessionHandle = p.trackedSessionHandles.PushBack(sh)
+		p.mu.Unlock()
+		sh.stack = debug.Stack()
+	}
+	return sh
+}
+
+// errGetSessionTimeout returns error for context timeout during
+// sessionPool.take().
+func (p *sessionPool) errGetSessionTimeout() error {
+	if p.TrackSessionHandles {
+		return p.errGetSessionTimeoutWithTrackedSessionHandles()
+	}
+	return p.errGetBasicSessionTimeout()
+}
+
+// errGetBasicSessionTimeout returns error for context timout during
+// sessionPool.take() without any tracked sessionHandles.
+func (p *sessionPool) errGetBasicSessionTimeout() error {
+	return spannerErrorf(codes.Canceled, "timeout / context canceled during getting session.\n"+
+		"Enable SessionPoolConfig.TrackSessionHandles if you suspect a session leak to get more information about the checked out sessions.")
+}
+
+// errGetSessionTimeoutWithTrackedSessionHandles returns error for context
+// timout during sessionPool.take() including a stacktrace of each checked out
+// session handle.
+func (p *sessionPool) errGetSessionTimeoutWithTrackedSessionHandles() error {
+	err := spannerErrorf(codes.Canceled, "timeout / context canceled during getting session.")
+	err.(*Error).additionalInformation = p.getTrackedSessionHandleStacksLocked()
+	return err
+}
+
+// getTrackedSessionHandleStacksLocked returns a string containing the
+// stacktrace of all currently checked out sessions of the pool. This method
+// requires the caller to have locked p.mu.
+func (p *sessionPool) getTrackedSessionHandleStacksLocked() string {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	stackTraces := ""
+	i := 1
+	element := p.trackedSessionHandles.Front()
+	for element != nil {
+		sh := element.Value.(*sessionHandle)
+		sh.mu.Lock()
+		if sh.stack != nil {
+			stackTraces = fmt.Sprintf("%s\n\nSession %d checked out of pool at %s by goroutine:\n%s", stackTraces, i, sh.checkoutTime.Format(time.RFC3339), sh.stack)
+		}
+		sh.mu.Unlock()
+		element = element.Next()
+		i++
+	}
+	return stackTraces
+}
+
 // shouldPrepareWriteLocked returns true if we should prepare more sessions for write.
 func (p *sessionPool) shouldPrepareWriteLocked() bool {
 	return !p.disableBackgroundPrepareSessions && float64(p.numOpened)*p.WriteSessions > float64(p.idleWriteList.Len()+int(p.prepareReqs))
@@ -710,7 +814,7 @@
 			if !p.isHealthy(s) {
 				continue
 			}
-			return &sessionHandle{session: s}, nil
+			return p.newSessionHandle(s), nil
 		}
 
 		// Idle list is empty, block if session pool has reached max session
@@ -722,7 +826,7 @@
 			select {
 			case <-ctx.Done():
 				trace.TracePrintf(ctx, nil, "Context done waiting for session")
-				return nil, errGetSessionTimeout
+				return nil, p.errGetSessionTimeout()
 			case <-mayGetSession:
 			}
 			continue
@@ -743,7 +847,7 @@
 		}
 		trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()},
 			"Created session")
-		return &sessionHandle{session: s}, nil
+		return p.newSessionHandle(s), nil
 	}
 }
 
@@ -795,7 +899,7 @@
 				select {
 				case <-ctx.Done():
 					trace.TracePrintf(ctx, nil, "Context done waiting for session")
-					return nil, errGetSessionTimeout
+					return nil, p.errGetSessionTimeout()
 				case <-mayGetSession:
 				}
 				continue
@@ -825,7 +929,7 @@
 				return nil, toSpannerError(err)
 			}
 		}
-		return &sessionHandle{session: s}, nil
+		return p.newSessionHandle(s), nil
 	}
 }
 
diff --git a/spanner/session_test.go b/spanner/session_test.go
index 0016709..b137784 100644
--- a/spanner/session_test.go
+++ b/spanner/session_test.go
@@ -25,10 +25,12 @@
 	"log"
 	"math/rand"
 	"os"
+	"strings"
 	"testing"
 	"time"
 
 	. "cloud.google.com/go/spanner/internal/testutil"
+	"google.golang.org/api/iterator"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
@@ -464,6 +466,86 @@
 	}
 }
 
+// TestSessionLeak tests leaking a session and getting the stack of the
+// goroutine that leaked it.
+func TestSessionLeak(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+
+	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
+		SessionPoolConfig: SessionPoolConfig{
+			TrackSessionHandles: true,
+			MinOpened:           0,
+			MaxOpened:           1,
+		},
+	})
+	defer teardown()
+
+	// Execute a query without calling rowIterator.Stop. This will cause the
+	// session not to be returned to the pool.
+	single := client.Single()
+	iter := single.Query(ctx, NewStatement(SelectFooFromBar))
+	for {
+		_, err := iter.Next()
+		if err == iterator.Done {
+			break
+		}
+		if err != nil {
+			t.Fatalf("Got unexpected error while iterating results: %v\n", err)
+		}
+	}
+	// The session should not have been returned to the pool.
+	if g, w := client.idleSessions.idleList.Len(), 0; g != w {
+		t.Fatalf("Idle sessions count mismatch\nGot: %d\nWant: %d\n", g, w)
+	}
+	// The checked out session should contain a stack trace.
+	if single.sh.stack == nil {
+		t.Fatalf("Missing stacktrace from session handle")
+	}
+	stack := fmt.Sprintf("%s", single.sh.stack)
+	testMethod := "TestSessionLeak"
+	if !strings.Contains(stack, testMethod) {
+		t.Fatalf("Stacktrace does not contain '%s'\nGot: %s", testMethod, stack)
+	}
+	// Return the session to the pool.
+	iter.Stop()
+	// The stack should now have been removed from the session handle.
+	if single.sh.stack != nil {
+		t.Fatalf("Got unexpected stacktrace in session handle: %s", single.sh.stack)
+	}
+
+	// Do another query and hold on to the session.
+	single = client.Single()
+	iter = single.Query(ctx, NewStatement(SelectFooFromBar))
+	for {
+		_, err := iter.Next()
+		if err == iterator.Done {
+			break
+		}
+		if err != nil {
+			t.Fatalf("Got unexpected error while iterating results: %v\n", err)
+		}
+	}
+	// Try to do another query. This will fail as MaxOpened=1.
+	ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Millisecond*10)
+	defer cancel()
+	single2 := client.Single()
+	iter2 := single2.Query(ctxWithTimeout, NewStatement(SelectFooFromBar))
+	_, gotErr := iter2.Next()
+	wantErr := client.idleSessions.errGetSessionTimeoutWithTrackedSessionHandles()
+	// The error should contain the stacktraces of all the checked out
+	// sessions.
+	if !testEqual(gotErr, wantErr) {
+		t.Fatalf("Error mismatch on iterating result set.\nGot: %v\nWant: %v\n", gotErr, wantErr)
+	}
+	if !strings.Contains(gotErr.Error(), testMethod) {
+		t.Fatalf("Error does not contain '%s'\nGot: %s", testMethod, gotErr.Error())
+	}
+	// Close iterators to check sessions back into the pool before closing.
+	iter2.Stop()
+	iter.Stop()
+}
+
 // TestMaxOpenedSessions tests max open sessions constraint.
 func TestMaxOpenedSessions(t *testing.T) {
 	t.Parallel()
@@ -486,7 +568,7 @@
 	ctx2, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
 	defer cancel()
 	_, gotErr := sp.take(ctx2)
-	if wantErr := errGetSessionTimeout; gotErr != wantErr {
+	if wantErr := sp.errGetBasicSessionTimeout(); !testEqual(gotErr, wantErr) {
 		t.Fatalf("the second session retrival returns error %v, want %v", gotErr, wantErr)
 	}
 	doneWaiting := make(chan struct{})
@@ -619,7 +701,7 @@
 	_, gotErr := sp.take(ctx2)
 
 	// Since MaxBurst == 1, the second session request should block.
-	if wantErr := errGetSessionTimeout; gotErr != wantErr {
+	if wantErr := sp.errGetBasicSessionTimeout(); !testEqual(gotErr, wantErr) {
 		t.Fatalf("session retrival returns error %v, want %v", gotErr, wantErr)
 	}