spanner/spannertest: switch to storing rows in primary key order
This is faster for non-trivial tables for many operations (because we
can use binary searches), and is a better match for production Spanner.
Change-Id: I2a3271a8bca1e466a923a1367d2220615a6e1328
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/44351
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 34619e4..255a253 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -22,6 +22,7 @@
// TODO: missing transactionality in a serious way!
import (
+ "bytes"
"fmt"
"sort"
"strconv"
@@ -51,6 +52,7 @@
colIndex map[string]int // col name to index
pkCols int // number of primary key columns (may be 0)
+ // Rows are stored in primary key order.
rows []row
}
@@ -251,12 +253,12 @@
func (d *database) Insert(tbl string, cols []string, values []*structpb.ListValue) error {
return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error {
pk := r[:t.pkCols]
- if t.rowForPK(pk) >= 0 {
+ rowNum, found := t.rowForPK(pk)
+ if found {
// TODO: how do we return `ALREADY_EXISTS`?
return status.Errorf(codes.Unknown, "row already in table")
}
-
- t.rows = append(t.rows, r)
+ t.insertRow(rowNum, r)
return nil
})
}
@@ -267,8 +269,8 @@
return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl)
}
pk := r[:t.pkCols]
- rowNum := t.rowForPK(pk)
- if rowNum < 0 {
+ rowNum, found := t.rowForPK(pk)
+ if !found {
// TODO: is this the right way to return `NOT_FOUND`?
return status.Errorf(codes.NotFound, "row not in table")
}
@@ -283,10 +285,10 @@
func (d *database) InsertOrUpdate(tbl string, cols []string, values []*structpb.ListValue) error {
return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error {
pk := r[:t.pkCols]
- rowNum := t.rowForPK(pk)
- if rowNum < 0 {
+ rowNum, found := t.rowForPK(pk)
+ if !found {
// New row; do an insert.
- t.rows = append(t.rows, r)
+ t.insertRow(rowNum, r)
} else {
// Existing row; do an update.
for _, i := range colIndexes {
@@ -319,8 +321,8 @@
return err
}
// Not an error if the key does not exist.
- rowNum := t.rowForPK(pk)
- if rowNum >= 0 {
+ rowNum, found := t.rowForPK(pk)
+ if found {
copy(t.rows[rowNum:], t.rows[rowNum+1:])
t.rows = t.rows[:len(t.rows)-1]
}
@@ -335,16 +337,10 @@
if err != nil {
return err
}
- for rowNum := 0; rowNum < len(t.rows); {
- rowPK := t.rows[rowNum][:t.pkCols]
- if !r.includePK(rowPK) {
- rowNum++
- continue
- }
-
- // Row is in range.
- copy(t.rows[rowNum:], t.rows[rowNum+1:])
- t.rows = t.rows[:len(t.rows)-1]
+ startRow, endRow := t.findRange(r)
+ if n := endRow - startRow; n > 0 {
+ copy(t.rows[startRow:], t.rows[endRow:])
+ t.rows = t.rows[:len(t.rows)-n]
}
}
@@ -414,8 +410,8 @@
return err
}
// Not an error if the key does not exist.
- rowNum := t.rowForPK(pk)
- if rowNum < 0 {
+ rowNum, found := t.rowForPK(pk)
+ if !found {
continue
}
ri.add(t.rows[rowNum], colIndexes)
@@ -512,6 +508,41 @@
return nil
}
+func (t *table) insertRow(rowNum int, r row) {
+ t.rows = append(t.rows, nil)
+ copy(t.rows[rowNum+1:], t.rows[rowNum:])
+ t.rows[rowNum] = r
+}
+
+// findRange finds the rows included in the key range,
+// reporting it as a half-open interval.
+// r.startKey and r.endKey should be populated.
+func (t *table) findRange(r *keyRange) (int, int) {
+ // TODO: This is incorrect for primary keys with descending order.
+ // It might be sufficient for the caller to switch start/end in that case.
+
+ // startRow is the first row matching the range.
+ startRow := sort.Search(len(t.rows), func(i int) bool {
+ return rowCmp(r.startKey, t.rows[i][:t.pkCols]) <= 0
+ })
+ if startRow == len(t.rows) {
+ return startRow, startRow
+ }
+ if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols]) == 0 {
+ startRow++
+ }
+
+ // endRow is one more than the last row matching the range.
+ endRow := sort.Search(len(t.rows), func(i int) bool {
+ return rowCmp(r.endKey, t.rows[i][:t.pkCols]) < 0
+ })
+ if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols]) == 0 {
+ endRow--
+ }
+
+ return startRow, endRow
+}
+
// colIndexes returns the indexes for the named columns.
func (t *table) colIndexes(cols []string) ([]int, error) {
var is []int
@@ -551,18 +582,20 @@
return pk, nil
}
-// rowForPK returns the index of t.rows that holds the row for the given primary key.
-// It returns -1 if it isn't found, including when the table's primary key has no columns.
-func (t *table) rowForPK(pk []interface{}) int {
+// rowForPK returns the index of t.rows that holds the row for the given primary key, and true.
+// If the given primary key isn't found, it returns the row that should hold it, and false.
+func (t *table) rowForPK(pk []interface{}) (row int, found bool) {
if len(pk) != t.pkCols {
panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols))
}
- for i, row := range t.rows {
- if rowCmp(pk, row[:t.pkCols]) == 0 {
- return i
- }
+
+ i := sort.Search(len(t.rows), func(i int) bool {
+ return rowCmp(pk, t.rows[i][:t.pkCols]) <= 0
+ })
+ if i == len(t.rows) {
+ return i, false
}
- return -1
+ return i, rowCmp(pk, t.rows[i][:t.pkCols]) == 0
}
// rowCmp compares two rows, returning -1/0/+1.
@@ -649,24 +682,20 @@
startKey, endKey []interface{}
}
-type keyRangeList []*keyRange
-
-func (kr *keyRange) includePK(pk []interface{}) bool {
- // rowCmp permits its first argument to be a prefix,
- // so the calls to it below use kr.fooKey as the first arg.
-
- // TODO: This is incorrect for primary keys with descending order.
- // It might be sufficient for the caller to switch start/end in that case.
-
- cmp := rowCmp(kr.startKey, pk)
- if cmp > 0 || (cmp == 0 && !kr.startClosed) {
- // Row is before range.
- return false
+func (r *keyRange) String() string {
+ var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9.
+ if r.startClosed {
+ sb.WriteString("[")
+ } else {
+ sb.WriteString("(")
}
- cmp = rowCmp(kr.endKey, pk)
- if cmp < 0 || (cmp == 0 && !kr.endClosed) {
- // Row is after range.
- return false
+ fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey)
+ if r.endClosed {
+ sb.WriteString("]")
+ } else {
+ sb.WriteString(")")
}
- return true
+ return sb.String()
}
+
+type keyRangeList []*keyRange
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index 7a7203a..b8296a5 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -148,10 +148,11 @@
}
all = slurp(ri)
wantAll = [][]interface{}{
- {int64(10), "Jack", 1.85},
+ // Primary key is (Name, ID), so results should come back sorted by Name then ID.
{int64(11), "Daniel", 1.83},
+ {int64(6), "George", 1.73},
+ {int64(10), "Jack", 1.85},
{int64(9), "Sam", 1.75},
- {int64(8), "Teal'c", 1.91},
}
if !reflect.DeepEqual(all, wantAll) {
t.Errorf("ReadAll data wrong.\n got %v\nwant %v", all, wantAll)
@@ -336,7 +337,7 @@
}
}
-func TestKeyRangeInclude(t *testing.T) {
+func TestKeyRange(t *testing.T) {
r := func(x ...interface{}) []interface{} { return x }
closedClosed := func(start, end []interface{}) *keyRange {
return &keyRange{
@@ -353,6 +354,12 @@
startClosed: true,
}
}
+ openOpen := func(start, end []interface{}) *keyRange {
+ return &keyRange{
+ startKey: start,
+ endKey: end,
+ }
+ }
tests := []struct {
kr *keyRange
include [][]interface{}
@@ -400,6 +407,20 @@
},
exclude: [][]interface{}{
r("Alice", "1999-11-07"),
+ r("Bob", "2000-01-01"),
+ r("Bob", "2004-07-07"),
+ r("Charlie", "1999-11-07"),
+ },
+ },
+ {
+ kr: openOpen(r("Bob", "1999-11-06"), r("Bob", "2000-01-01")),
+ include: [][]interface{}{
+ r("Bob", "1999-11-07"),
+ },
+ exclude: [][]interface{}{
+ r("Alice", "1999-11-07"),
+ r("Bob", "1999-11-06"),
+ r("Bob", "2000-01-01"),
r("Bob", "2004-07-07"),
r("Charlie", "1999-11-07"),
},
@@ -426,14 +447,26 @@
},
}
for _, test := range tests {
+ tbl := &table{
+ pkCols: 2,
+ }
+ for _, pk := range append(test.include, test.exclude...) {
+ rowNum, _ := tbl.rowForPK(pk)
+ tbl.insertRow(rowNum, pk)
+ }
+ start, end := tbl.findRange(test.kr)
+ has := func(pk []interface{}) bool {
+ n, _ := tbl.rowForPK(pk)
+ return start <= n && n < end
+ }
for _, pk := range test.include {
- if !test.kr.includePK(pk) {
- t.Errorf("(%v).includePK(%v) = false, want true", test.kr, pk)
+ if !has(pk) {
+ t.Errorf("keyRange %v does not include %v", test.kr, pk)
}
}
for _, pk := range test.exclude {
- if test.kr.includePK(pk) {
- t.Errorf("(%v).includePK(%v) = true, want false", test.kr, pk)
+ if has(pk) {
+ t.Errorf("keyRange %v includes %v", test.kr, pk)
}
}
}