spanner/spannertest: fix data race in query execution

Whenever a table is used as the data source for a query (most SELECT
statements), the table needs to be completely read before evalSelect
returns, since that is what holds the table lock. The prior code
erroneously only did that when the outermost rowIter was a tableIter,
when that is in fact never the case. The correct approach is to flatten
the output rowIter whenever a tableIter is used as a data source.

Change-Id: I21e26ad06e7c2bb295c583d2a34178122cd2b5c1
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52995
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index c3ac1cf..f84ebd7 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -171,6 +171,15 @@
 	return v
 }
 
+// copyData returns a copy of the row.
+func (r row) copyAllData() row {
+	dst := make(row, 0, len(r))
+	for i := range r {
+		dst = append(dst, r.copyDataElem(i))
+	}
+	return dst
+}
+
 // copyData returns a copy of a subset of a row.
 func (r row) copyData(indexes []int) row {
 	if len(indexes) == 0 {
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index 86ddf5b..698cdc6 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -129,7 +129,7 @@
 		} else if err != nil {
 			return nil, err
 		}
-		raw.rows = append(raw.rows, row)
+		raw.rows = append(raw.rows, row.copyAllData())
 	}
 	return raw, nil
 }
@@ -347,16 +347,15 @@
 		defer t.mu.Unlock()
 		ri = &tableIter{t: t}
 		ec.cols = t.cols
-	}
-	defer func() {
-		// If we're about to return a tableIter, convert it to a rawIter
+
+		// On the way out, convert the result to a rawIter
 		// so that the table may be safely unlocked.
-		if evalErr == nil {
-			if ti, ok := ri.(*tableIter); ok {
-				ri, evalErr = toRawIter(ti)
+		defer func() {
+			if evalErr == nil {
+				ri, evalErr = toRawIter(ri)
 			}
-		}
-	}()
+		}()
+	}
 
 	// Apply WHERE.
 	if sel.Where != nil {
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index ac46aa5..2c28197 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -19,6 +19,7 @@
 import (
 	"io"
 	"reflect"
+	"sync"
 	"testing"
 
 	"google.golang.org/grpc/codes"
@@ -736,6 +737,79 @@
 	}
 }
 
+func TestConcurrentReadInsert(t *testing.T) {
+	// Check that data is safely copied during a query.
+	tbl := &spansql.CreateTable{
+		Name: "Tablino",
+		Columns: []spansql.ColumnDef{
+			{Name: "A", Type: spansql.Type{Base: spansql.Int64}},
+		},
+		PrimaryKey: []spansql.KeyPart{{Column: "A"}},
+	}
+
+	var db database
+	if st := db.ApplyDDL(tbl); st.Code() != codes.OK {
+		t.Fatalf("Creating table: %v", st.Err())
+	}
+
+	// Insert some initial data.
+	tx := db.NewTransaction()
+	tx.Start()
+	err := db.Insert(tx, "Tablino", []string{"A"}, []*structpb.ListValue{
+		listV(stringV("1")),
+		listV(stringV("2")),
+		listV(stringV("4")),
+	})
+	if err != nil {
+		t.Fatalf("Inserting data: %v", err)
+	}
+	if _, err := tx.Commit(); err != nil {
+		t.Fatalf("Committing changes: %v", err)
+	}
+
+	// Now insert "3", and query concurrently.
+	q, err := spansql.ParseQuery(`SELECT * FROM Tablino WHERE A > 2`)
+	if err != nil {
+		t.Fatalf("ParseQuery: %v", err)
+	}
+	var out [][]interface{}
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+	go func() {
+		defer wg.Done()
+
+		ri, err := db.Query(q, nil)
+		if err != nil {
+			t.Errorf("Query: %v", err)
+			return
+		}
+		out = slurp(t, ri)
+	}()
+	go func() {
+		defer wg.Done()
+
+		tx := db.NewTransaction()
+		tx.Start()
+		err := db.Insert(tx, "Tablino", []string{"A"}, []*structpb.ListValue{
+			listV(stringV("3")),
+		})
+		if err != nil {
+			t.Errorf("Inserting data: %v", err)
+			return
+		}
+		if _, err := tx.Commit(); err != nil {
+			t.Errorf("Committing changes: %v", err)
+		}
+	}()
+	wg.Wait()
+
+	// We should get either 1 or 2 rows (value 4 should be included, and value 3 might).
+	if n := len(out); n != 1 && n != 2 {
+		t.Fatalf("Concurrent read returned %d rows, want 1 or 2", n)
+	}
+}
+
 func slurp(t *testing.T, ri rowIter) (all [][]interface{}) {
 	t.Helper()
 	for {