spanner/spannertest: implement GROUP BY, aggregation

This adds full support for GROUP BY and applying aggregation using COUNT
and SUM as part of a SELECT list.

Also revise TODO list in README.md.

Change-Id: I52c8c41ee4751b1ed4c2632e09e3bbda9b953b93
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52735
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md
index 9ccb19f..edb9656 100644
--- a/spanner/spannertest/README.md
+++ b/spanner/spannertest/README.md
@@ -15,17 +15,16 @@
 Here's a list of features that are missing or incomplete. It is roughly ordered
 by ascending esotericism:
 
-- SELECT GROUP BY
-- SELECT HAVING
-- arithmetic expressions (operators, parens)
-- transaction simulation
+- expression functions
+- more aggregation functions
 - INSERT/UPDATE DML statements
 - case insensitivity
 - alternate literal types (esp. strings)
 - STRUCT types
-- expression functions
-- expression type casting, coercion
+- SELECT HAVING
 - joins
+- transaction simulation
+- expression type casting, coercion
 - query offset
 - SELECT aliases
 - subselects
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 63075fe..6fb3d8e 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -207,6 +207,9 @@
 		if _, ok := d.tables[stmt.Name]; ok {
 			return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name)
 		}
+		if len(stmt.PrimaryKey) == 0 {
+			return status.Newf(codes.InvalidArgument, "table %s has no primary key", stmt.Name)
+		}
 
 		// TODO: check stmt.Interleave details.
 
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index be20e55..750eba3 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -356,6 +356,10 @@
 		return ec.evalBoolExpr(e)
 	case spansql.IsOp:
 		return ec.evalBoolExpr(e)
+	case aggSentinel:
+		// Aggregate value is always last in the row.
+		// TODO: This could be tightened up by including colInfo in evalContext.
+		return ec.row[len(ec.row)-1], nil
 	}
 }
 
@@ -515,6 +519,8 @@
 		// There isn't necessarily something sensible here.
 		// Empirically, though, the real Spanner returns Int64.
 		return colInfo{Type: int64Type}, nil
+	case aggSentinel:
+		return colInfo{Type: e.Type}, nil
 	}
 	return colInfo{}, fmt.Errorf("can't deduce column type from expression [%s]", e.SQL())
 }
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index fff96d1..160d7f4 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -33,8 +33,8 @@
 The order of operations among those supported by Cloud Spanner is
 	FROM + JOIN + set ops [TODO: JOIN and set ops]
 	WHERE
-	GROUP BY [TODO]
-	aggregation [TODO]
+	GROUP BY
+	aggregation
 	HAVING [TODO]
 	SELECT
 	DISTINCT
@@ -54,6 +54,13 @@
 	Next() (row, error)
 }
 
+// aggSentinel is a synthetic expression that refers to an aggregated value.
+// It is transient only; it is never stored and only used during evaluation.
+type aggSentinel struct {
+	spansql.Expr
+	Type spansql.Type
+}
+
 // nullIter is a rowIter that returns one empty row only.
 // This is used for queries without a table.
 type nullIter struct {
@@ -319,23 +326,144 @@
 		}
 	}
 
