spanner/spannertest: bind column info to evalContext, disconnect from table
This is more correct in the presence of aggregation and other similar
advanced features, and better isolates expression evaluation from a
table context.
Change-Id: Ice1a417520f7c188ec9588c3843233b02dc08155
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52737
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 6fb3d8e..a099b58 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -64,8 +64,9 @@
// colInfo represents information about a column in a table or result set.
type colInfo struct {
- Name string
- Type spansql.Type
+ Name string
+ Type spansql.Type
+ AggIndex int // Index+1 of SELECT list for which this is an aggregate value.
}
// commitTimestampSentinel is a sentinel value for TIMESTAMP fields with allow_commit_timestamp=true.
@@ -869,7 +870,7 @@
n := 0
for i := 0; i < len(t.rows); {
ec := evalContext{
- table: t,
+ cols: t.cols,
row: t.rows[i],
params: params,
}
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index 750eba3..1da4fc1 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -29,8 +29,10 @@
// evalContext represents the context for evaluating an expression.
type evalContext struct {
- table *table // may be nil
- row row // set if table is set, only during expr evaluation
+ // cols and row are set during expr evaluation.
+ cols []colInfo
+ row row
+
params queryParams
}
@@ -357,22 +359,30 @@
case spansql.IsOp:
return ec.evalBoolExpr(e)
case aggSentinel:
- // Aggregate value is always last in the row.
- // TODO: This could be tightened up by including colInfo in evalContext.
- return ec.row[len(ec.row)-1], nil
+ // Match up e.AggIndex with the column.
+ // They might have been reordered.
+ ci := -1
+ for i, col := range ec.cols {
+ if col.AggIndex == e.AggIndex {
+ ci = i
+ break
+ }
+ }
+ if ci < 0 {
+ return 0, fmt.Errorf("internal error: did not find aggregate column %d", e.AggIndex)
+ }
+ return ec.row[ci], nil
}
}
func (ec evalContext) evalID(id spansql.ID) (interface{}, error) {
// TODO: look beyond column names.
- if ec.table == nil {
- return nil, fmt.Errorf("identifier %s when not SELECTing on a table is not supported", string(id))
+ for i, col := range ec.cols {
+ if col.Name == string(id) {
+ return ec.row.copyDataElem(i), nil
+ }
}
- i, ok := ec.table.colIndex[string(id)]
- if !ok {
- return nil, fmt.Errorf("couldn't resolve identifier %s", string(id))
- }
- return ec.row.copyDataElem(i), nil
+ return nil, fmt.Errorf("couldn't resolve identifier %s", string(id))
}
func evalLimit(lim spansql.Limit, params queryParams) (int64, error) {
@@ -507,10 +517,9 @@
return colInfo{Type: spansql.Type{Base: spansql.Bool}}, nil
case spansql.ID:
// TODO: support more than only naming a table column.
- name := string(e)
- if ec.table != nil {
- if i, ok := ec.table.colIndex[name]; ok {
- return ec.table.cols[i], nil
+ for _, col := range ec.cols {
+ if col.Name == string(e) {
+ return col, nil
}
}
case spansql.Paren:
@@ -520,7 +529,7 @@
// Empirically, though, the real Spanner returns Int64.
return colInfo{Type: int64Type}, nil
case aggSentinel:
- return colInfo{Type: e.Type}, nil
+ return colInfo{Type: e.Type, AggIndex: e.AggIndex}, nil
}
return colInfo{}, fmt.Errorf("can't deduce column type from expression [%s]", e.SQL())
}
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index 160d7f4..503cc99 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -58,7 +58,8 @@
// It is transient only; it is never stored and only used during evaluation.
type aggSentinel struct {
spansql.Expr
- Type spansql.Type
+ Type spansql.Type
+ AggIndex int // Index+1 of SELECT list.
}
// nullIter is a rowIter that returns one empty row only.
@@ -305,7 +306,7 @@
t.mu.Lock()
defer t.mu.Unlock()
ri = &tableIter{t: t}
- ec.table = t
+ ec.cols = t.cols
}
defer func() {
// If we're about to return a tableIter, convert it to a rawIter
@@ -455,14 +456,16 @@
aggType = int64Type
}
rawOut.cols = append(raw.cols, colInfo{
- // TODO: Generate more interesting colInfo?
- Name: fexpr.SQL(),
- Type: aggType,
+ Name: fexpr.SQL(),
+ Type: aggType,
+ AggIndex: aggI + 1,
})
ri = rawOut
+ ec.cols = rawOut.cols
sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value.
- Type: aggType,
+ Type: aggType,
+ AggIndex: aggI + 1,
}
}
@@ -474,7 +477,7 @@
selectStar := len(sel.List) == 1 && sel.List[0] == spansql.Star
if selectStar {
// Every column will appear in the output.
- colInfos = append([]colInfo(nil), ec.table.cols...)
+ colInfos = ec.cols
} else {
for _, e := range sel.List {
ci, err := ec.colInfo(e)