spanner/spannertest: change internal representation of DATE/TIMESTAMP values

Using civil.Date/time.Time internally for these columns, rather than the
encoded string values, is less likely to have problems such as
comparison errors. It'll also make future features (e.g. functions using
these value types) easier.

This adds some infrastructure for coercing parts of expressions, and
applies that to string literal arguments of comparison operators. This
can be expanded in the future.

Fixes #2195.

Change-Id: Ifa44fae15cfe08c412f4d997f27c7ac8c70b9e22
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/56476
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index f84ebd7..58e5722 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -36,6 +36,7 @@
 
 	structpb "github.com/golang/protobuf/ptypes/struct"
 
+	"cloud.google.com/go/civil"
 	"cloud.google.com/go/spanner/spansql"
 )
 
@@ -155,8 +156,8 @@
 	FLOAT64		float64
 	STRING		string
 	BYTES		[]byte
-	DATE		string (RFC 3339 date; "YYYY-MM-DD")
-	TIMESTAMP	string (RFC 3339 timestamp with zone; "YYYY-MM-DDTHH:MM:SSZ")
+	DATE		civil.Date
+	TIMESTAMP	time.Time (location set to UTC)
 	ARRAY<T>	[]interface{}
 	STRUCT		TODO
 */
@@ -383,8 +384,7 @@
 				return err
 			}
 			if x == commitTimestampSentinel {
-				// Cloud Spanner commit timestamps have microsecond granularity.
-				x = tx.commitTimestamp.Format("2006-01-02T15:04:05.999999Z")
+				x = tx.commitTimestamp
 			}
 
 			r[i] = x
@@ -785,6 +785,7 @@
 	return true
 }
 
+// valForType converts a value from its RPC form into its internal representation.
 func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) {
 	if _, ok := v.Kind.(*structpb.Value_NullValue); ok {
 		// TODO: enforce NOT NULL constraints?
@@ -843,12 +844,12 @@
 		// The Spanner protocol encodes DATE in RFC 3339 date format.
 		sv, ok := v.Kind.(*structpb.Value_StringValue)
 		if ok {
-			// Store it internally as a string, but validate its value.
 			s := sv.StringValue
-			if _, err := time.Parse("2006-01-02", s); err != nil {
+			d, err := parseAsDate(s)
+			if err != nil {
 				return nil, fmt.Errorf("bad DATE string %q: %v", s, err)
 			}
-			return s, nil
+			return d, nil
 		}
 	case spansql.Timestamp:
 		// The Spanner protocol encodes TIMESTAMP in RFC 3339 timestamp format with zone Z.
@@ -858,11 +859,11 @@
 			if strings.ToLower(s) == "spanner.commit_timestamp()" {
 				return commitTimestampSentinel, nil
 			}
-			// Store it internally as a string, but validate its value.
-			if _, err := time.Parse("2006-01-02T15:04:05.999999999Z", s); err != nil {
+			t, err := parseAsTimestamp(s)
+			if err != nil {
 				return nil, fmt.Errorf("bad TIMESTAMP string %q: %v", s, err)
 			}
-			return s, nil
+			return t, nil
 		}
 	}
 	return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL())
@@ -932,3 +933,8 @@
 		return n, nil
 	}
 }
+
+func parseAsDate(s string) (civil.Date, error) { return civil.ParseDate(s) }
+func parseAsTimestamp(s string) (time.Time, error) {
+	return time.Parse("2006-01-02T15:04:05.999999999Z", s)
+}
diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go
index 2b57d04..ae4e3c7 100644
--- a/spanner/spannertest/db_eval.go
+++ b/spanner/spannertest/db_eval.go
@@ -24,7 +24,9 @@
 	"regexp"
 	"strconv"
 	"strings"
+	"time"
 
+	"cloud.google.com/go/civil"
 	"cloud.google.com/go/spanner/spansql"
 )
 
@@ -40,6 +42,17 @@
 	params queryParams
 }
 
