spanner/spannertest: implement ARRAY_AGG function
This requires passing the argument type to aggregation functions, which
also simplifies the implementation of SUM.
Change-Id: Ia34b80d9a19334903c2676dd19532b8046c9f385
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52811
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 6e3884b..c3ac1cf 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -157,7 +157,7 @@
BYTES []byte
DATE string (RFC 3339 date; "YYYY-MM-DD")
TIMESTAMP string (RFC 3339 timestamp with zone; "YYYY-MM-DDTHH:MM:SSZ")
- ARRAY<T> []T
+ ARRAY<T> []interface{}
STRUCT TODO
*/
type row []interface{}
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index 1454921..2a08d77 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -412,6 +412,14 @@
if starArg && !fn.AcceptStar {
return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name)
}
+ var argType spansql.Type
+ if !starArg {
+ ci, err := ec.colInfo(fexpr.Args[0])
+ if err != nil {
+ return nil, err
+ }
+ argType = ci.Type
+ }
// Prepare output.
rawOut := &rawIter{
@@ -439,7 +447,7 @@
values = append(values, x)
}
}
- x, typ, err := fn.Eval(values)
+ x, typ, err := fn.Eval(values, argType)
if err != nil {
return nil, err
}
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index f6304aa..2f11f24 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -488,6 +488,14 @@
{int64(1), int64(25)}, // Jack(ID=1, Tenure=10), Sam(ID=3, Tenure=9), George(ID=5, Tenure=6)
},
},
+ {
+ `SELECT ARRAY_AGG(Cool) FROM Staff ORDER BY Name`,
+ nil,
+ [][]interface{}{
+ // Daniel, George (NULL), Jack (NULL), Sam, Teal'c
+ {[]interface{}{false, nil, nil, false, true}},
+ },
+ },
}
for _, test := range tests {
q, err := spansql.ParseQuery(test.q)
diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go
index c44ccb9..66d8ed6 100644
--- a/spanner/spannertest/funcs.go
+++ b/spanner/spannertest/funcs.go
@@ -29,16 +29,30 @@
AcceptStar bool
// Every aggregate func takes one expression.
- Eval func(values []interface{}) (interface{}, spansql.Type, error)
+ Eval func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error)
// TODO: Handle qualifiers such as DISTINCT.
}
// TODO: more aggregate funcs.
var aggregateFuncs = map[string]aggregateFunc{
+ "ARRAY_AGG": {
+ // https://cloud.google.com/spanner/docs/aggregate_functions#array_agg
+ Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
+ if typ.Array {
+ return nil, spansql.Type{}, fmt.Errorf("ARRAY_AGG unsupported on values of type %v", typ.SQL())
+ }
+ typ.Array = true // use as return type
+ if len(values) == 0 {
+ // "If there are zero input rows, this function returns NULL."
+ return nil, typ, nil
+ }
+ return values, typ, nil
+ },
+ },
"COUNT": {
AcceptStar: true,
- Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
+ Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
// Count the number of non-NULL values.
// COUNT(*) receives a list of non-NULL placeholders rather than values,
// so every value will be non-NULL.
@@ -52,35 +66,40 @@
},
},
"SUM": {
- Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
- // Ignoring NULL values, there may only be one type, either INT64 or FLOAT64.
- var seenInt, seenFloat bool
- var sumInt int64
- var sumFloat float64
- for _, v := range values {
- switch v := v.(type) {
- default:
- return nil, spansql.Type{}, fmt.Errorf("SUM only supports arguments of INT64 or FLOAT64 type, not %T", v)
- case nil:
- continue
- case int64:
- seenInt = true
- sumInt += v
- case float64:
- seenFloat = true
- sumFloat += v
+ Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
+ if typ.Array || !(typ.Base == spansql.Int64 || typ.Base == spansql.Float64) {
+ return nil, spansql.Type{}, fmt.Errorf("SUM only supports arguments of INT64 or FLOAT64 type, not %s", typ.SQL())
+ }
+ if typ.Base == spansql.Int64 {
+ var seen bool
+ var sum int64
+ for _, v := range values {
+ if v == nil {
+ continue
+ }
+ seen = true
+ sum += v.(int64)
}
+ if !seen {
+ // "Returns NULL if the input contains only NULLs".
+ return nil, typ, nil
+ }
+ return sum, typ, nil
}
- if !seenInt && !seenFloat {
+ var seen bool
+ var sum float64
+ for _, v := range values {
+ if v == nil {
+ continue
+ }
+ seen = true
+ sum += v.(float64)
+ }
+ if !seen {
// "Returns NULL if the input contains only NULLs".
- return nil, int64Type, nil
- } else if seenInt && seenFloat {
- // This shouldn't happen.
- return nil, spansql.Type{}, fmt.Errorf("internal error: SUM saw mix of INT64 and FLOAT64")
- } else if seenInt {
- return sumInt, int64Type, nil
+ return nil, typ, nil
}
- return sumFloat, float64Type, nil
+ return sum, typ, nil
},
},
}
diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go
index c243b7a..e7a8941 100644
--- a/spanner/spansql/keywords.go
+++ b/spanner/spansql/keywords.go
@@ -129,9 +129,10 @@
// https://cloud.google.com/spanner/docs/functions-and-operators
var funcs = map[string]bool{
// Aggregate functions.
- "BIT_XOR": true,
- "COUNT": true,
- "SUM": true,
+ "ARRAY_AGG": true,
+ "BIT_XOR": true,
+ "COUNT": true,
+ "SUM": true,
// Mathematical functions.
"ABS": true,