spanner/spannertest: implement GROUP BY, aggregation
This adds full support for GROUP BY and applying aggregation using COUNT
and SUM as part of a SELECT list.
Also revise TODO list in README.md.
Change-Id: I52c8c41ee4751b1ed4c2632e09e3bbda9b953b93
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52735
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md
index 9ccb19f..edb9656 100644
--- a/spanner/spannertest/README.md
+++ b/spanner/spannertest/README.md
@@ -15,17 +15,16 @@
Here's a list of features that are missing or incomplete. It is roughly ordered
by ascending esotericism:
-- SELECT GROUP BY
-- SELECT HAVING
-- arithmetic expressions (operators, parens)
-- transaction simulation
+- expression functions
+- more aggregation functions
- INSERT/UPDATE DML statements
- case insensitivity
- alternate literal types (esp. strings)
- STRUCT types
-- expression functions
-- expression type casting, coercion
+- SELECT HAVING
- joins
+- transaction simulation
+- expression type casting, coercion
- query offset
- SELECT aliases
- subselects
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 63075fe..6fb3d8e 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -207,6 +207,9 @@
if _, ok := d.tables[stmt.Name]; ok {
return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name)
}
+ if len(stmt.PrimaryKey) == 0 {
+ return status.Newf(codes.InvalidArgument, "table %s has no primary key", stmt.Name)
+ }
// TODO: check stmt.Interleave details.
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index be20e55..750eba3 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -356,6 +356,10 @@
return ec.evalBoolExpr(e)
case spansql.IsOp:
return ec.evalBoolExpr(e)
+ case aggSentinel:
+ // Aggregate value is always last in the row.
+ // TODO: This could be tightened up by including colInfo in evalContext.
+ return ec.row[len(ec.row)-1], nil
}
}
@@ -515,6 +519,8 @@
// There isn't necessarily something sensible here.
// Empirically, though, the real Spanner returns Int64.
return colInfo{Type: int64Type}, nil
+ case aggSentinel:
+ return colInfo{Type: e.Type}, nil
}
return colInfo{}, fmt.Errorf("can't deduce column type from expression [%s]", e.SQL())
}
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index fff96d1..160d7f4 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -33,8 +33,8 @@
The order of operations among those supported by Cloud Spanner is
FROM + JOIN + set ops [TODO: JOIN and set ops]
WHERE
- GROUP BY [TODO]
- aggregation [TODO]
+ GROUP BY
+ aggregation
HAVING [TODO]
SELECT
DISTINCT
@@ -54,6 +54,13 @@
Next() (row, error)
}
+// aggSentinel is a synthetic expression that refers to an aggregated value.
+// It is transient only; it is never stored and only used during evaluation.
+type aggSentinel struct {
+ spansql.Expr
+ Type spansql.Type
+}
+
// nullIter is a rowIter that returns one empty row only.
// This is used for queries without a table.
type nullIter struct {
@@ -319,23 +326,144 @@
}
}
- // Handle COUNT(*) specially.
- // TODO: Handle aggregation more generally.
- if len(sel.List) == 1 && isCountStar(sel.List[0]) {
- // Replace the `COUNT(*)` with `1`, then aggregate on the way out.
- sel.List[0] = spansql.IntegerLiteral(1)
- defer func() {
- if evalErr != nil {
- return
- }
- raw, err := toRawIter(ri)
+ // Apply GROUP BY.
+ // This only reorders rows to group rows together;
+ // aggregation happens next.
+ var rowGroups [][2]int // Sequence of half-open intervals of row numbers.
+ if len(sel.GroupBy) > 0 {
+ raw, err := toRawIter(ri)
+ if err != nil {
+ return nil, err
+ }
+ keys := make([][]interface{}, 0, len(raw.rows))
+ for _, row := range raw.rows {
+ // Evaluate sort key for this row.
+ // TODO: Support referring to expression names in the SELECT list;
+ // this may require passing through sel.List, or maybe mutating
+ // sel.GroupBy to copy the referenced values. This will also be
+ // required to support grouping by aliases.
+ ec.row = row
+ key, err := ec.evalExprList(sel.GroupBy)
if err != nil {
- ri, evalErr = nil, err
+ return nil, err
}
- count := int64(len(raw.rows))
- raw.rows = []row{{count}}
- ri, evalErr = raw, nil
- }()
+ keys = append(keys, key)
+ }
+
+ // Reorder rows base on the evaluated keys.
+ ers := externalRowSorter{rows: raw.rows, keys: keys}
+ sort.Sort(ers)
+ raw.rows = ers.rows
+
+ // Record groups as a sequence of row intervals.
+ // Each group is a run of the same keys.
+ start := 0
+ for i := 1; i < len(keys); i++ {
+ if compareValLists(keys[i-1], keys[i], nil) == 0 {
+ continue
+ }
+ rowGroups = append(rowGroups, [2]int{start, i})
+ start = i
+ }
+ if len(keys) > 0 {
+ rowGroups = append(rowGroups, [2]int{start, len(keys)})
+ }
+
+ ri = raw
+ }
+
+ // Handle aggregation.
+ // TODO: Support more than one aggregation function; does Spanner support that?
+ aggI := -1
+ for i, e := range sel.List {
+ // Supported aggregate funcs have exactly one arg.
+ f, ok := e.(spansql.Func)
+ if !ok || len(f.Args) != 1 {
+ continue
+ }
+ _, ok = aggregateFuncs[f.Name]
+ if !ok {
+ continue
+ }
+ if aggI > -1 {
+ return nil, fmt.Errorf("only one aggregate function is supported")
+ }
+ aggI = i
+ }
+ if aggI > -1 {
+ raw, err := toRawIter(ri)
+ if err != nil {
+ return nil, err
+ }
+ if len(rowGroups) == 0 {
+ // No grouping, so aggregation applies to the entire table (e.g. COUNT(*)).
+ rowGroups = [][2]int{{0, len(raw.rows)}}
+ }
+ fexpr := sel.List[aggI].(spansql.Func)
+ fn := aggregateFuncs[fexpr.Name]
+ starArg := fexpr.Args[0] == spansql.Star
+ if starArg && !fn.AcceptStar {
+ return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name)
+ }
+
+ // Prepare output.
+ rawOut := &rawIter{
+ // Same as input columns, but also the aggregate value.
+ // Add the colInfo for the aggregate at the end
+ // so we know the type.
+ // Make a copy for safety.
+ cols: append([]colInfo(nil), raw.cols...),
+ }
+
+ var aggType spansql.Type
+ for _, rg := range rowGroups {
+ // Compute aggregate value across this group.
+ var values []interface{}
+ for i := rg[0]; i < rg[1]; i++ {
+ ec.row = raw.rows[i]
+ if starArg {
+ // A non-NULL placeholder is sufficient for aggregation.
+ values = append(values, 1)
+ } else {
+ x, err := ec.evalExpr(fexpr.Args[0])
+ if err != nil {
+ return nil, err
+ }
+ values = append(values, x)
+ }
+ }
+ x, typ, err := fn.Eval(values)
+ if err != nil {
+ return nil, err
+ }
+ aggType = typ
+ // Output for the row group is the first row of the group (arbitrary,
+ // but it should be representative), and the aggregate value.
+ // TODO: Should this exclude the aggregated expressions so they can't be selected?
+ repRow := raw.rows[rg[0]]
+ var outRow row
+ for i := range repRow {
+ outRow = append(outRow, repRow.copyDataElem(i))
+ }
+ outRow = append(outRow, x)
+ rawOut.rows = append(rawOut.rows, outRow)
+ }
+
+ if aggType == (spansql.Type{}) {
+ // Fallback; there might not be any groups.
+ // TODO: Should this be in aggregateFunc?
+ aggType = int64Type
+ }
+ rawOut.cols = append(raw.cols, colInfo{
+ // TODO: Generate more interesting colInfo?
+ Name: fexpr.SQL(),
+ Type: aggType,
+ })
+
+ ri = rawOut
+ sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value.
+ Type: aggType,
+ }
}
// TODO: Support table sampling.
@@ -372,13 +500,18 @@
return ri, nil
}
-func isCountStar(e spansql.Expr) bool {
- f, ok := e.(spansql.Func)
- if !ok {
- return false
- }
- if f.Name != "COUNT" || len(f.Args) != 1 {
- return false
- }
- return f.Args[0] == spansql.Star
+// externalRowSorter implements sort.Interface for a slice of rows
+// with an external sort key.
+type externalRowSorter struct {
+ rows []row
+ keys [][]interface{}
+}
+
+func (ers externalRowSorter) Len() int { return len(ers.rows) }
+func (ers externalRowSorter) Less(i, j int) bool {
+ return compareValLists(ers.keys[i], ers.keys[j], nil) < 0
+}
+func (ers externalRowSorter) Swap(i, j int) {
+ ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i]
+ ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i]
}
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index 439f039..09eb4cc 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -300,6 +300,41 @@
t.Fatalf("Committing changes: %v", err)
}
+ // Prepare the sample tables from the Cloud Spanner docs.
+ // https://cloud.google.com/spanner/docs/query-syntax#appendix-a-examples-with-sample-data
+ for _, ct := range []*spansql.CreateTable{
+ // TODO: Roster, TeamMascot when we implement JOINs.
+ {
+ Name: "PlayerStats",
+ Columns: []spansql.ColumnDef{
+ {Name: "LastName", Type: spansql.Type{Base: spansql.String}},
+ {Name: "OpponentID", Type: spansql.Type{Base: spansql.Int64}},
+ {Name: "PointsScored", Type: spansql.Type{Base: spansql.Int64}},
+ },
+ PrimaryKey: []spansql.KeyPart{{Column: "LastName"}, {Column: "OpponentID"}}, // TODO: is this right?
+ },
+ } {
+ st := db.ApplyDDL(ct)
+ if st.Code() != codes.OK {
+ t.Fatalf("Creating table: %v", st.Err())
+ }
+ }
+ tx = db.NewTransaction()
+ tx.Start()
+ err = db.Insert(tx, "PlayerStats", []string{"LastName", "OpponentID", "PointsScored"}, []*structpb.ListValue{
+ listV(stringV("Adams"), stringV("51"), stringV("3")),
+ listV(stringV("Buchanan"), stringV("77"), stringV("0")),
+ listV(stringV("Coolidge"), stringV("77"), stringV("1")),
+ listV(stringV("Adams"), stringV("52"), stringV("4")),
+ listV(stringV("Buchanan"), stringV("50"), stringV("13")),
+ })
+ if err != nil {
+ t.Fatalf("Inserting data: %v", err)
+ }
+ if _, err := tx.Commit(); err != nil {
+ t.Fatalf("Commiting changes: %v", err)
+ }
+
// Do some complex queries.
tests := []struct {
q string
@@ -431,6 +466,18 @@
{true, false},
},
},
+ // From https://cloud.google.com/spanner/docs/query-syntax#group-by-clause_1:
+ {
+ // TODO: Ordering matters? Our implementation sorts by the GROUP BY key,
+ // but nothing documented seems to guarantee that.
+ `SELECT LastName, SUM(PointsScored) FROM PlayerStats GROUP BY LastName`,
+ nil,
+ [][]interface{}{
+ {"Adams", int64(7)},
+ {"Buchanan", int64(13)},
+ {"Coolidge", int64(1)},
+ },
+ },
}
for _, test := range tests {
q, err := spansql.ParseQuery(test.q)
diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go
new file mode 100644
index 0000000..c44ccb9
--- /dev/null
+++ b/spanner/spannertest/funcs.go
@@ -0,0 +1,86 @@
+/*
+Copyright 2020 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package spannertest
+
+import (
+ "fmt"
+
+ "cloud.google.com/go/spanner/spansql"
+)
+
+// This file contains implementations of query functions.
+
+type aggregateFunc struct {
+ // Whether the function can take a * arg (only COUNT).
+ AcceptStar bool
+
+ // Every aggregate func takes one expression.
+ Eval func(values []interface{}) (interface{}, spansql.Type, error)
+
+ // TODO: Handle qualifiers such as DISTINCT.
+}
+
+// TODO: more aggregate funcs.
+var aggregateFuncs = map[string]aggregateFunc{
+ "COUNT": {
+ AcceptStar: true,
+ Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
+ // Count the number of non-NULL values.
+ // COUNT(*) receives a list of non-NULL placeholders rather than values,
+ // so every value will be non-NULL.
+ var n int64
+ for _, v := range values {
+ if v != nil {
+ n++
+ }
+ }
+ return n, int64Type, nil
+ },
+ },
+ "SUM": {
+ Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
+ // Ignoring NULL values, there may only be one type, either INT64 or FLOAT64.
+ var seenInt, seenFloat bool
+ var sumInt int64
+ var sumFloat float64
+ for _, v := range values {
+ switch v := v.(type) {
+ default:
+ return nil, spansql.Type{}, fmt.Errorf("SUM only supports arguments of INT64 or FLOAT64 type, not %T", v)
+ case nil:
+ continue
+ case int64:
+ seenInt = true
+ sumInt += v
+ case float64:
+ seenFloat = true
+ sumFloat += v
+ }
+ }
+ if !seenInt && !seenFloat {
+ // "Returns NULL if the input contains only NULLs".
+ return nil, int64Type, nil
+ } else if seenInt && seenFloat {
+ // This shouldn't happen.
+ return nil, spansql.Type{}, fmt.Errorf("internal error: SUM saw mix of INT64 and FLOAT64")
+ } else if seenInt {
+ return sumInt, int64Type, nil
+ }
+ return sumFloat, float64Type, nil
+ },
+ },
+}