blob: ab828b20d55004265eaec0c149ffe5b231287b1e [file] [log] [blame]
/*
Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spannertest
// This file contains the part of the Spanner fake that evaluates expressions.
import (
"fmt"
"regexp"
"strconv"
"strings"
"cloud.google.com/go/spanner/spansql"
)
// evalContext represents the context for evaluating an expression.
type evalContext struct {
table *table // may be nil
row row // set if table is set, only during expr evaluation
params queryParams
}
func (d *database) evalSelect(sel spansql.Select, params queryParams, aux []spansql.Expr) (ri *resultIter, evalErr error) {
// TODO: weave this in below.
if len(sel.From) == 0 && sel.Where == nil {
// Simple expressions.
ri := &resultIter{}
ec := evalContext{
params: params,
}
var r row
for _, e := range sel.List {
ci, err := ec.colInfo(e)
if err != nil {
return nil, err
}
// TODO: set column names?
ri.Cols = append(ri.Cols, ci)
x, err := ec.evalExpr(e)
if err != nil {
return nil, err
}
r = append(r, x)
}
ri.rows = []resultRow{{data: r}}
return ri, nil
}
if len(sel.From) != 1 {
return nil, fmt.Errorf("selecting from more than one table not supported")
}
tableName := sel.From[0].Table
t, err := d.table(tableName)
if err != nil {
return nil, err
}
ri = &resultIter{}
// Handle COUNT(*) specially.
// TODO: Handle aggregation more generally.
if len(sel.List) == 1 && isCountStar(sel.List[0]) {
// Replace the `COUNT(*)` with `1`, then aggregate on the way out.
sel.List[0] = spansql.IntegerLiteral(1)
defer func() {
if evalErr != nil {
return
}
count := int64(len(ri.rows))
ri.rows = []resultRow{
{data: []interface{}{count}},
}
}()
}
// TODO: Support table sampling.
t.mu.Lock()
defer t.mu.Unlock()
ec := evalContext{
table: t,
params: params,
}
// Is this a `SELECT *` query?
selectStar := len(sel.List) == 1 && sel.List[0] == spansql.Star
if selectStar {
// Every column will appear in the output.
ri.Cols = append([]colInfo(nil), ec.table.cols...)
} else {
for _, e := range sel.List {
ci, err := ec.colInfo(e)
if err != nil {
return nil, err
}
// TODO: deal with ci.Name == ""?
ri.Cols = append(ri.Cols, ci)
}
}
for _, r := range t.rows {
ec.row = r
// See if we want this row.
if sel.Where != nil {
b, err := ec.evalBoolExpr(sel.Where)
if err != nil {
return nil, err
}
if !b {
continue
}
}
// Evaluate SELECT expression list on the row.
var res resultRow
if selectStar {
var out []interface{}
for i := range r {
out = append(out, r.copyDataElem(i))
}
res.data = out
} else {
out, err := ec.evalExprList(sel.List)
if err != nil {
return nil, err
}
res.data = out
}
a, err := ec.evalExprList(aux)
if err != nil {
return nil, err
}
res.aux = a
ri.rows = append(ri.rows, res)
}
return ri, nil
}
func (ec evalContext) evalExprList(list []spansql.Expr) ([]interface{}, error) {
var out []interface{}
for _, e := range list {
x, err := ec.evalExpr(e)
if err != nil {
return nil, err
}
out = append(out, x)
}
return out, nil
}
func (ec evalContext) evalBoolExpr(be spansql.BoolExpr) (bool, error) {
switch be := be.(type) {
default:
return false, fmt.Errorf("unhandled BoolExpr %T", be)
case spansql.BoolLiteral:
return bool(be), nil
case spansql.ID, spansql.Paren:
e, err := ec.evalExpr(be)
if err != nil {
return false, err
}
if e == nil { // NULL is a false boolean.
return false, nil
}
b, ok := e.(bool)
if !ok {
return false, fmt.Errorf("got %T, want bool", e)
}
return b, nil
case spansql.LogicalOp:
var lhs, rhs bool
var err error
if be.LHS != nil {
lhs, err = ec.evalBoolExpr(be.LHS)
if err != nil {
return false, err
}
}
rhs, err = ec.evalBoolExpr(be.RHS)
if err != nil {
return false, err
}
switch be.Op {
case spansql.And:
return lhs && rhs, nil
case spansql.Or:
return lhs || rhs, nil
case spansql.Not:
return !rhs, nil
default:
return false, fmt.Errorf("unhandled LogicalOp %d", be.Op)
}
case spansql.ComparisonOp:
var lhs, rhs interface{}
var err error
lhs, err = ec.evalExpr(be.LHS)
if err != nil {
return false, err
}
rhs, err = ec.evalExpr(be.RHS)
if err != nil {
return false, err
}
switch be.Op {
default:
return false, fmt.Errorf("TODO: ComparisonOp %d", be.Op)
case spansql.Lt:
return compareVals(lhs, rhs) < 0, nil
case spansql.Le:
return compareVals(lhs, rhs) <= 0, nil
case spansql.Gt:
return compareVals(lhs, rhs) > 0, nil
case spansql.Ge:
return compareVals(lhs, rhs) >= 0, nil
case spansql.Eq:
return compareVals(lhs, rhs) == 0, nil
case spansql.Ne:
return compareVals(lhs, rhs) != 0, nil
case spansql.Like, spansql.NotLike:
left, ok := lhs.(string)
if !ok {
// TODO: byte works here too?
return false, fmt.Errorf("LHS of LIKE is %T, not string", lhs)
}
right, ok := rhs.(string)
if !ok {
// TODO: byte works here too?
return false, fmt.Errorf("RHS of LIKE is %T, not string", rhs)
}
match := evalLike(left, right)
if be.Op == spansql.NotLike {
match = !match
}
return match, nil
case spansql.Between, spansql.NotBetween:
rhs2, err := ec.evalExpr(be.RHS2)
if err != nil {
return false, err
}
b := compareVals(rhs, lhs) <= 0 && compareVals(lhs, rhs2) <= 0
if be.Op == spansql.NotBetween {
b = !b
}
return b, nil
}
case spansql.IsOp:
lhs, err := ec.evalExpr(be.LHS)
if err != nil {
return false, err
}
var b bool
switch rhs := be.RHS.(type) {
default:
return false, fmt.Errorf("unhandled IsOp %T", rhs)
case spansql.BoolLiteral:
lhsBool, ok := lhs.(bool)
if !ok {
return false, fmt.Errorf("non-bool value %T on LHS for %s", lhs, be.SQL())
}
b = (lhsBool == bool(rhs))
case spansql.NullLiteral:
b = (lhs == nil)
}
if be.Neg {
b = !b
}
return b, nil
}
}
func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) {
switch e := e.(type) {
default:
return nil, fmt.Errorf("TODO: evalExpr(%s %T)", e.SQL(), e)
case spansql.ID:
return ec.evalID(e)
case spansql.Param:
v, ok := ec.params[string(e)]
if !ok {
return 0, fmt.Errorf("unbound param %s", e.SQL())
}
return v, nil
case spansql.IntegerLiteral:
return int64(e), nil
case spansql.FloatLiteral:
return float64(e), nil
case spansql.StringLiteral:
return string(e), nil
case spansql.BytesLiteral:
return []byte(e), nil
case spansql.NullLiteral:
return nil, nil
case spansql.BoolLiteral:
return bool(e), nil
case spansql.Paren:
return ec.evalExpr(e.Expr)
case spansql.LogicalOp:
return ec.evalBoolExpr(e)
case spansql.IsOp:
return ec.evalBoolExpr(e)
}
}
func (ec evalContext) evalID(id spansql.ID) (interface{}, error) {
// TODO: look beyond column names.
if ec.table == nil {
return nil, fmt.Errorf("identifier %s when not SELECTing on a table is not supported", string(id))
}
i, ok := ec.table.colIndex[string(id)]
if !ok {
return nil, fmt.Errorf("couldn't resolve identifier %s", string(id))
}
return ec.row.copyDataElem(i), nil
}
func evalLimit(lim spansql.Limit, params queryParams) (int64, error) {
switch lim := lim.(type) {
case spansql.IntegerLiteral:
return int64(lim), nil
case spansql.Param:
return paramAsInteger(lim, params)
default:
return 0, fmt.Errorf("LIMIT with %T not supported", lim)
}
}
func paramAsInteger(p spansql.Param, params queryParams) (int64, error) {
v, ok := params[string(p)]
if !ok {
return 0, fmt.Errorf("unbound param %s", p.SQL())
}
switch v := v.(type) {
default:
return 0, fmt.Errorf("can't interpret parameter %s value of type %T as integer", p.SQL(), v)
case int64:
return v, nil
case string:
x, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, fmt.Errorf("bad int64 string %q: %v", v, err)
}
return x, nil
}
}
func compareVals(x, y interface{}) int {
// NULL is always the minimum possible value.
if x == nil && y == nil {
return 0
} else if x == nil {
return -1
} else if y == nil {
return 1
}
// TODO: coerce between comparable types (e.g. int64/float64).
switch x := x.(type) {
default:
panic(fmt.Sprintf("unhandled comparison on %T", x))
case bool:
// false < true
y := y.(bool)
if !x && y {
return -1
} else if x && !y {
return 1
}
return 0
case int64:
if s, ok := y.(string); ok {
var err error
y, err = strconv.ParseInt(s, 10, 64)
if err != nil {
panic(fmt.Sprintf("bad int64 string %q: %v", s, err))
}
}
y := y.(int64)
if x < y {
return -1
} else if x > y {
return 1
}
return 0
case float64:
y := y.(float64)
if x < y {
return -1
} else if x > y {
return 1
}
return 0
case string:
// This handles DATE too.
return strings.Compare(x, y.(string))
}
}
func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) {
// TODO: more types
switch e := e.(type) {
case spansql.IntegerLiteral:
return colInfo{Type: spansql.Type{Base: spansql.Int64}}, nil
case spansql.StringLiteral:
return colInfo{Type: spansql.Type{Base: spansql.String}}, nil
case spansql.BytesLiteral:
return colInfo{Type: spansql.Type{Base: spansql.Bytes}}, nil
case spansql.LogicalOp, spansql.ComparisonOp, spansql.IsOp:
return colInfo{Type: spansql.Type{Base: spansql.Bool}}, nil
case spansql.ID:
// TODO: support more than only naming a table column.
name := string(e)
if ec.table != nil {
if i, ok := ec.table.colIndex[name]; ok {
return ec.table.cols[i], nil
}
}
case spansql.Paren:
return ec.colInfo(e.Expr)
case spansql.NullLiteral:
// There isn't necessarily something sensible here.
// Empirically, though, the real Spanner returns Int64.
return colInfo{Type: spansql.Type{Base: spansql.Int64}}, nil
}
return colInfo{}, fmt.Errorf("can't deduce column type from expression [%s]", e.SQL())
}
func evalLike(str, pat string) bool {
/*
% matches any number of chars.
_ matches a single char.
TODO: handle escaping
*/
// Lean on regexp for simplicity.
pat = regexp.QuoteMeta(pat)
pat = strings.Replace(pat, "%", ".*", -1)
pat = strings.Replace(pat, "_", ".", -1)
match, err := regexp.MatchString(pat, str)
if err != nil {
panic(fmt.Sprintf("internal error: constructed bad regexp /%s/: %v", pat, err))
}
return match
}
func isCountStar(e spansql.Expr) bool {
f, ok := e.(spansql.Func)
if !ok {
return false
}
if f.Name != "COUNT" || len(f.Args) != 1 {
return false
}
return f.Args[0] == spansql.Star
}