spanner/spannertest: evaluate more arithmetic operators
This covers unary negation, unary not, bitwise and/xor/or,
as well as reporting column types for expressions involving
any possible arithmetic operator.
Change-Id: I47f595945d0b8964e8475ed10f9040ebb2fdc9ba
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52310
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index 07c7340..b89162d 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -318,6 +318,34 @@
func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) {
switch e.Op {
+ case spansql.Neg:
+ rhs, err := ec.evalExpr(e.RHS)
+ if err != nil {
+ return nil, err
+ }
+ switch rhs := rhs.(type) {
+ case float64:
+ return -rhs, nil
+ case int64:
+ return -rhs, nil
+ }
+ return nil, fmt.Errorf("RHS of %s evaluates to %T, want FLOAT64 or INT64", e.SQL(), rhs)
+ case spansql.BitNot:
+ rhs, err := ec.evalExpr(e.RHS)
+ if err != nil {
+ return nil, err
+ }
+ switch rhs := rhs.(type) {
+ case int64:
+ return ^rhs, nil
+ case []byte:
+ b := append([]byte(nil), rhs...) // deep copy
+ for i := range b {
+ b[i] = ^b[i]
+ }
+ return b, nil
+ }
+ return nil, fmt.Errorf("RHS of %s evaluates to %T, want INT64 or BYTES", e.SQL(), rhs)
case spansql.Div:
lhs, err := ec.evalFloat64(e.LHS)
if err != nil {
@@ -369,7 +397,51 @@
case spansql.Mul:
return f1 * f2, nil
}
+ case spansql.BitAnd, spansql.BitXor, spansql.BitOr:
+ lhs, err := ec.evalExpr(e.LHS)
+ if err != nil {
+ return nil, err
+ }
+ rhs, err := ec.evalExpr(e.RHS)
+ if err != nil {
+ return nil, err
+ }
+ i1, ok1 := lhs.(int64)
+ i2, ok2 := rhs.(int64)
+ if ok1 && ok2 {
+ switch e.Op {
+ case spansql.BitAnd:
+ return i1 & i2, nil
+ case spansql.BitXor:
+ return i1 ^ i2, nil
+ case spansql.BitOr:
+ return i1 | i2, nil
+ }
+ }
+ b1, ok1 := lhs.([]byte)
+ b2, ok2 := rhs.([]byte)
+ if !ok1 || !ok2 {
+ return nil, fmt.Errorf("arguments of %s evaluate to (%T, %T), want (INT64, INT64) or (BYTES, BYTES)", e.SQL(), lhs, rhs)
+ }
+ if len(b1) != len(b2) {
+ return nil, fmt.Errorf("arguments of %s evaluate to BYTES of unequal lengths (%d vs %d)", e.SQL(), len(b1), len(b2))
+ }
+ var f func(x, y byte) byte
+ switch e.Op {
+ case spansql.BitAnd:
+ f = func(x, y byte) byte { return x & y }
+ case spansql.BitXor:
+ f = func(x, y byte) byte { return x ^ y }
+ case spansql.BitOr:
+ f = func(x, y byte) byte { return x | y }
+ }
+ b := make([]byte, len(b1))
+ for i := range b1 {
+ b[i] = f(b1[i], b2[i])
+ }
+ return b, nil
}
+ // TODO: Concat, BitShl, BitShr
return nil, fmt.Errorf("TODO: evalArithOp(%s %v)", e.SQL(), e.Op)
}
@@ -537,6 +609,7 @@
var (
int64Type = spansql.Type{Base: spansql.Int64}
float64Type = spansql.Type{Base: spansql.Float64}
+ stringType = spansql.Type{Base: spansql.String}
)
func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) {
@@ -596,6 +669,8 @@
switch ao.Op {
default:
return spansql.Type{}, fmt.Errorf("can't deduce column type from ArithOp [%s]", ao.SQL())
+ case spansql.Neg, spansql.BitNot:
+ return rhs, nil
case spansql.Add, spansql.Sub, spansql.Mul:
if lhs == int64Type && rhs == int64Type {
return int64Type, nil
@@ -603,8 +678,15 @@
return float64Type, nil
case spansql.Div:
return float64Type, nil
+ case spansql.Concat:
+ if !lhs.Array {
+ return stringType, nil
+ }
+ return lhs, nil
+ case spansql.BitShl, spansql.BitShr, spansql.BitAnd, spansql.BitXor, spansql.BitOr:
+ // "All bitwise operators return the same type and the same length as the first operand."
+ return lhs, nil
}
- // TODO: more operators
}
func evalLike(str, pat string) bool {
diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go
index eaa1d3c..251133b 100644
--- a/spanner/spansql/types.go
+++ b/spanner/spansql/types.go
@@ -300,7 +300,7 @@
type LogicalOp struct {
Op LogicalOperator
- LHS, RHS BoolExpr // only RHS is set for Not
+ LHS, RHS BoolExpr // only RHS is set for Neg, BitNot
}
func (LogicalOp) isBoolExpr() {}