+// coercedValue represents a literal value that has been coerced to a different type.
+// This never leaves this package, nor is persisted.
+type coercedValue struct {
+	spansql.Expr             // not a real Expr
+	val          interface{} // internal representation
+	// TODO: type?
+	orig spansql.Expr
+}
+
+func (cv coercedValue) SQL() string { return cv.orig.SQL() }
+
 func (ec evalContext) evalExprList(list []spansql.Expr) ([]interface{}, error) {
 	var out []interface{}
 	for _, e := range list {
@@ -95,8 +108,15 @@
 			return false, fmt.Errorf("unhandled LogicalOp %d", be.Op)
 		}
 	case spansql.ComparisonOp:
+		// Per https://cloud.google.com/spanner/docs/operators#comparison_operators,
+		// "Cloud Spanner SQL will generally coerce literals to the type of non-literals, where present".
+		// Before evaluating be.LHS and be.RHS, do any necessary coercion.
+		be, err := ec.coerceComparisonOpArgs(be)
+		if err != nil {
+			return false, err
+		}
+
 		var lhs, rhs interface{}
-		var err error
 		lhs, err = ec.evalExpr(be.LHS)
 		if err != nil {
 			return false, err
@@ -332,6 +352,8 @@
 	switch e := e.(type) {
 	default:
 		return nil, fmt.Errorf("TODO: evalExpr(%s %T)", e.SQL(), e)
+	case coercedValue:
+		return e.val, nil
 	case spansql.ID:
 		return ec.evalID(e)
 	case spansql.Param:
@@ -447,6 +469,67 @@
 	return nil, fmt.Errorf("couldn't resolve identifier %s", string(id))
 }
 
+func (ec evalContext) coerceComparisonOpArgs(co spansql.ComparisonOp) (spansql.ComparisonOp, error) {
+	// https://cloud.google.com/spanner/docs/operators#comparison_operators
+
+	if co.RHS2 != nil {
+		// TODO: Handle co.RHS2 for BETWEEN. The rules for that aren't clear.
+		return co, nil
+	}
+
+	// Look for a string literal on LHS or RHS.
+	var err error
+	if slit, ok := co.LHS.(spansql.StringLiteral); ok {
+		co.LHS, err = ec.coerceString(co.RHS, slit)
+		return co, err
+	}
+	if slit, ok := co.RHS.(spansql.StringLiteral); ok {
+		co.RHS, err = ec.coerceString(co.LHS, slit)
+		return co, err
+	}
+
+	// TODO: Other coercion literals. The int64/float64 code elsewhere may be able to be simplified.
+
+	return co, nil
+}
+
+// coerceString converts a string literal into something compatible with the target expression.
+func (ec evalContext) coerceString(target spansql.Expr, slit spansql.StringLiteral) (spansql.Expr, error) {
+	ci, err := ec.colInfo(target)
+	if err != nil {
+		return nil, err
+	}
+	if ci.Type.Array {
+		return nil, fmt.Errorf("unable to coerce string literal %q to match array type", slit)
+	}
+	switch ci.Type.Base {
+	case spansql.String:
+		return slit, nil
+	case spansql.Date:
+		d, err := parseAsDate(string(slit))
+		if err != nil {
+			return nil, fmt.Errorf("coercing string literal %q to DATE: %v", slit, err)
+		}
+		return coercedValue{
+			val:  d,
+			orig: slit,
+		}, nil
+	case spansql.Timestamp:
+		t, err := parseAsTimestamp(string(slit))
+		if err != nil {
+			return nil, fmt.Errorf("coercing string literal %q to TIMESTAMP: %v", slit, err)
+		}
+		return coercedValue{
+			val:  t,
+			orig: slit,
+		}, nil
+	}
+
+	// TODO: Any others?
+
+	return nil, fmt.Errorf("unable to coerce string literal %q to match %v", slit, ci.Type)
+}
+
 func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) {
 	switch v := lop.(type) {
 	case spansql.IntegerLiteral:
@@ -549,8 +632,23 @@
 		}
 		return 0
 	case string:
-		// This handles DATE and TIMESTAMP too.
 		return strings.Compare(x, y.(string))
+	case civil.Date:
+		y := y.(civil.Date)
+		if x.Before(y) {
+			return -1
+		} else if x.After(y) {
+			return 1
+		}
+		return 0
+	case time.Time:
+		y := y.(time.Time)
+		if x.Before(y) {
+			return -1
+		} else if x.After(y) {
+			return 1
+		}
+		return 0
 	case []byte:
 		return bytes.Compare(x, y.([]byte))
 	}
@@ -568,7 +666,7 @@
 	case spansql.IntegerLiteral:
 		return colInfo{Type: int64Type}, nil
 	case spansql.StringLiteral:
-		return colInfo{Type: spansql.Type{Base: spansql.String}}, nil
+		return colInfo{Type: stringType}, nil
 	case spansql.BytesLiteral:
 		return colInfo{Type: spansql.Type{Base: spansql.Bytes}}, nil
 	case spansql.ArithOp:
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index 698cdc6..0476c99 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -265,7 +265,7 @@
 }
 
 type queryParam struct {
-	Value interface{}
+	Value interface{} // internal representation
 	Type  spansql.Type
 }
 
diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go
index b8fa1b9..bc8d67a 100644
--- a/spanner/spannertest/db_test.go
+++ b/spanner/spannertest/db_test.go
@@ -17,6 +17,7 @@
 package spannertest
 
 import (
+	"fmt"
 	"io"
 	"reflect"
 	"sync"
@@ -26,6 +27,7 @@
 
 	structpb "github.com/golang/protobuf/ptypes/struct"
 
+	"cloud.google.com/go/civil"
 	"cloud.google.com/go/spanner/spansql"
 )
 
