spanner/spannertest, spanner/spansql: support functions, COUNT(*)
The implementation in spannertest only supports COUNT(*) in a SELECT;
other functions and aggregations are not yet supported. In particular,
other aggregations will require some more restructuring of evalSelect.
Updates #1181.
Change-Id: Ifc859647cda57895f6940219379a91091654b4e9
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/43590
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index f9f6d43..62af253 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -34,7 +34,7 @@
params queryParams
}
-func (d *database) evalSelect(sel spansql.Select, params queryParams, aux []spansql.Expr) (*resultIter, error) {
+func (d *database) evalSelect(sel spansql.Select, params queryParams, aux []spansql.Expr) (ri *resultIter, evalErr error) {
// TODO: weave this in below.
if len(sel.From) == 0 && sel.Where == nil {
// Simple expressions.
@@ -70,6 +70,24 @@
return nil, err
}
+ ri = &resultIter{}
+
+ // Handle COUNT(*) specially.
+ // TODO: Handle aggregation more generally.
+ if len(sel.List) == 1 && isCountStar(sel.List[0]) {
+ // Replace the `COUNT(*)` with `1`, then aggregate on the way out.
+ sel.List[0] = spansql.IntegerLiteral(1)
+ defer func() {
+ if evalErr != nil {
+ return
+ }
+ count := int64(len(ri.rows))
+ ri.rows = []resultRow{
+ {data: []interface{}{count}},
+ }
+ }()
+ }
+
// TODO: Support table sampling.
t.mu.Lock()
@@ -79,7 +97,6 @@
params: params,
}
- ri := &resultIter{}
for _, e := range sel.List {
ci, err := ec.colInfo(e)
if err != nil {
@@ -420,3 +437,14 @@
}
return match
}
+
+func isCountStar(e spansql.Expr) bool {
+ f, ok := e.(spansql.Func)
+ if !ok {
+ return false
+ }
+ if f.Name != "COUNT" || len(f.Args) != 1 {
+ return false
+ }
+ return f.Args[0] == spansql.Star
+}
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index 1d6a78c..8579cc0 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -229,6 +229,13 @@
{"Sam", 1.75},
},
},
+ {
+ `SELECT COUNT(*) FROM Staff WHERE Name < "T"`,
+ nil,
+ [][]interface{}{
+ {int64(4)},
+ },
+ },
}
for _, test := range tests {
q, err := spansql.ParseQuery(test.q)
diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go
index b4c3fbe..83de139 100644
--- a/spanner/spansql/parser.go
+++ b/spanner/spansql/parser.go
@@ -410,7 +410,7 @@
// TODO: backtick (`) for quoted identifiers.
// TODO: array, struct, date, timestamp literals
switch p.s[0] {
- case ',', ';', '(', ')':
+ case ',', ';', '(', ')', '*':
// Single character symbol.
p.cur.value, p.s = p.s[:1], p.s[1:]
return
@@ -1178,6 +1178,39 @@
return nil, p.errorf("got %q, want literal or parameter", tok.value)
}
+func (p *parser) parseExprList() ([]Expr, error) {
+ if err := p.expect("("); err != nil {
+ return nil, err
+ }
+ var list []Expr
+ for {
+ if err := p.expect(")"); err == nil {
+ break
+ }
+ p.back()
+
+ e, err := p.parseExpr()
+ if err != nil {
+ return nil, err
+ }
+ list = append(list, e)
+
+ // ")" or "," should be next.
+ tok := p.next()
+ if tok.err != nil {
+ return nil, err
+ }
+ if tok.value == ")" {
+ break
+ } else if tok.value == "," {
+ continue
+ } else {
+ return nil, p.errorf(`got %q, want ")" or ","`, tok.value)
+ }
+ }
+ return list, nil
+}
+
/*
Expressions
@@ -1414,7 +1447,25 @@
return Paren{Expr: e}, nil
}
- return p.parseLit()
+ lit, err := p.parseLit()
+ if err != nil {
+ return nil, err
+ }
+
+ // If the literal was an identifier, and there's an open paren next,
+ // this is a function invocation.
+ if id, ok := lit.(ID); ok && p.sniff("(") {
+ list, err := p.parseExprList()
+ if err != nil {
+ return nil, err
+ }
+ return Func{
+ Name: string(id),
+ Args: list,
+ }, nil
+ }
+
+ return lit, nil
}
func (p *parser) parseLit() (Expr, error) {
@@ -1432,7 +1483,7 @@
return StringLiteral(tok.string), nil
}
- // Handle some reserved keywords that become specific values.
+ // Handle some reserved keywords and special tokens that become specific values.
// TODO: Handle the other 92 keywords.
switch tok.value {
case "TRUE":
@@ -1441,6 +1492,8 @@
return False, nil
case "NULL":
return Null, nil
+ case "*":
+ return Star, nil
}
// TODO: more types of literals (array, struct, date, timestamp).
diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go
index 6c2e38b..42a95e9 100644
--- a/spanner/spansql/parser_test.go
+++ b/spanner/spansql/parser_test.go
@@ -55,6 +55,19 @@
Limit: Param("limit"),
},
},
+ {`SELECT COUNT(*) FROM Packages`,
+ Query{
+ Select: Select{
+ List: []Expr{
+ Func{
+ Name: "COUNT",
+ Args: []Expr{Star},
+ },
+ },
+ From: []SelectFrom{{Table: "Packages"}},
+ },
+ },
+ },
}
for _, test := range tests {
got, err := ParseQuery(test.in)
diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go
index d2298e0..5a3b184 100644
--- a/spanner/spansql/sql.go
+++ b/spanner/spansql/sql.go
@@ -243,6 +243,18 @@
return str
}
+func (f Func) SQL() string {
+ str := f.Name + "("
+ for i, e := range f.Args {
+ if i > 0 {
+ str += ", "
+ }
+ str += e.SQL()
+ }
+ str += ")"
+ return str
+}
+
func (p Paren) SQL() string { return "(" + p.Expr.SQL() + ")" }
func (id ID) SQL() string { return string(id) }
@@ -256,6 +268,7 @@
}
func (n NullLiteral) SQL() string { return "NULL" }
+func (StarExpr) SQL() string { return "*" }
func (il IntegerLiteral) SQL() string { return strconv.Itoa(int(il)) }
func (fl FloatLiteral) SQL() string { return strconv.FormatFloat(float64(fl), 'g', -1, 64) }
diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go
index 957ae8f..9b5cefe 100644
--- a/spanner/spansql/types.go
+++ b/spanner/spansql/types.go
@@ -250,6 +250,17 @@
SQL() string
}
+// Func represents a function call.
+type Func struct {
+ Name string
+ Args []Expr
+
+ // TODO: various functions permit as-expressions, which might warrant different types in here.
+}
+
+func (Func) isBoolExpr() {} // possibly bool
+func (Func) isExpr() {}
+
// Paren represents a parenthesised expression.
type Paren struct {
Expr Expr
@@ -308,6 +319,13 @@
func (StringLiteral) isExpr() {}
+type StarExpr int
+
+// Star represents a "*" in an expression.
+const Star = StarExpr(0)
+
+func (StarExpr) isExpr() {}
+
// DDL
// https://cloud.google.com/spanner/docs/data-definition-language#ddl_syntax