spanner/spannertest, spanner/spansql: support functions, COUNT(*)

The implementation in spannertest only supports COUNT(*) in a SELECT;
other functions and aggregations are not yet supported. In particular,
other aggregations will require some more restructuring of evalSelect.

Updates #1181.

Change-Id: Ifc859647cda57895f6940219379a91091654b4e9
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/43590
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index f9f6d43..62af253 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -34,7 +34,7 @@
 	params queryParams
 }
 
-func (d *database) evalSelect(sel spansql.Select, params queryParams, aux []spansql.Expr) (*resultIter, error) {
+func (d *database) evalSelect(sel spansql.Select, params queryParams, aux []spansql.Expr) (ri *resultIter, evalErr error) {
 	// TODO: weave this in below.
 	if len(sel.From) == 0 && sel.Where == nil {
 		// Simple expressions.
@@ -70,6 +70,24 @@
 		return nil, err
 	}
 
+	ri = &resultIter{}
+
+	// Handle COUNT(*) specially.
+	// TODO: Handle aggregation more generally.
+	if len(sel.List) == 1 && isCountStar(sel.List[0]) {
+		// Replace the `COUNT(*)` with `1`, then aggregate on the way out.
+		sel.List[0] = spansql.IntegerLiteral(1)
+		defer func() {
+			if evalErr != nil {
+				return
+			}
+			count := int64(len(ri.rows))
+			ri.rows = []resultRow{
+				{data: []interface{}{count}},
+			}
+		}()
+	}
+
 	// TODO: Support table sampling.
 
 	t.mu.Lock()
@@ -79,7 +97,6 @@
 		params: params,
 	}
 
-	ri := &resultIter{}
 	for _, e := range sel.List {
 		ci, err := ec.colInfo(e)
 		if err != nil {
@@ -420,3 +437,14 @@
 	}
 	return match
 }
+
+func isCountStar(e spansql.Expr) bool {
+	f, ok := e.(spansql.Func)
+	if !ok {
+		return false
+	}
+	if f.Name != "COUNT" || len(f.Args) != 1 {
+		return false
+	}
+	return f.Args[0] == spansql.Star
+}
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index 1d6a78c..8579cc0 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -229,6 +229,13 @@
 				{"Sam", 1.75},
 			},
 		},