@@ -451,7 +453,7 @@
 		},
 		{
 			`SELECT Name FROM Staff WHERE FirstSeen >= @min`,
-			queryParams{"min": queryParam{Value: "1996-01-01", Type: spansql.Type{Base: spansql.Date}}},
+			queryParams{"min": dateParam("1996-01-01")},
 			[][]interface{}{
 				{"George"},
 			},
@@ -465,7 +467,8 @@
 		},
 		{
 			// The keyword "To" needs quoting in queries.
-			"SELECT COUNT(*) FROM Staff WHERE `To` IS NOT NULL",
+			// Check coercion of comparison operator literal args too.
+			"SELECT COUNT(*) FROM Staff WHERE `To` > '2000-01-01T00:00:00Z'",
 			nil,
 			[][]interface{}{
 				{int64(1)},
@@ -845,6 +848,14 @@
 func floatParam(f float64) queryParam { return queryParam{Value: f, Type: float64Type} }
 func nullParam() queryParam           { return queryParam{Value: nil} }
 
+func dateParam(s string) queryParam {
+	d, err := civil.ParseDate(s)
+	if err != nil {
+		panic(fmt.Sprintf("bad test date %q: %v", s, err))
+	}
+	return queryParam{Value: d, Type: spansql.Type{Base: spansql.Date}}
+}
+
 func TestRowCmp(t *testing.T) {
 	r := func(x ...interface{}) []interface{} { return x }
 	tests := []struct {
diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go
index 9029fc0..ba53a6f 100644
--- a/spanner/spannertest/inmem.go
+++ b/spanner/spannertest/inmem.go
@@ -70,6 +70,7 @@
 	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
 	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
 
+	"cloud.google.com/go/civil"
 	"cloud.google.com/go/spanner/spansql"
 )
 
@@ -738,8 +739,10 @@
 }
 
 func parseQueryParam(v *structpb.Value, typ *spannerpb.Type) (queryParam, error) {
-	// TODO: Use typeFromSpannerType in here?
+	// TODO: Use valForType and typeFromSpannerType more comprehensively here?
+	// They are only used for StringValue vs, since that's what mostly needs parsing.
 
+	rawv := v
 	switch v := v.Kind.(type) {
 	default:
 		return queryParam{}, fmt.Errorf("unsupported well-known type value kind %T", v)
@@ -748,24 +751,15 @@
 	case *structpb.Value_NumberValue:
 		return queryParam{Value: v.NumberValue, Type: float64Type}, nil
 	case *structpb.Value_StringValue:
-		switch typ.Code {
-		case spannerpb.TypeCode_INT64:
-			return queryParam{Value: v.StringValue, Type: int64Type}, nil
-		case spannerpb.TypeCode_TIMESTAMP:
-			return queryParam{Value: v.StringValue, Type: spansql.Type{Base: spansql.Timestamp}}, nil
-		case spannerpb.TypeCode_DATE:
-			return queryParam{Value: v.StringValue, Type: spansql.Type{Base: spansql.Date}}, nil
-		case spannerpb.TypeCode_BYTES:
-			b, err := base64.StdEncoding.DecodeString(v.StringValue)
-			if err != nil {
-				return queryParam{}, err
-			}
-			return queryParam{Value: b, Type: spansql.Type{Base: spansql.Bytes, Len: spansql.MaxLen}}, nil
-		default:
-			// All other types represented on the wire as a string are stored internally as strings.
-			// We don't often get a type hint unfortunately, so ths type code here may be wrong.
-			return queryParam{Value: v.StringValue, Type: stringType}, nil
+		t, err := typeFromSpannerType(typ)
+		if err != nil {
+			return queryParam{}, err
 		}
+		val, err := valForType(rawv, t)
+		if err != nil {
+			return queryParam{}, err
+		}
+		return queryParam{Value: val, Type: t}, nil
 	case *structpb.Value_ListValue:
 		var list []interface{}
 		for _, elem := range v.ListValue.Values {
@@ -858,6 +852,13 @@
 		return &structpb.Value{Kind: &structpb.Value_StringValue{x}}, nil
 	case []byte:
 		return &structpb.Value{Kind: &structpb.Value_StringValue{base64.StdEncoding.EncodeToString(x)}}, nil
+	case civil.Date:
+		// RFC 3339 date format.
+		return &structpb.Value{Kind: &structpb.Value_StringValue{x.String()}}, nil
+	case time.Time:
+		// RFC 3339 timestamp format with zone Z.
+		s := x.Format("2006-01-02T15:04:05.999999999Z")
+		return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil
 	case nil:
 		return &structpb.Value{Kind: &structpb.Value_NullValue{}}, nil
 	case []interface{}: