spanner/spannertest: implement ARRAY_AGG function

This requires passing the argument type to aggregation functions, which
also simplifies the implementation of SUM.

Change-Id: Ia34b80d9a19334903c2676dd19532b8046c9f385
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52811
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 6e3884b..c3ac1cf 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -157,7 +157,7 @@
 	BYTES		[]byte
 	DATE		string (RFC 3339 date; "YYYY-MM-DD")
 	TIMESTAMP	string (RFC 3339 timestamp with zone; "YYYY-MM-DDTHH:MM:SSZ")
-	ARRAY<T>	[]T
+	ARRAY<T>	[]interface{}
 	STRUCT		TODO
 */
 type row []interface{}
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index 1454921..2a08d77 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -412,6 +412,14 @@
 		if starArg && !fn.AcceptStar {
 			return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name)
 		}
+		var argType spansql.Type
+		if !starArg {
+			ci, err := ec.colInfo(fexpr.Args[0])
+			if err != nil {
+				return nil, err
+			}
+			argType = ci.Type
+		}
 
 		// Prepare output.
 		rawOut := &rawIter{
@@ -439,7 +447,7 @@
 					values = append(values, x)
 				}
 			}
-			x, typ, err := fn.Eval(values)
+			x, typ, err := fn.Eval(values, argType)
 			if err != nil {
 				return nil, err
 			}
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index f6304aa..2f11f24 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -488,6 +488,14 @@
 				{int64(1), int64(25)}, // Jack(ID=1, Tenure=10), Sam(ID=3, Tenure=9), George(ID=5, Tenure=6)
 			},
 		},
+		{
+			`SELECT ARRAY_AGG(Cool) FROM Staff ORDER BY Name`,
+			nil,
+			[][]interface{}{
+				// Daniel, George (NULL), Jack (NULL), Sam, Teal'c
+				{[]interface{}{false, nil, nil, false, true}},
+			},
+		},
 	}
 	for _, test := range tests {
 		q, err := spansql.ParseQuery(test.q)
diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go
index c44ccb9..66d8ed6 100644
--- a/spanner/spannertest/funcs.go
+++ b/spanner/spannertest/funcs.go
@@ -29,16 +29,30 @@
 	AcceptStar bool
 
 	// Every aggregate func takes one expression.
-	Eval func(values []interface{}) (interface{}, spansql.Type, error)
+	Eval func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error)
 
 	// TODO: Handle qualifiers such as DISTINCT.
 }
 
 // TODO: more aggregate funcs.
 var aggregateFuncs = map[string]aggregateFunc{
+	"ARRAY_AGG": {
+		// https://cloud.google.com/spanner/docs/aggregate_functions#array_agg
+		Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
+			if typ.Array {
+				return nil, spansql.Type{}, fmt.Errorf("ARRAY_AGG unsupported on values of type %v", typ.SQL())
+			}
+			typ.Array = true // use as return type
+			if len(values) == 0 {
+				// "If there are zero input rows, this function returns NULL."
+				return nil, typ, nil
+			}
+			return values, typ, nil
+		},
+	},
 	"COUNT": {
 		AcceptStar: true,
-		Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
+		Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
 			// Count the number of non-NULL values.
 			// COUNT(*) receives a list of non-NULL placeholders rather than values,
 			// so every value will be non-NULL.
@@ -52,35 +66,40 @@
 		},
 	},
 	"SUM": {
-		Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
-			// Ignoring NULL values, there may only be one type, either INT64 or FLOAT64.
-			var seenInt, seenFloat bool
-			var sumInt int64
-			var sumFloat float64
-			for _, v := range values {
-				switch v := v.(type) {
-				default:
-					return nil, spansql.Type{}, fmt.Errorf("SUM only supports arguments of INT64 or FLOAT64 type, not %T", v)
-				case nil:
-					continue
-				case int64:
-					seenInt = true
-					sumInt += v
-				case float64:
-					seenFloat = true
-					sumFloat += v
+		Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
+			if typ.Array || !(typ.Base == spansql.Int64 || typ.Base == spansql.Float64) {
+				return nil, spansql.Type{}, fmt.Errorf("SUM only supports arguments of INT64 or FLOAT64 type, not %s", typ.SQL())
+			}
+			if typ.Base == spansql.Int64 {
+				var seen bool
+				var sum int64
+				for _, v := range values {
+					if v == nil {
+						continue
+					}
+					seen = true
+					sum += v.(int64)
 				}
+				if !seen {
+					// "Returns NULL if the input contains only NULLs".
+					return nil, typ, nil
+				}
+				return sum, typ, nil
 			}
-			if !seenInt && !seenFloat {
+			var seen bool
+			var sum float64
+			for _, v := range values {
+				if v == nil {
+					continue
+				}
+				seen = true
+				sum += v.(float64)
+			}
+			if !seen {
 				// "Returns NULL if the input contains only NULLs".
-				return nil, int64Type, nil
-			} else if seenInt && seenFloat {
-				// This shouldn't happen.
-				return nil, spansql.Type{}, fmt.Errorf("internal error: SUM saw mix of INT64 and FLOAT64")
-			} else if seenInt {
-				return sumInt, int64Type, nil
+				return nil, typ, nil
 			}
-			return sumFloat, float64Type, nil
+			return sum, typ, nil
 		},
 	},
 }
diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go
index c243b7a..e7a8941 100644
--- a/spanner/spansql/keywords.go
+++ b/spanner/spansql/keywords.go
@@ -129,9 +129,10 @@
 // https://cloud.google.com/spanner/docs/functions-and-operators
 var funcs = map[string]bool{
 	// Aggregate functions.
-	"BIT_XOR": true,
-	"COUNT":   true,
-	"SUM":     true,
+	"ARRAY_AGG": true,
+	"BIT_XOR":   true,
+	"COUNT":     true,
+	"SUM":       true,
 
 	// Mathematical functions.
 	"ABS": true,