spanner/spannertest: switch to storing rows in primary key order

This is faster for non-trivial tables for many operations (because we
can use binary searches), and is a better match for production Spanner.

Change-Id: I2a3271a8bca1e466a923a1367d2220615a6e1328
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/44351
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 34619e4..255a253 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -22,6 +22,7 @@
 // TODO: missing transactionality in a serious way!
 
 import (
+	"bytes"
 	"fmt"
 	"sort"
 	"strconv"
@@ -51,6 +52,7 @@
 	colIndex map[string]int // col name to index
 	pkCols   int            // number of primary key columns (may be 0)
 
+	// Rows are stored in primary key order.
 	rows []row
 }
 
@@ -251,12 +253,12 @@
 func (d *database) Insert(tbl string, cols []string, values []*structpb.ListValue) error {
 	return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error {
 		pk := r[:t.pkCols]
-		if t.rowForPK(pk) >= 0 {
+		rowNum, found := t.rowForPK(pk)
+		if found {
 			// TODO: how do we return `ALREADY_EXISTS`?
 			return status.Errorf(codes.Unknown, "row already in table")
 		}
-
-		t.rows = append(t.rows, r)
+		t.insertRow(rowNum, r)
 		return nil
 	})
 }
@@ -267,8 +269,8 @@
 			return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl)
 		}
 		pk := r[:t.pkCols]
-		rowNum := t.rowForPK(pk)
-		if rowNum < 0 {
+		rowNum, found := t.rowForPK(pk)
+		if !found {
 			// TODO: is this the right way to return `NOT_FOUND`?
 			return status.Errorf(codes.NotFound, "row not in table")
 		}
@@ -283,10 +285,10 @@
 func (d *database) InsertOrUpdate(tbl string, cols []string, values []*structpb.ListValue) error {
 	return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error {
 		pk := r[:t.pkCols]
-		rowNum := t.rowForPK(pk)
-		if rowNum < 0 {
+		rowNum, found := t.rowForPK(pk)
+		if !found {
 			// New row; do an insert.
-			t.rows = append(t.rows, r)
+			t.insertRow(rowNum, r)
 		} else {
 			// Existing row; do an update.
 			for _, i := range colIndexes {
@@ -319,8 +321,8 @@
 			return err
 		}
 		// Not an error if the key does not exist.
-		rowNum := t.rowForPK(pk)
-		if rowNum >= 0 {
+		rowNum, found := t.rowForPK(pk)
+		if found {
 			copy(t.rows[rowNum:], t.rows[rowNum+1:])
 			t.rows = t.rows[:len(t.rows)-1]
 		}
@@ -335,16 +337,10 @@
 		if err != nil {
 			return err
 		}
-		for rowNum := 0; rowNum < len(t.rows); {
-			rowPK := t.rows[rowNum][:t.pkCols]
-			if !r.includePK(rowPK) {
-				rowNum++
-				continue
-			}
-
-			// Row is in range.
-			copy(t.rows[rowNum:], t.rows[rowNum+1:])
-			t.rows = t.rows[:len(t.rows)-1]
+		startRow, endRow := t.findRange(r)
+		if n := endRow - startRow; n > 0 {
+			copy(t.rows[startRow:], t.rows[endRow:])
+			t.rows = t.rows[:len(t.rows)-n]
 		}
 	}
 
@@ -414,8 +410,8 @@
 				return err
 			}
 			// Not an error if the key does not exist.
-			rowNum := t.rowForPK(pk)
-			if rowNum < 0 {
+			rowNum, found := t.rowForPK(pk)
+			if !found {
 				continue
 			}
 			ri.add(t.rows[rowNum], colIndexes)
@@ -512,6 +508,41 @@
 	return nil
 }
 
+func (t *table) insertRow(rowNum int, r row) {
+	t.rows = append(t.rows, nil)
+	copy(t.rows[rowNum+1:], t.rows[rowNum:])
+	t.rows[rowNum] = r
+}
+
+// findRange finds the rows included in the key range,
+// reporting it as a half-open interval.
+// r.startKey and r.endKey should be populated.
+func (t *table) findRange(r *keyRange) (int, int) {
+	// TODO: This is incorrect for primary keys with descending order.
+	// It might be sufficient for the caller to switch start/end in that case.
+
+	// startRow is the first row matching the range.
+	startRow := sort.Search(len(t.rows), func(i int) bool {
+		return rowCmp(r.startKey, t.rows[i][:t.pkCols]) <= 0
+	})
+	if startRow == len(t.rows) {
+		return startRow, startRow
+	}
+	if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols]) == 0 {
+		startRow++
+	}
+
+	// endRow is one more than the last row matching the range.
+	endRow := sort.Search(len(t.rows), func(i int) bool {
+		return rowCmp(r.endKey, t.rows[i][:t.pkCols]) < 0
+	})
+	if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols]) == 0 {
+		endRow--
+	}
+
+	return startRow, endRow
+}
+
 // colIndexes returns the indexes for the named columns.
 func (t *table) colIndexes(cols []string) ([]int, error) {
 	var is []int
@@ -551,18 +582,20 @@
 	return pk, nil
 }
 
-// rowForPK returns the index of t.rows that holds the row for the given primary key.
-// It returns -1 if it isn't found, including when the table's primary key has no columns.
-func (t *table) rowForPK(pk []interface{}) int {
+// rowForPK returns the index of t.rows that holds the row for the given primary key, and true.
+// If the given primary key isn't found, it returns the row that should hold it, and false.
+func (t *table) rowForPK(pk []interface{}) (row int, found bool) {
 	if len(pk) != t.pkCols {
 		panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols))
 	}
-	for i, row := range t.rows {
-		if rowCmp(pk, row[:t.pkCols]) == 0 {
-			return i
-		}
+
+	i := sort.Search(len(t.rows), func(i int) bool {
+		return rowCmp(pk, t.rows[i][:t.pkCols]) <= 0
+	})
+	if i == len(t.rows) {
+		return i, false
 	}
-	return -1
+	return i, rowCmp(pk, t.rows[i][:t.pkCols]) == 0
 }
 
 // rowCmp compares two rows, returning -1/0/+1.
@@ -649,24 +682,20 @@
 	startKey, endKey []interface{}
 }
 