-	// Handle COUNT(*) specially.
-	// TODO: Handle aggregation more generally.
-	if len(sel.List) == 1 && isCountStar(sel.List[0]) {
-		// Replace the `COUNT(*)` with `1`, then aggregate on the way out.
-		sel.List[0] = spansql.IntegerLiteral(1)
-		defer func() {
-			if evalErr != nil {
-				return
-			}
-			raw, err := toRawIter(ri)
+	// Apply GROUP BY.
+	// This only reorders rows to group rows together;
+	// aggregation happens next.
+	var rowGroups [][2]int // Sequence of half-open intervals of row numbers.
+	if len(sel.GroupBy) > 0 {
+		raw, err := toRawIter(ri)
+		if err != nil {
+			return nil, err
+		}
+		keys := make([][]interface{}, 0, len(raw.rows))
+		for _, row := range raw.rows {
+			// Evaluate sort key for this row.
+			// TODO: Support referring to expression names in the SELECT list;
+			// this may require passing through sel.List, or maybe mutating
+			// sel.GroupBy to copy the referenced values. This will also be
+			// required to support grouping by aliases.
+			ec.row = row
+			key, err := ec.evalExprList(sel.GroupBy)
 			if err != nil {
-				ri, evalErr = nil, err
+				return nil, err
 			}
-			count := int64(len(raw.rows))
-			raw.rows = []row{{count}}
-			ri, evalErr = raw, nil
-		}()
+			keys = append(keys, key)
+		}
+
+		// Reorder rows base on the evaluated keys.
+		ers := externalRowSorter{rows: raw.rows, keys: keys}
+		sort.Sort(ers)
+		raw.rows = ers.rows
+
+		// Record groups as a sequence of row intervals.
+		// Each group is a run of the same keys.
+		start := 0
+		for i := 1; i < len(keys); i++ {
+			if compareValLists(keys[i-1], keys[i], nil) == 0 {
+				continue
+			}
+			rowGroups = append(rowGroups, [2]int{start, i})
+			start = i
+		}
+		if len(keys) > 0 {
+			rowGroups = append(rowGroups, [2]int{start, len(keys)})
+		}
+
+		ri = raw
+	}
+
+	// Handle aggregation.
+	// TODO: Support more than one aggregation function; does Spanner support that?
+	aggI := -1
+	for i, e := range sel.List {
+		// Supported aggregate funcs have exactly one arg.
+		f, ok := e.(spansql.Func)
+		if !ok || len(f.Args) != 1 {
+			continue
+		}
+		_, ok = aggregateFuncs[f.Name]
+		if !ok {
+			continue
+		}
+		if aggI > -1 {
+			return nil, fmt.Errorf("only one aggregate function is supported")
+		}
+		aggI = i
+	}
+	if aggI > -1 {
+		raw, err := toRawIter(ri)
+		if err != nil {
+			return nil, err
+		}
+		if len(rowGroups) == 0 {
+			// No grouping, so aggregation applies to the entire table (e.g. COUNT(*)).
+			rowGroups = [][2]int{{0, len(raw.rows)}}
+		}
+		fexpr := sel.List[aggI].(spansql.Func)
+		fn := aggregateFuncs[fexpr.Name]
+		starArg := fexpr.Args[0] == spansql.Star
+		if starArg && !fn.AcceptStar {
+			return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name)
+		}
+
+		// Prepare output.
+		rawOut := &rawIter{
+			// Same as input columns, but also the aggregate value.
+			// Add the colInfo for the aggregate at the end
+			// so we know the type.
+			// Make a copy for safety.
+			cols: append([]colInfo(nil), raw.cols...),
+		}
+
+		var aggType spansql.Type
+		for _, rg := range rowGroups {
+			// Compute aggregate value across this group.
+			var values []interface{}
+			for i := rg[0]; i < rg[1]; i++ {
+				ec.row = raw.rows[i]
+				if starArg {
+					// A non-NULL placeholder is sufficient for aggregation.
+					values = append(values, 1)
+				} else {
+					x, err := ec.evalExpr(fexpr.Args[0])
+					if err != nil {
+						return nil, err
+					}
+					values = append(values, x)
+				}
+			}
+			x, typ, err := fn.Eval(values)
+			if err != nil {
+				return nil, err
+			}
+			aggType = typ
+			// Output for the row group is the first row of the group (arbitrary,
+			// but it should be representative), and the aggregate value.
+			// TODO: Should this exclude the aggregated expressions so they can't be selected?
+			repRow := raw.rows[rg[0]]
+			var outRow row
+			for i := range repRow {
+				outRow = append(outRow, repRow.copyDataElem(i))
+			}
+			outRow = append(outRow, x)
+			rawOut.rows = append(rawOut.rows, outRow)
+		}
+
+		if aggType == (spansql.Type{}) {
+			// Fallback; there might not be any groups.
+			// TODO: Should this be in aggregateFunc?
+			aggType = int64Type
+		}
+		rawOut.cols = append(raw.cols, colInfo{
+			// TODO: Generate more interesting colInfo?
+			Name: fexpr.SQL(),
+			Type: aggType,
+		})
+
+		ri = rawOut
+		sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value.
+			Type: aggType,
+		}
 	}
 
 	// TODO: Support table sampling.
@@ -372,13 +500,18 @@
 	return ri, nil
 }
 
-func isCountStar(e spansql.Expr) bool {
-	f, ok := e.(spansql.Func)
-	if !ok {
-		return false
-	}
-	if f.Name != "COUNT" || len(f.Args) != 1 {
-		return false
-	}
-	return f.Args[0] == spansql.Star
+// externalRowSorter implements sort.Interface for a slice of rows
+// with an external sort key.
+type externalRowSorter struct {
+	rows []row
+	keys [][]interface{}
+}
+
+func (ers externalRowSorter) Len() int { return len(ers.rows) }
+func (ers externalRowSorter) Less(i, j int) bool {
+	return compareValLists(ers.keys[i], ers.keys[j], nil) < 0
+}
+func (ers externalRowSorter) Swap(i, j int) {
+	ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i]
+	ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i]
 }
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index 439f039..09eb4cc 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -300,6 +300,41 @@
 		t.Fatalf("Committing changes: %v", err)
 	}
 
