spanner/spannertest: evaluate more arithmetic operators

This covers unary negation, unary not, bitwise and/xor/or,
as well as reporting column types for expressions involving
any possible arithmetic operator.

Change-Id: I47f595945d0b8964e8475ed10f9040ebb2fdc9ba
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52310
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index 07c7340..b89162d 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -318,6 +318,34 @@
 
 func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) {
 	switch e.Op {
+	case spansql.Neg:
+		rhs, err := ec.evalExpr(e.RHS)
+		if err != nil {
+			return nil, err
+		}
+		switch rhs := rhs.(type) {
+		case float64:
+			return -rhs, nil
+		case int64:
+			return -rhs, nil
+		}
+		return nil, fmt.Errorf("RHS of %s evaluates to %T, want FLOAT64 or INT64", e.SQL(), rhs)
+	case spansql.BitNot:
+		rhs, err := ec.evalExpr(e.RHS)
+		if err != nil {
+			return nil, err
+		}
+		switch rhs := rhs.(type) {
+		case int64:
+			return ^rhs, nil
+		case []byte:
+			b := append([]byte(nil), rhs...) // deep copy
+			for i := range b {
+				b[i] = ^b[i]
+			}
+			return b, nil
+		}
+		return nil, fmt.Errorf("RHS of %s evaluates to %T, want INT64 or BYTES", e.SQL(), rhs)
 	case spansql.Div:
 		lhs, err := ec.evalFloat64(e.LHS)
 		if err != nil {
@@ -369,7 +397,51 @@
 		case spansql.Mul:
 			return f1 * f2, nil
 		}
+	case spansql.BitAnd, spansql.BitXor, spansql.BitOr:
+		lhs, err := ec.evalExpr(e.LHS)
+		if err != nil {
+			return nil, err
+		}
+		rhs, err := ec.evalExpr(e.RHS)
+		if err != nil {
+			return nil, err
+		}
+		i1, ok1 := lhs.(int64)
+		i2, ok2 := rhs.(int64)
+		if ok1 && ok2 {
+			switch e.Op {
+			case spansql.BitAnd:
+				return i1 & i2, nil
+			case spansql.BitXor:
+				return i1 ^ i2, nil
+			case spansql.BitOr:
+				return i1 | i2, nil
+			}
+		}
+		b1, ok1 := lhs.([]byte)
+		b2, ok2 := rhs.([]byte)
+		if !ok1 || !ok2 {
+			return nil, fmt.Errorf("arguments of %s evaluate to (%T, %T), want (INT64, INT64) or (BYTES, BYTES)", e.SQL(), lhs, rhs)
+		}
+		if len(b1) != len(b2) {
+			return nil, fmt.Errorf("arguments of %s evaluate to BYTES of unequal lengths (%d vs %d)", e.SQL(), len(b1), len(b2))
+		}
+		var f func(x, y byte) byte
+		switch e.Op {
+		case spansql.BitAnd:
+			f = func(x, y byte) byte { return x & y }
+		case spansql.BitXor:
+			f = func(x, y byte) byte { return x ^ y }
+		case spansql.BitOr:
+			f = func(x, y byte) byte { return x | y }
+		}
+		b := make([]byte, len(b1))
+		for i := range b1 {
+			b[i] = f(b1[i], b2[i])
+		}
+		return b, nil
 	}
+	// TODO: Concat, BitShl, BitShr
 	return nil, fmt.Errorf("TODO: evalArithOp(%s %v)", e.SQL(), e.Op)
 }
 
@@ -537,6 +609,7 @@
 var (
 	int64Type   = spansql.Type{Base: spansql.Int64}
 	float64Type = spansql.Type{Base: spansql.Float64}
+	stringType  = spansql.Type{Base: spansql.String}
 )
 
 func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) {
@@ -596,6 +669,8 @@
 	switch ao.Op {
 	default:
 		return spansql.Type{}, fmt.Errorf("can't deduce column type from ArithOp [%s]", ao.SQL())
+	case spansql.Neg, spansql.BitNot:
+		return rhs, nil
 	case spansql.Add, spansql.Sub, spansql.Mul:
 		if lhs == int64Type && rhs == int64Type {
 			return int64Type, nil
@@ -603,8 +678,15 @@
 		return float64Type, nil
 	case spansql.Div:
 		return float64Type, nil
+	case spansql.Concat:
+		if !lhs.Array {
+			return stringType, nil
+		}
+		return lhs, nil
+	case spansql.BitShl, spansql.BitShr, spansql.BitAnd, spansql.BitXor, spansql.BitOr:
+		// "All bitwise operators return the same type and the same length as the first operand."
+		return lhs, nil
 	}
-	// TODO: more operators
 }
 
 func evalLike(str, pat string) bool {
diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go
index eaa1d3c..251133b 100644
--- a/spanner/spansql/types.go
+++ b/spanner/spansql/types.go
@@ -300,7 +300,7 @@
 
 type LogicalOp struct {
 	Op       LogicalOperator
-	LHS, RHS BoolExpr // only RHS is set for Not
+	LHS, RHS BoolExpr // only RHS is set for Neg, BitNot
 }
 
 func (LogicalOp) isBoolExpr() {}