spanner/spannertest: bind column info to evalContext, disconnect from table

This is more correct in the presence of aggregation and other similar
advanced features, and better isolates expression evaluation from a
table context.

Change-Id: Ice1a417520f7c188ec9588c3843233b02dc08155
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52737
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 6fb3d8e..a099b58 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -64,8 +64,9 @@
 
 // colInfo represents information about a column in a table or result set.
 type colInfo struct {
-	Name string
-	Type spansql.Type
+	Name     string
+	Type     spansql.Type
+	AggIndex int // Index+1 of SELECT list for which this is an aggregate value.
 }
 
 // commitTimestampSentinel is a sentinel value for TIMESTAMP fields with allow_commit_timestamp=true.
@@ -869,7 +870,7 @@
 		n := 0
 		for i := 0; i < len(t.rows); {
 			ec := evalContext{
-				table:  t,
+				cols:   t.cols,
 				row:    t.rows[i],
 				params: params,
 			}
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index 750eba3..1da4fc1 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -29,8 +29,10 @@
 
 // evalContext represents the context for evaluating an expression.
 type evalContext struct {
-	table  *table // may be nil
-	row    row    // set if table is set, only during expr evaluation
+	// cols and row are set during expr evaluation.
+	cols []colInfo
+	row  row
+
 	params queryParams
 }
 
@@ -357,22 +359,30 @@
 	case spansql.IsOp:
 		return ec.evalBoolExpr(e)
 	case aggSentinel:
-		// Aggregate value is always last in the row.
-		// TODO: This could be tightened up by including colInfo in evalContext.
-		return ec.row[len(ec.row)-1], nil
+		// Match up e.AggIndex with the column.
+		// They might have been reordered.
+		ci := -1
+		for i, col := range ec.cols {
+			if col.AggIndex == e.AggIndex {
+				ci = i
+				break
+			}
+		}
+		if ci < 0 {
+			return 0, fmt.Errorf("internal error: did not find aggregate column %d", e.AggIndex)
+		}
+		return ec.row[ci], nil
 	}
 }
 
 func (ec evalContext) evalID(id spansql.ID) (interface{}, error) {
 	// TODO: look beyond column names.
-	if ec.table == nil {
-		return nil, fmt.Errorf("identifier %s when not SELECTing on a table is not supported", string(id))
+	for i, col := range ec.cols {
+		if col.Name == string(id) {
+			return ec.row.copyDataElem(i), nil
+		}
 	}
-	i, ok := ec.table.colIndex[string(id)]
-	if !ok {
-		return nil, fmt.Errorf("couldn't resolve identifier %s", string(id))
-	}
-	return ec.row.copyDataElem(i), nil
+	return nil, fmt.Errorf("couldn't resolve identifier %s", string(id))
 }
 
 func evalLimit(lim spansql.Limit, params queryParams) (int64, error) {
@@ -507,10 +517,9 @@
 		return colInfo{Type: spansql.Type{Base: spansql.Bool}}, nil
 	case spansql.ID:
 		// TODO: support more than only naming a table column.
-		name := string(e)
-		if ec.table != nil {
-			if i, ok := ec.table.colIndex[name]; ok {
-				return ec.table.cols[i], nil
+		for _, col := range ec.cols {
+			if col.Name == string(e) {
+				return col, nil
 			}
 		}
 	case spansql.Paren:
@@ -520,7 +529,7 @@
 		// Empirically, though, the real Spanner returns Int64.
 		return colInfo{Type: int64Type}, nil
 	case aggSentinel:
-		return colInfo{Type: e.Type}, nil
+		return colInfo{Type: e.Type, AggIndex: e.AggIndex}, nil
 	}
 	return colInfo{}, fmt.Errorf("can't deduce column type from expression [%s]", e.SQL())
 }
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index 160d7f4..503cc99 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -58,7 +58,8 @@
 // It is transient only; it is never stored and only used during evaluation.
 type aggSentinel struct {
 	spansql.Expr
-	Type spansql.Type
+	Type     spansql.Type
+	AggIndex int // Index+1 of SELECT list.
 }
 
 // nullIter is a rowIter that returns one empty row only.
@@ -305,7 +306,7 @@
 		t.mu.Lock()
 		defer t.mu.Unlock()
 		ri = &tableIter{t: t}
-		ec.table = t
+		ec.cols = t.cols
 	}
 	defer func() {
 		// If we're about to return a tableIter, convert it to a rawIter
@@ -455,14 +456,16 @@
 			aggType = int64Type
 		}
 		rawOut.cols = append(raw.cols, colInfo{
-			// TODO: Generate more interesting colInfo?
-			Name: fexpr.SQL(),
-			Type: aggType,
+			Name:     fexpr.SQL(),
+			Type:     aggType,
+			AggIndex: aggI + 1,
 		})
 
 		ri = rawOut
+		ec.cols = rawOut.cols
 		sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value.
-			Type: aggType,
+			Type:     aggType,
+			AggIndex: aggI + 1,
 		}
 	}
 
@@ -474,7 +477,7 @@
 	selectStar := len(sel.List) == 1 && sel.List[0] == spansql.Star
 	if selectStar {
 		// Every column will appear in the output.
-		colInfos = append([]colInfo(nil), ec.table.cols...)
+		colInfos = ec.cols
 	} else {
 		for _, e := range sel.List {
 			ci, err := ec.colInfo(e)