+	// Prepare the sample tables from the Cloud Spanner docs.
+	// https://cloud.google.com/spanner/docs/query-syntax#appendix-a-examples-with-sample-data
+	for _, ct := range []*spansql.CreateTable{
+		// TODO: Roster, TeamMascot when we implement JOINs.
+		{
+			Name: "PlayerStats",
+			Columns: []spansql.ColumnDef{
+				{Name: "LastName", Type: spansql.Type{Base: spansql.String}},
+				{Name: "OpponentID", Type: spansql.Type{Base: spansql.Int64}},
+				{Name: "PointsScored", Type: spansql.Type{Base: spansql.Int64}},
+			},
+			PrimaryKey: []spansql.KeyPart{{Column: "LastName"}, {Column: "OpponentID"}}, // TODO: is this right?
+		},
+	} {
+		st := db.ApplyDDL(ct)
+		if st.Code() != codes.OK {
+			t.Fatalf("Creating table: %v", st.Err())
+		}
+	}
+	tx = db.NewTransaction()
+	tx.Start()
+	err = db.Insert(tx, "PlayerStats", []string{"LastName", "OpponentID", "PointsScored"}, []*structpb.ListValue{
+		listV(stringV("Adams"), stringV("51"), stringV("3")),
+		listV(stringV("Buchanan"), stringV("77"), stringV("0")),
+		listV(stringV("Coolidge"), stringV("77"), stringV("1")),
+		listV(stringV("Adams"), stringV("52"), stringV("4")),
+		listV(stringV("Buchanan"), stringV("50"), stringV("13")),
+	})
+	if err != nil {
+		t.Fatalf("Inserting data: %v", err)
+	}
+	if _, err := tx.Commit(); err != nil {
+		t.Fatalf("Commiting changes: %v", err)
+	}
+
 	// Do some complex queries.
 	tests := []struct {
 		q      string
@@ -431,6 +466,18 @@
 				{true, false},
 			},
 		},
+		// From https://cloud.google.com/spanner/docs/query-syntax#group-by-clause_1:
+		{
+			// TODO: Ordering matters? Our implementation sorts by the GROUP BY key,
+			// but nothing documented seems to guarantee that.
+			`SELECT LastName, SUM(PointsScored) FROM PlayerStats GROUP BY LastName`,
+			nil,
+			[][]interface{}{
+				{"Adams", int64(7)},
+				{"Buchanan", int64(13)},
+				{"Coolidge", int64(1)},
+			},
+		},
 	}
 	for _, test := range tests {
 		q, err := spansql.ParseQuery(test.q)
diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go
new file mode 100644
index 0000000..c44ccb9
--- /dev/null
+++ b/spanner/spannertest/funcs.go
@@ -0,0 +1,86 @@
+/*
+Copyright 2020 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package spannertest
+
+import (
+	"fmt"
+
+	"cloud.google.com/go/spanner/spansql"
+)
+
+// This file contains implementations of query functions.
+
+type aggregateFunc struct {
+	// Whether the function can take a * arg (only COUNT).
+	AcceptStar bool
+
+	// Every aggregate func takes one expression.
+	Eval func(values []interface{}) (interface{}, spansql.Type, error)
+
+	// TODO: Handle qualifiers such as DISTINCT.
+}
+
+// TODO: more aggregate funcs.
+var aggregateFuncs = map[string]aggregateFunc{
+	"COUNT": {
+		AcceptStar: true,
+		Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
+			// Count the number of non-NULL values.
+			// COUNT(*) receives a list of non-NULL placeholders rather than values,
+			// so every value will be non-NULL.
+			var n int64
+			for _, v := range values {
+				if v != nil {
+					n++
+				}
+			}
+			return n, int64Type, nil
+		},
+	},
+	"SUM": {
+		Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
+			// Ignoring NULL values, there may only be one type, either INT64 or FLOAT64.
+			var seenInt, seenFloat bool
+			var sumInt int64
+			var sumFloat float64
+			for _, v := range values {
+				switch v := v.(type) {
+				default:
+					return nil, spansql.Type{}, fmt.Errorf("SUM only supports arguments of INT64 or FLOAT64 type, not %T", v)
+				case nil:
+					continue
+				case int64:
+					seenInt = true
+					sumInt += v
+				case float64:
+					seenFloat = true
+					sumFloat += v
+				}
+			}
+			if !seenInt && !seenFloat {
+				// "Returns NULL if the input contains only NULLs".
+				return nil, int64Type, nil
+			} else if seenInt && seenFloat {
+				// This shouldn't happen.
+				return nil, spansql.Type{}, fmt.Errorf("internal error: SUM saw mix of INT64 and FLOAT64")
+			} else if seenInt {
+				return sumInt, int64Type, nil
+			}
+			return sumFloat, float64Type, nil
+		},
+	},
+}