spanner/spannertest: enforce commit timestamp selection, transaction exclusion

The requirements for commit timestamps are such that this simplified
simulation needs to enforce only one read-write transaction operating at
a time.

Updates #1774.

Change-Id: I8c1d0349fdc7d08f86f0eafceb703eb82f7b5b08
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/51832
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
Reviewed-by: kokoro <noreply+kokoro@google.com>
diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md
index 170f764..9ccb19f 100644
--- a/spanner/spannertest/README.md
+++ b/spanner/spannertest/README.md
@@ -22,7 +22,6 @@
 - INSERT/UPDATE DML statements
 - case insensitivity
 - alternate literal types (esp. strings)
-- allow_commit_timestamp
 - STRUCT types
 - expression functions
 - expression type casting, coercion
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index dd28c48..579fe13 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -44,6 +44,8 @@
 	lastTS  time.Time // last commit timestamp
 	tables  map[string]*table
 	indexes map[string]struct{} // only record their existence
+
+	rwMu sync.Mutex // held by read-write transactions
 }
 
 type table struct {
@@ -71,34 +73,75 @@
 var commitTimestampSentinel = &struct{}{}
 
 // transaction records information about a running transaction.
+// This is not safe for concurrent use.
 type transaction struct {
-	commitTimestamp time.Time
+	// readOnly is whether this transaction was constructed
+	// for read-only use, and should yield errors if used
+	// to perform a mutation.
+	readOnly bool
+
+	d               *database
+	commitTimestamp time.Time // not set if readOnly
+	unlock          func()    // may be nil
 }
 
-func (d *database) startTransaction() *transaction {
+func (d *database) NewReadOnlyTransaction() *transaction {
+	return &transaction{
+		readOnly: true,
+	}
+}
+
+func (d *database) NewTransaction() *transaction {
+	return &transaction{
+		d: d,
+	}
+}
+
+// Start starts the transaction and commits to a specific commit timestamp.
+// This also locks out any other read-write transaction on this database
+// until Commit/Rollback are called.
+func (tx *transaction) Start() {
 	// Commit timestamps are only guaranteed to be unique
 	// when transactions write to overlapping sets of fields.
 	// This simulated database exceeds that guarantee.
-	d.mu.Lock()
-	defer d.mu.Unlock()
 
+	// Grab rwMu for the duration of this transaction.
+	// Take it before d.mu so we don't hold that lock
+	// while waiting for d.rwMu, which is held for longer.
+	tx.d.rwMu.Lock()
+
+	tx.d.mu.Lock()
 	const tsRes = 1 * time.Microsecond
 	now := time.Now().UTC().Truncate(tsRes)
-	if now.Equal(d.lastTS) {
-		now = now.Add(tsRes)
+	if !now.After(tx.d.lastTS) {
+		now = tx.d.lastTS.Add(tsRes)
 	}
-	d.lastTS = now
+	tx.d.lastTS = now
+	tx.d.mu.Unlock()
 
-	return &transaction{
-		commitTimestamp: now,
+	tx.commitTimestamp = now
+	tx.unlock = tx.d.rwMu.Unlock
+}
+
+func (tx *transaction) checkMutable() error {
+	if tx.readOnly {
+		// TODO: is this the right status?
+		return status.Errorf(codes.InvalidArgument, "transaction is read-only")
 	}
+	return nil
 }
 
 func (tx *transaction) Commit() (time.Time, error) {
+	if tx.unlock != nil {
+		tx.unlock()
+	}
 	return tx.commitTimestamp, nil
 }
 
 func (tx *transaction) Rollback() {
+	if tx.unlock != nil {
+		tx.unlock()
+	}
 	// TODO: actually rollback
 }
 
@@ -243,6 +286,10 @@
 
 // writeValues executes a write option (Insert, Update, etc.).
 func (d *database) writeValues(tx *transaction, tbl string, cols []string, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error {
+	if err := tx.checkMutable(); err != nil {
+		return err
+	}
+
 	t, err := d.table(tbl)
 	if err != nil {
 		return err
@@ -350,6 +397,10 @@
 // TODO: Replace
 
 func (d *database) Delete(tx *transaction, table string, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error {
+	if err := tx.checkMutable(); err != nil {
+		return err
+	}
+
 	t, err := d.table(table)
 	if err != nil {
 		return err
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index 8babaca..7e7d916 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -83,7 +83,8 @@
 	}
 
 	// Insert a subset of columns.
-	tx := db.startTransaction()
+	tx := db.NewTransaction()
+	tx.Start()
 	err := db.Insert(tx, "Staff", []string{"ID", "Name", "Tenure", "Height"}, []*structpb.ListValue{
 		// int64 arrives as a decimal string.
 		listV(stringV("1"), stringV("Jack"), stringV("10"), floatV(1.85)),
@@ -202,7 +203,8 @@
 	if st.Code() != codes.OK {
 		t.Fatalf("Adding column: %v", st.Err())
 	}
-	tx = db.startTransaction()
+	tx = db.NewTransaction()
+	tx.Start()
 	err = db.Update(tx, "Staff", []string{"Name", "ID", "FirstSeen", "To"}, []*structpb.ListValue{
 		listV(stringV("Jack"), stringV("1"), stringV("1994-10-28"), nullV()),
 		listV(stringV("Daniel"), stringV("2"), stringV("1994-10-28"), nullV()),
@@ -217,7 +219,8 @@
 
 	// Add some more data, then delete it with a KeyRange.
 	// The queries below ensure that this was all deleted.
-	tx = db.startTransaction()
+	tx = db.NewTransaction()
+	tx.Start()
 	err = db.Insert(tx, "Staff", []string{"Name", "ID"}, []*structpb.ListValue{
 		listV(stringV("01"), stringV("1")),
 		listV(stringV("03"), stringV("3")),
@@ -283,7 +286,8 @@
 	if st.Code() != codes.OK {
 		t.Fatalf("Adding column: %v", st.Err())
 	}
-	tx = db.startTransaction()
+	tx = db.NewTransaction()
+	tx.Start()
 	err = db.Update(tx, "Staff", []string{"Name", "ID", "RawBytes"}, []*structpb.ListValue{
 		// bytes {0x01 0x00 0x01} encode as base-64 AQAB.
 		listV(stringV("Jack"), stringV("1"), stringV("AQAB")),
@@ -461,7 +465,8 @@
 		t.Fatalf("Creating table: %v", st.Err())
 	}
 
-	tx := db.startTransaction()
+	tx := db.NewTransaction()
+	tx.Start()
 	err := db.Insert(tx, "Timeseries", []string{"Name", "Observed", "Value"}, []*structpb.ListValue{
 		listV(stringV("box"), stringV("1"), floatV(1.1)),
 		listV(stringV("cupcake"), stringV("1"), floatV(6)),
diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go
index a3cffb0..2566573 100644
--- a/spanner/spannertest/inmem.go
+++ b/spanner/spannertest/inmem.go
@@ -386,50 +386,52 @@
 
 // readTx returns a transaction for the given session and transaction selector.
 // It is used by read/query operations (ExecuteStreamingSql, StreamingRead).
-func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, err error) {
+func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, cleanup func(), err error) {
 	s.mu.Lock()
 	sess, ok := s.sessions[session]
 	s.mu.Unlock()
 	if !ok {
 		// TODO: what error does the real Spanner return?
-		return nil, status.Errorf(codes.NotFound, "unknown session %q", session)
+		return nil, nil, status.Errorf(codes.NotFound, "unknown session %q", session)
 	}
 
 	sess.mu.Lock()
 	sess.lastUse = time.Now()
 	sess.mu.Unlock()
 
-	singleUse := func() (*transaction, error) {
-		tx := s.db.startTransaction()
-		return tx, nil
-	}
-	singleUseReadOnly := func() (*transaction, error) {
-		// TODO: figure out a way to make this read-only.
-		return singleUse()
+	// Only give a read-only transaction regardless of whether the selector
+	// is requesting a read-write or read-only one, since this is in readTx
+	// and so shouldn't be mutating anyway.
+	singleUse := func() (*transaction, func(), error) {
+		tx := s.db.NewReadOnlyTransaction()
+		return tx, tx.Rollback, nil
 	}
 
 	if tsel.GetSelector() == nil {
-		return singleUseReadOnly()
+		return singleUse()
 	}
 
 	switch sel := tsel.Selector.(type) {
 	default:
-		return nil, fmt.Errorf("TransactionSelector type %T not supported", sel)
+		return nil, nil, fmt.Errorf("TransactionSelector type %T not supported", sel)
 	case *spannerpb.TransactionSelector_SingleUse:
 		// Ignore options (e.g. timestamps).
 		switch mode := sel.SingleUse.Mode.(type) {
 		case *spannerpb.TransactionOptions_ReadOnly_:
-			return singleUseReadOnly()
+			return singleUse()
 		case *spannerpb.TransactionOptions_ReadWrite_:
 			return singleUse()
 		default:
-			return nil, fmt.Errorf("single use transaction in mode %T not supported", mode)
+			return nil, nil, fmt.Errorf("single use transaction in mode %T not supported", mode)
 		}
 	case *spannerpb.TransactionSelector_Id:
-		id := sel.Id // []byte
-		_ = id       // TODO: lookup an existing transaction by ID.
-		tx := s.db.startTransaction()
-		return tx, nil
+		sess.mu.Lock()
+		tx, ok := sess.transactions[string(sel.Id)]
+		sess.mu.Unlock()
+		if !ok {
+			return nil, nil, fmt.Errorf("no transaction with id %q", sel.Id)
+		}
+		return tx, func() {}, nil
 	}
 }
 
@@ -470,11 +472,11 @@
 }
 
 func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
-	tx, err := s.readTx(stream.Context(), req.Session, req.Transaction)
+	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
 	if err != nil {
 		return err
 	}
-	defer tx.Rollback()
+	defer cleanup()
 
 	q, err := spansql.ParseQuery(req.Sql)
 	if err != nil {
@@ -502,11 +504,11 @@
 // TODO: Read
 
 func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
-	tx, err := s.readTx(stream.Context(), req.Session, req.Transaction)
+	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
 	if err != nil {
 		return err
 	}
-	defer tx.Rollback()
+	defer cleanup()
 
 	// Bail out if various advanced features are being used.
 	if req.Index != "" {
@@ -598,7 +600,7 @@
 	}
 
 	id := genRandomTransaction()
-	tx := s.db.startTransaction()
+	tx := s.db.NewTransaction()
 
 	sess.mu.Lock()
 	sess.lastUse = time.Now()
@@ -626,6 +628,7 @@
 			tx.Rollback()
 		}
 	}()
+	tx.Start()
 
 	for _, m := range req.Mutations {
 		switch op := m.Operation.(type) {