-type keyRangeList []*keyRange
-
-func (kr *keyRange) includePK(pk []interface{}) bool {
-	// rowCmp permits its first argument to be a prefix,
-	// so the calls to it below use kr.fooKey as the first arg.
-
-	// TODO: This is incorrect for primary keys with descending order.
-	// It might be sufficient for the caller to switch start/end in that case.
-
-	cmp := rowCmp(kr.startKey, pk)
-	if cmp > 0 || (cmp == 0 && !kr.startClosed) {
-		// Row is before range.
-		return false
+func (r *keyRange) String() string {
+	var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9.
+	if r.startClosed {
+		sb.WriteString("[")
+	} else {
+		sb.WriteString("(")
 	}
-	cmp = rowCmp(kr.endKey, pk)
-	if cmp < 0 || (cmp == 0 && !kr.endClosed) {
-		// Row is after range.
-		return false
+	fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey)
+	if r.endClosed {
+		sb.WriteString("]")
+	} else {
+		sb.WriteString(")")
 	}
-	return true
+	return sb.String()
 }
+
+type keyRangeList []*keyRange
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index 7a7203a..b8296a5 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -148,10 +148,11 @@
 	}
 	all = slurp(ri)
 	wantAll = [][]interface{}{
-		{int64(10), "Jack", 1.85},
+		// Primary key is (Name, ID), so results should come back sorted by Name then ID.
 		{int64(11), "Daniel", 1.83},
+		{int64(6), "George", 1.73},
+		{int64(10), "Jack", 1.85},
 		{int64(9), "Sam", 1.75},
-		{int64(8), "Teal'c", 1.91},
 	}
 	if !reflect.DeepEqual(all, wantAll) {
 		t.Errorf("ReadAll data wrong.\n got %v\nwant %v", all, wantAll)
@@ -336,7 +337,7 @@
 	}
 }
 
-func TestKeyRangeInclude(t *testing.T) {
+func TestKeyRange(t *testing.T) {
 	r := func(x ...interface{}) []interface{} { return x }
 	closedClosed := func(start, end []interface{}) *keyRange {
 		return &keyRange{
@@ -353,6 +354,12 @@
 			startClosed: true,
 		}
 	}
+	openOpen := func(start, end []interface{}) *keyRange {
+		return &keyRange{
+			startKey: start,
+			endKey:   end,
+		}
+	}
 	tests := []struct {
 		kr      *keyRange
 		include [][]interface{}
@@ -400,6 +407,20 @@
 			},
 			exclude: [][]interface{}{
 				r("Alice", "1999-11-07"),
+				r("Bob", "2000-01-01"),
+				r("Bob", "2004-07-07"),
+				r("Charlie", "1999-11-07"),
+			},
+		},
+		{
+			kr: openOpen(r("Bob", "1999-11-06"), r("Bob", "2000-01-01")),
+			include: [][]interface{}{
+				r("Bob", "1999-11-07"),
+			},
+			exclude: [][]interface{}{
+				r("Alice", "1999-11-07"),
+				r("Bob", "1999-11-06"),
+				r("Bob", "2000-01-01"),
 				r("Bob", "2004-07-07"),
 				r("Charlie", "1999-11-07"),
 			},
@@ -426,14 +447,26 @@
 		},
 	}
 	for _, test := range tests {
+		tbl := &table{
+			pkCols: 2,
+		}
+		for _, pk := range append(test.include, test.exclude...) {
+			rowNum, _ := tbl.rowForPK(pk)
+			tbl.insertRow(rowNum, pk)
+		}
+		start, end := tbl.findRange(test.kr)
+		has := func(pk []interface{}) bool {
+			n, _ := tbl.rowForPK(pk)
+			return start <= n && n < end
+		}
 		for _, pk := range test.include {
-			if !test.kr.includePK(pk) {
-				t.Errorf("(%v).includePK(%v) = false, want true", test.kr, pk)
+			if !has(pk) {
+				t.Errorf("keyRange %v does not include %v", test.kr, pk)
 			}
 		}
 		for _, pk := range test.exclude {
-			if test.kr.includePK(pk) {
-				t.Errorf("(%v).includePK(%v) = true, want false", test.kr, pk)
+			if has(pk) {
+				t.Errorf("keyRange %v includes %v", test.kr, pk)
 			}
 		}
 	}