+		{
+			`SELECT COUNT(*) FROM Staff WHERE Name < "T"`,
+			nil,
+			[][]interface{}{
+				{int64(4)},
+			},
+		},
 	}
 	for _, test := range tests {
 		q, err := spansql.ParseQuery(test.q)
diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go
index b4c3fbe..83de139 100644
--- a/spanner/spansql/parser.go
+++ b/spanner/spansql/parser.go
@@ -410,7 +410,7 @@
 	// TODO: backtick (`) for quoted identifiers.
 	// TODO: array, struct, date, timestamp literals
 	switch p.s[0] {
-	case ',', ';', '(', ')':
+	case ',', ';', '(', ')', '*':
 		// Single character symbol.
 		p.cur.value, p.s = p.s[:1], p.s[1:]
 		return
@@ -1178,6 +1178,39 @@
 	return nil, p.errorf("got %q, want literal or parameter", tok.value)
 }
 
+func (p *parser) parseExprList() ([]Expr, error) {
+	if err := p.expect("("); err != nil {
+		return nil, err
+	}
+	var list []Expr
+	for {
+		if err := p.expect(")"); err == nil {
+			break
+		}
+		p.back()
+
+		e, err := p.parseExpr()
+		if err != nil {
+			return nil, err
+		}
+		list = append(list, e)
+
+		// ")" or "," should be next.
+		tok := p.next()
+		if tok.err != nil {
+			return nil, err
+		}
+		if tok.value == ")" {
+			break
+		} else if tok.value == "," {
+			continue
+		} else {
+			return nil, p.errorf(`got %q, want ")" or ","`, tok.value)
+		}
+	}
+	return list, nil
+}
+
 /*
 Expressions
 
@@ -1414,7 +1447,25 @@
 		return Paren{Expr: e}, nil
 	}
 
-	return p.parseLit()
+	lit, err := p.parseLit()
+	if err != nil {
+		return nil, err
+	}
+
+	// If the literal was an identifier, and there's an open paren next,
+	// this is a function invocation.
+	if id, ok := lit.(ID); ok && p.sniff("(") {
+		list, err := p.parseExprList()
+		if err != nil {
+			return nil, err
+		}
+		return Func{
+			Name: string(id),
+			Args: list,
+		}, nil
+	}
+
+	return lit, nil
 }
 
 func (p *parser) parseLit() (Expr, error) {
@@ -1432,7 +1483,7 @@
 		return StringLiteral(tok.string), nil
 	}
 
-	// Handle some reserved keywords that become specific values.
+	// Handle some reserved keywords and special tokens that become specific values.
 	// TODO: Handle the other 92 keywords.
 	switch tok.value {
 	case "TRUE":
@@ -1441,6 +1492,8 @@
 		return False, nil
 	case "NULL":
 		return Null, nil
+	case "*":
+		return Star, nil
 	}
 
 	// TODO: more types of literals (array, struct, date, timestamp).
diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go
index 6c2e38b..42a95e9 100644
--- a/spanner/spansql/parser_test.go
+++ b/spanner/spansql/parser_test.go
@@ -55,6 +55,19 @@
 				Limit: Param("limit"),
 			},
 		},
+		{`SELECT COUNT(*) FROM Packages`,
+			Query{
+				Select: Select{
+					List: []Expr{
+						Func{
+							Name: "COUNT",
+							Args: []Expr{Star},
+						},
+					},
+					From: []SelectFrom{{Table: "Packages"}},
+				},
+			},
+		},
 	}
 	for _, test := range tests {
 		got, err := ParseQuery(test.in)
diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go
index d2298e0..5a3b184 100644
--- a/spanner/spansql/sql.go
+++ b/spanner/spansql/sql.go
@@ -243,6 +243,18 @@
 	return str
 }
 
+func (f Func) SQL() string {
+	str := f.Name + "("
+	for i, e := range f.Args {
+		if i > 0 {
+			str += ", "
+		}
+		str += e.SQL()
+	}
+	str += ")"
+	return str
+}
+
 func (p Paren) SQL() string { return "(" + p.Expr.SQL() + ")" }
 
 func (id ID) SQL() string   { return string(id) }
@@ -256,6 +268,7 @@
 }
 
 func (n NullLiteral) SQL() string { return "NULL" }
+func (StarExpr) SQL() string      { return "*" }
 
 func (il IntegerLiteral) SQL() string { return strconv.Itoa(int(il)) }
 func (fl FloatLiteral) SQL() string   { return strconv.FormatFloat(float64(fl), 'g', -1, 64) }
diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go
index 957ae8f..9b5cefe 100644
--- a/spanner/spansql/types.go
+++ b/spanner/spansql/types.go
@@ -250,6 +250,17 @@
 	SQL() string
 }
 
+// Func represents a function call.
+type Func struct {
+	Name string
+	Args []Expr
+
+	// TODO: various functions permit as-expressions, which might warrant different types in here.
+}
+
+func (Func) isBoolExpr() {} // possibly bool
+func (Func) isExpr()     {}
+
 // Paren represents a parenthesised expression.
 type Paren struct {
 	Expr Expr
@@ -308,6 +319,13 @@
 
 func (StringLiteral) isExpr() {}
 
+type StarExpr int
+
+// Star represents a "*" in an expression.
+const Star = StarExpr(0)
+
+func (StarExpr) isExpr() {}
+
 // DDL
 // https://cloud.google.com/spanner/docs/data-definition-language#ddl_syntax