feat(spanner): add support for NUMERIC data type (#2415)
* spanner: add support for NUMERIC data type
* Add integration test for Numeric support.
diff --git a/spanner/cmp_test.go b/spanner/cmp_test.go
index 76a0fb3..c8c42f3 100644
--- a/spanner/cmp_test.go
+++ b/spanner/cmp_test.go
@@ -17,6 +17,7 @@
package spanner
import (
+ "math/big"
"strings"
"cloud.google.com/go/internal/testutil"
@@ -27,7 +28,7 @@
func testEqual(a, b interface{}) bool {
return testutil.Equal(a, b,
cmp.AllowUnexported(TimestampBound{}, Error{}, TransactionOutcomeUnknownError{},
- Mutation{}, Row{}, Partition{}, BatchReadOnlyTransactionID{}),
+ Mutation{}, Row{}, Partition{}, BatchReadOnlyTransactionID{}, big.Rat{}, big.Int{}),
cmp.FilterPath(func(path cmp.Path) bool {
// Ignore Error.state, Error.sizeCache, and Error.unknownFields
if strings.HasSuffix(path.GoString(), ".err.(*status.Error).state") {
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index de5fecc..2d29a50 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -23,6 +23,7 @@
"fmt"
"log"
"math"
+ "math/big"
"os"
"reflect"
"regexp"
@@ -1318,7 +1319,44 @@
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
- client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements)
+ stmts := singerDBStatements
+ if !isEmulatorEnvSet() {
+ stmts = []string{
+ `CREATE TABLE Singers (
+ SingerId INT64 NOT NULL,
+ FirstName STRING(1024),
+ LastName STRING(1024),
+ SingerInfo BYTES(MAX)
+ ) PRIMARY KEY (SingerId)`,
+ `CREATE INDEX SingerByName ON Singers(FirstName, LastName)`,
+ `CREATE TABLE Accounts (
+ AccountId INT64 NOT NULL,
+ Nickname STRING(100),
+ Balance INT64 NOT NULL,
+ ) PRIMARY KEY (AccountId)`,
+ `CREATE INDEX AccountByNickname ON Accounts(Nickname) STORING (Balance)`,
+ `CREATE TABLE Types (
+ RowID INT64 NOT NULL,
+ String STRING(MAX),
+ StringArray ARRAY<STRING(MAX)>,
+ Bytes BYTES(MAX),
+ BytesArray ARRAY<BYTES(MAX)>,
+ Int64a INT64,
+ Int64Array ARRAY<INT64>,
+ Bool BOOL,
+ BoolArray ARRAY<BOOL>,
+ Float64 FLOAT64,
+ Float64Array ARRAY<FLOAT64>,
+ Date DATE,
+ DateArray ARRAY<DATE>,
+ Timestamp TIMESTAMP,
+ TimestampArray ARRAY<TIMESTAMP>,
+ Numeric NUMERIC,
+ NumericArray ARRAY<NUMERIC>
+ ) PRIMARY KEY (RowID)`,
+ }
+ }
+ client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, stmts)
defer cleanup()
t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
@@ -1330,6 +1368,10 @@
d2, _ := civil.ParseDate("0001-01-01")
d3, _ := civil.ParseDate("9999-12-31")
+ n0 := big.Rat{}
+ n1 := *big.NewRat(123456789, 1)
+ n2 := *big.NewRat(123456789, 1000000000)
+
tests := []struct {
col string
val interface{}
@@ -1420,6 +1462,31 @@
{col: "TimestampArray", val: []time.Time{t1, t2, t3}, want: []NullTime{{t1, true}, {t2, true}, {t3, true}}},
}
+ if !isEmulatorEnvSet() {
+ for _, tc := range []struct {
+ col string
+ val interface{}
+ want interface{}
+ }{
+ {col: "Numeric", val: n1},
+ {col: "Numeric", val: n2},
+ {col: "Numeric", val: n1, want: NullNumeric{n1, true}},
+ {col: "Numeric", val: n2, want: NullNumeric{n2, true}},
+ {col: "Numeric", val: NullNumeric{n1, true}, want: n1},
+ {col: "Numeric", val: NullNumeric{n1, true}, want: NullNumeric{n1, true}},
+ {col: "Numeric", val: NullNumeric{n0, false}},
+ {col: "Numeric", val: nil, want: NullNumeric{}},
+ {col: "NumericArray", val: []big.Rat(nil), want: []NullNumeric(nil)},
+ {col: "NumericArray", val: []big.Rat{}, want: []NullNumeric{}},
+ {col: "NumericArray", val: []big.Rat{n1, n2}, want: []NullNumeric{{n1, true}, {n2, true}}},
+ {col: "NumericArray", val: []NullNumeric(nil)},
+ {col: "NumericArray", val: []NullNumeric{}},
+ {col: "NumericArray", val: []NullNumeric{{n1, true}, {n2, true}, {}}},
+ } {
+ tests = append(tests, tc)
+ }
+ }
+
// Write rows into table first.
var muts []*Mutation
for i, test := range tests {
@@ -3124,8 +3191,12 @@
return b
}
+func isEmulatorEnvSet() bool {
+ return os.Getenv("SPANNER_EMULATOR_HOST") != ""
+}
+
func skipEmulatorTest(t *testing.T) {
- if os.Getenv("SPANNER_EMULATOR_HOST") != "" {
+ if isEmulatorEnvSet() {
t.Skip("Skipping testing against the emulator.")
}
}
diff --git a/spanner/protoutils.go b/spanner/protoutils.go
index 3797980..6465b2c 100644
--- a/spanner/protoutils.go
+++ b/spanner/protoutils.go
@@ -18,6 +18,7 @@
import (
"encoding/base64"
+ "math/big"
"strconv"
"time"
@@ -64,6 +65,14 @@
return &sppb.Type{Code: sppb.TypeCode_FLOAT64}
}
+func numericProto(n *big.Rat) *proto3.Value {
+ return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: NumericString(n)}}
+}
+
+func numericType() *sppb.Type {
+ return &sppb.Type{Code: sppb.TypeCode_NUMERIC}
+}
+
func bytesProto(b []byte) *proto3.Value {
return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: base64.StdEncoding.EncodeToString(b)}}
}
diff --git a/spanner/row_test.go b/spanner/row_test.go
index 9a6ed29..d909297 100644
--- a/spanner/row_test.go
+++ b/spanner/row_test.go
@@ -745,7 +745,7 @@
[]*proto3.Value{stringProto("nan")},
},
&NullFloat64{},
- errDecodeColumn(0, errUnexpectedNumStr("nan")),
+ errDecodeColumn(0, errUnexpectedFloat64Str("nan")),
},
{
// Field specifies FLOAT64 type, but value is wrongly encoded.
@@ -756,7 +756,7 @@
[]*proto3.Value{stringProto("nan")},
},
proto.Float64(0),
- errDecodeColumn(0, errUnexpectedNumStr("nan")),
+ errDecodeColumn(0, errUnexpectedFloat64Str("nan")),
},
{
// Field specifies BYTES type, value is having a nil Kind.
diff --git a/spanner/value.go b/spanner/value.go
index d206478..f5f142a 100644
--- a/spanner/value.go
+++ b/spanner/value.go
@@ -21,6 +21,7 @@
"encoding/base64"
"fmt"
"math"
+ "math/big"
"reflect"
"strconv"
"time"
@@ -33,10 +34,27 @@
"google.golang.org/grpc/codes"
)
-// nullString is returned by the String methods of NullableValues when the
-// underlying database value is null.
-const nullString = "<null>"
-const commitTimestampPlaceholderString = "spanner.commit_timestamp()"
+const (
+ // nullString is returned by the String methods of NullableValues when the
+ // underlying database value is null.
+ nullString = "<null>"
+ commitTimestampPlaceholderString = "spanner.commit_timestamp()"
+
+ // NumericPrecisionDigits is the maximum number of digits in a NUMERIC
+ // value.
+ NumericPrecisionDigits = 38
+
+ // NumericScaleDigits is the maximum number of digits after the decimal
+ // point in a NUMERIC value.
+ NumericScaleDigits = 9
+)
+
+// NumericString returns a string representing a *big.Rat in a format compatible
+// with Spanner SQL. It returns a floating-point literal with 9 digits after the
+// decimal point.
+func NumericString(r *big.Rat) string {
+ return r.FloatString(NumericScaleDigits)
+}
var (
// CommitTimestamp is a special value used to tell Cloud Spanner to insert
@@ -387,6 +405,57 @@
return nil
}
+// NullNumeric represents a Cloud Spanner Numeric that may be NULL.
+type NullNumeric struct {
+ Numeric big.Rat
+ Valid bool // Valid is true if Numeric is not NULL.
+}
+
+// IsNull implements NullableValue.IsNull for NullNumeric.
+func (n NullNumeric) IsNull() bool {
+ return !n.Valid
+}
+
+// String implements Stringer.String for NullNumeric
+func (n NullNumeric) String() string {
+ if !n.Valid {
+ return nullString
+ }
+ return fmt.Sprintf("%v", NumericString(&n.Numeric))
+}
+
+// MarshalJSON implements json.Marshaler.MarshalJSON for NullNumeric.
+func (n NullNumeric) MarshalJSON() ([]byte, error) {
+ if n.Valid {
+ return []byte(fmt.Sprintf("%q", NumericString(&n.Numeric))), nil
+ }
+ return jsonNullBytes, nil
+}
+
+// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullNumeric.
+func (n *NullNumeric) UnmarshalJSON(payload []byte) error {
+ if payload == nil {
+ return fmt.Errorf("payload should not be nil")
+ }
+ if bytes.Equal(payload, jsonNullBytes) {
+ n.Numeric = big.Rat{}
+ n.Valid = false
+ return nil
+ }
+ payload, err := trimDoubleQuotes(payload)
+ if err != nil {
+ return err
+ }
+ s := string(payload)
+ val, ok := (&big.Rat{}).SetString(s)
+ if !ok {
+ return fmt.Errorf("payload cannot be converted to big.Rat: got %v", string(payload))
+ }
+ n.Numeric = *val
+ n.Valid = true
+ return nil
+}
+
// NullRow represents a Cloud Spanner STRUCT that may be NULL.
// See also the document for Row.
// Note that NullRow is not a valid Cloud Spanner column Type.
@@ -950,6 +1019,107 @@
return err
}
*p = y
+ case *big.Rat:
+ if code != sppb.TypeCode_NUMERIC {
+ return errTypeMismatch(code, acode, ptr)
+ }
+ if isNull {
+ return errDstNotForNull(ptr)
+ }
+ x := v.GetStringValue()
+ y, ok := (&big.Rat{}).SetString(x)
+ if !ok {
+ return errUnexpectedNumericStr(x)
+ }
+ *p = *y
+ case *NullNumeric:
+ if p == nil {
+ return errNilDst(p)
+ }
+ if code != sppb.TypeCode_NUMERIC {
+ return errTypeMismatch(code, acode, ptr)
+ }
+ if isNull {
+ *p = NullNumeric{}
+ break
+ }
+ x := v.GetStringValue()
+ y, ok := (&big.Rat{}).SetString(x)
+ if !ok {
+ return errUnexpectedNumericStr(x)
+ }
+ *p = NullNumeric{*y, true}
+ case **big.Rat:
+ if p == nil {
+ return errNilDst(p)
+ }
+ if code != sppb.TypeCode_NUMERIC {
+ return errTypeMismatch(code, acode, ptr)
+ }
+ if isNull {
+ *p = nil
+ break
+ }
+ x := v.GetStringValue()
+ y, ok := (&big.Rat{}).SetString(x)
+ if !ok {
+ return errUnexpectedNumericStr(x)
+ }
+ *p = y
+ case *[]NullNumeric, *[]*big.Rat:
+ if p == nil {
+ return errNilDst(p)
+ }
+ if acode != sppb.TypeCode_NUMERIC {
+ return errTypeMismatch(code, acode, ptr)
+ }
+ if isNull {
+ switch sp := ptr.(type) {
+ case *[]NullNumeric:
+ *sp = nil
+ case *[]*big.Rat:
+ *sp = nil
+ }
+ break
+ }
+ x, err := getListValue(v)
+ if err != nil {
+ return err
+ }
+ switch sp := ptr.(type) {
+ case *[]NullNumeric:
+ y, err := decodeNullNumericArray(x)
+ if err != nil {
+ return err
+ }
+ *sp = y
+ case *[]*big.Rat:
+ y, err := decodeNumericPointerArray(x)
+ if err != nil {
+ return err
+ }
+ *sp = y
+ }
+ case *[]big.Rat:
+ if p == nil {
+ return errNilDst(p)
+ }
+ if acode != sppb.TypeCode_NUMERIC {
+ return errTypeMismatch(code, acode, ptr)
+ }
+ if isNull {
+ *p = nil
+ break
+ }
+ x, err := getListValue(v)
+ if err != nil {
+ return err
+ }
+ y, err := decodeNumericArray(x)
+ if err != nil {
+ return err
+ }
+ *p = y
case *time.Time:
var nt NullTime
if isNull {
@@ -1168,7 +1338,7 @@
}
// Check if the pointer is a variant of a base type.
- decodableType := getDecodableSpannerType(ptr)
+ decodableType := getDecodableSpannerType(ptr, true)
if decodableType != spannerTypeUnknown {
if isNull && !decodableType.supportsNull() {
return errDstNotForNull(ptr)
@@ -1222,6 +1392,7 @@
spannerTypeNonNullInt64
spannerTypeNonNullBool
spannerTypeNonNullFloat64
+ spannerTypeNonNullNumeric
spannerTypeNonNullTime
spannerTypeNonNullDate
spannerTypeNullString
@@ -1230,17 +1401,20 @@
spannerTypeNullFloat64
spannerTypeNullTime
spannerTypeNullDate
+ spannerTypeNullNumeric
spannerTypeArrayOfNonNullString
spannerTypeArrayOfByteArray
spannerTypeArrayOfNonNullInt64
spannerTypeArrayOfNonNullBool
spannerTypeArrayOfNonNullFloat64
+ spannerTypeArrayOfNonNullNumeric
spannerTypeArrayOfNonNullTime
spannerTypeArrayOfNonNullDate
spannerTypeArrayOfNullString
spannerTypeArrayOfNullInt64
spannerTypeArrayOfNullBool
spannerTypeArrayOfNullFloat64
+ spannerTypeArrayOfNullNumeric
spannerTypeArrayOfNullTime
spannerTypeArrayOfNullDate
)
@@ -1249,7 +1423,7 @@
// Spanner.
func (d decodableSpannerType) supportsNull() bool {
switch d {
- case spannerTypeNonNullString, spannerTypeNonNullInt64, spannerTypeNonNullBool, spannerTypeNonNullFloat64, spannerTypeNonNullTime, spannerTypeNonNullDate:
+ case spannerTypeNonNullString, spannerTypeNonNullInt64, spannerTypeNonNullBool, spannerTypeNonNullFloat64, spannerTypeNonNullTime, spannerTypeNonNullDate, spannerTypeNonNullNumeric:
return false
default:
return true
@@ -1265,17 +1439,26 @@
var typeOfNonNullTime = reflect.TypeOf(time.Time{})
var typeOfNonNullDate = reflect.TypeOf(civil.Date{})
+var typeOfNonNullNumeric = reflect.TypeOf(big.Rat{})
var typeOfNullString = reflect.TypeOf(NullString{})
var typeOfNullInt64 = reflect.TypeOf(NullInt64{})
var typeOfNullBool = reflect.TypeOf(NullBool{})
var typeOfNullFloat64 = reflect.TypeOf(NullFloat64{})
var typeOfNullTime = reflect.TypeOf(NullTime{})
var typeOfNullDate = reflect.TypeOf(NullDate{})
+var typeOfNullNumeric = reflect.TypeOf(NullNumeric{})
// getDecodableSpannerType returns the corresponding decodableSpannerType of
// the given pointer.
-func getDecodableSpannerType(ptr interface{}) decodableSpannerType {
- kind := reflect.Indirect(reflect.ValueOf(ptr)).Kind()
+func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType {
+ var val reflect.Value
+ var kind reflect.Kind
+ if isPtr {
+ val = reflect.Indirect(reflect.ValueOf(ptr))
+ } else {
+ val = reflect.ValueOf(ptr)
+ }
+ kind = val.Kind()
if kind == reflect.Invalid {
return spannerTypeInvalid
}
@@ -1290,8 +1473,16 @@
return spannerTypeNonNullBool
case reflect.Float64:
return spannerTypeNonNullFloat64
+ case reflect.Ptr:
+ t := val.Type()
+ if t.ConvertibleTo(typeOfNullNumeric) {
+ return spannerTypeNullNumeric
+ }
case reflect.Struct:
- t := reflect.Indirect(reflect.ValueOf(ptr)).Type()
+ t := val.Type()
+ if t.ConvertibleTo(typeOfNonNullNumeric) {
+ return spannerTypeNonNullNumeric
+ }
if t.ConvertibleTo(typeOfNonNullTime) {
return spannerTypeNonNullTime
}
@@ -1316,8 +1507,11 @@
if t.ConvertibleTo(typeOfNullDate) {
return spannerTypeNullDate
}
+ if t.ConvertibleTo(typeOfNullNumeric) {
+ return spannerTypeNullNumeric
+ }
case reflect.Slice:
- kind := reflect.Indirect(reflect.ValueOf(ptr)).Type().Elem().Kind()
+ kind := val.Type().Elem().Kind()
switch kind {
case reflect.Invalid:
return spannerTypeUnknown
@@ -1331,8 +1525,16 @@
return spannerTypeArrayOfNonNullBool
case reflect.Float64:
return spannerTypeArrayOfNonNullFloat64
+ case reflect.Ptr:
+ t := val.Type().Elem()
+ if t.ConvertibleTo(typeOfNullNumeric) {
+ return spannerTypeArrayOfNullNumeric
+ }
case reflect.Struct:
- t := reflect.Indirect(reflect.ValueOf(ptr)).Type().Elem()
+ t := val.Type().Elem()
+ if t.ConvertibleTo(typeOfNonNullNumeric) {
+ return spannerTypeArrayOfNonNullNumeric
+ }
if t.ConvertibleTo(typeOfNonNullTime) {
return spannerTypeArrayOfNonNullTime
}
@@ -1357,9 +1559,12 @@
if t.ConvertibleTo(typeOfNullDate) {
return spannerTypeArrayOfNullDate
}
+ if t.ConvertibleTo(typeOfNullNumeric) {
+ return spannerTypeArrayOfNullNumeric
+ }
case reflect.Slice:
// The only array-of-array type that is supported is [][]byte.
- kind := reflect.Indirect(reflect.ValueOf(ptr)).Type().Elem().Elem().Kind()
+ kind := val.Type().Elem().Elem().Kind()
switch kind {
case reflect.Uint8:
return spannerTypeArrayOfByteArray
@@ -1474,6 +1679,24 @@
} else {
result = &NullFloat64{x, !isNull}
}
+ case spannerTypeNonNullNumeric, spannerTypeNullNumeric:
+ if code != sppb.TypeCode_NUMERIC {
+ return errTypeMismatch(code, acode, ptr)
+ }
+ if isNull {
+ result = &NullNumeric{}
+ break
+ }
+ x := v.GetStringValue()
+ y, ok := (&big.Rat{}).SetString(x)
+ if !ok {
+ return errUnexpectedNumericStr(x)
+ }
+ if dsc == spannerTypeNonNullNumeric {
+ result = y
+ } else {
+ result = &NullNumeric{*y, true}
+ }
case spannerTypeNonNullTime, spannerTypeNullTime:
var nt NullTime
err := parseNullTime(v, &nt, code, isNull)
@@ -1591,6 +1814,23 @@
return err
}
result = y
+ case spannerTypeArrayOfNonNullNumeric, spannerTypeArrayOfNullNumeric:
+ if acode != sppb.TypeCode_NUMERIC {
+ return errTypeMismatch(code, acode, ptr)
+ }
+ if isNull {
+ ptr = nil
+ return nil
+ }
+ x, err := getListValue(v)
+ if err != nil {
+ return err
+ }
+ y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, numericType(), "NUMERIC")
+ if err != nil {
+ return err
+ }
+ result = y
case spannerTypeArrayOfNonNullTime, spannerTypeArrayOfNullTime:
if acode != sppb.TypeCode_TIMESTAMP {
return errTypeMismatch(code, acode, ptr)
@@ -1682,10 +1922,16 @@
}
}
-// errUnexpectedNumStr returns error for decoder getting a unexpected string for
-// representing special float values.
-func errUnexpectedNumStr(s string) error {
- return spannerErrorf(codes.FailedPrecondition, "unexpected string value %q for number", s)
+// errUnexpectedNumericStr returns error for decoder getting an unexpected
+// string for representing special numeric values.
+func errUnexpectedNumericStr(s string) error {
+ return spannerErrorf(codes.FailedPrecondition, "unexpected string value %q for numeric number", s)
+}
+
+// errUnexpectedFloat64Str returns error for decoder getting an unexpected
+// string for representing special float values.
+func errUnexpectedFloat64Str(s string) error {
+ return spannerErrorf(codes.FailedPrecondition, "unexpected string value %q for float64 number", s)
}
// getFloat64Value returns the float64 value encoded in proto3.Value v whose
@@ -1710,7 +1956,7 @@
case "-Infinity":
return math.Inf(-1), nil
default:
- return 0, errUnexpectedNumStr(x.StringValue)
+ return 0, errUnexpectedFloat64Str(x.StringValue)
}
}
return 0, errSrcVal(v, "Number")
@@ -1888,7 +2134,7 @@
return a, nil
}
-// decodeFloat64PointerArray decodes proto3.ListValue pb into a NullFloat64 slice.
+// decodeFloat64PointerArray decodes proto3.ListValue pb into a *float slice.
func decodeFloat64PointerArray(pb *proto3.ListValue) ([]*float64, error) {
if pb == nil {
return nil, errNilListValue("FLOAT64")
@@ -1916,6 +2162,48 @@
return a, nil
}
+// decodeNullNumericArray decodes proto3.ListValue pb into a NullNumeric slice.
+func decodeNullNumericArray(pb *proto3.ListValue) ([]NullNumeric, error) {
+ if pb == nil {
+ return nil, errNilListValue("NUMERIC")
+ }
+ a := make([]NullNumeric, len(pb.Values))
+ for i, v := range pb.Values {
+ if err := decodeValue(v, numericType(), &a[i]); err != nil {
+ return nil, errDecodeArrayElement(i, v, "NUMERIC", err)
+ }
+ }
+ return a, nil
+}
+
+// decodeNumericPointerArray decodes proto3.ListValue pb into a *big.Rat slice.
+func decodeNumericPointerArray(pb *proto3.ListValue) ([]*big.Rat, error) {
+ if pb == nil {
+ return nil, errNilListValue("NUMERIC")
+ }
+ a := make([]*big.Rat, len(pb.Values))
+ for i, v := range pb.Values {
+ if err := decodeValue(v, numericType(), &a[i]); err != nil {
+ return nil, errDecodeArrayElement(i, v, "NUMERIC", err)
+ }
+ }
+ return a, nil
+}
+
+// decodeNumericArray decodes proto3.ListValue pb into a big.Rat slice.
+func decodeNumericArray(pb *proto3.ListValue) ([]big.Rat, error) {
+ if pb == nil {
+ return nil, errNilListValue("NUMERIC")
+ }
+ a := make([]big.Rat, len(pb.Values))
+ for i, v := range pb.Values {
+ if err := decodeValue(v, numericType(), &a[i]); err != nil {
+ return nil, errDecodeArrayElement(i, v, "NUMERIC", err)
+ }
+ }
+ return a, nil
+}
+
// decodeByteArray decodes proto3.ListValue pb into a slice of byte slice.
func decodeByteArray(pb *proto3.ListValue) ([][]byte, error) {
if pb == nil {
@@ -2364,6 +2652,43 @@
}
}
pt = listType(floatType())
+ case big.Rat:
+ pb.Kind = stringKind(NumericString(&v))
+ pt = numericType()
+ case []big.Rat:
+ if v != nil {
+ pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ pt = listType(numericType())
+ case NullNumeric:
+ if v.Valid {
+ return encodeValue(v.Numeric)
+ }
+ pt = numericType()
+ case []NullNumeric:
+ if v != nil {
+ pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ pt = listType(numericType())
+ case *big.Rat:
+ if v != nil {
+ pb.Kind = stringKind(NumericString(v))
+ }
+ pt = numericType()
+ case []*big.Rat:
+ if v != nil {
+ pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ pt = listType(numericType())
case time.Time:
if v == commitTimestamp {
pb.Kind = stringKind(commitTimestampPlaceholderString)
@@ -2461,7 +2786,7 @@
}
// Check if the value is a variant of a base type.
- decodableType := getDecodableSpannerType(v)
+ decodableType := getDecodableSpannerType(v, false)
if decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid {
converted, err := convertCustomTypeValue(decodableType, v)
if err != nil {
@@ -2527,6 +2852,10 @@
destination = reflect.Indirect(reflect.New(reflect.TypeOf(civil.Date{})))
case spannerTypeNullDate:
destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullDate{})))
+ case spannerTypeNonNullNumeric:
+ destination = reflect.Indirect(reflect.New(reflect.TypeOf(big.Rat{})))
+ case spannerTypeNullNumeric:
+ destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullNumeric{})))
case spannerTypeArrayOfNonNullString:
if reflect.ValueOf(v).IsNil() {
return []string(nil), nil
@@ -2592,6 +2921,16 @@
return []NullDate(nil), nil
}
destination = reflect.MakeSlice(reflect.TypeOf([]NullDate{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
+ case spannerTypeArrayOfNonNullNumeric:
+ if reflect.ValueOf(v).IsNil() {
+ return []big.Rat(nil), nil
+ }
+ destination = reflect.MakeSlice(reflect.TypeOf([]big.Rat{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
+ case spannerTypeArrayOfNullNumeric:
+ if reflect.ValueOf(v).IsNil() {
+ return []NullNumeric(nil), nil
+ }
+ destination = reflect.MakeSlice(reflect.TypeOf([]NullNumeric{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
default:
// This should not be possible.
return nil, fmt.Errorf("unknown decodable type found: %v", sourceType)
@@ -2602,11 +2941,11 @@
if destination.Kind() == reflect.Slice || destination.Kind() == reflect.Array {
sourceSlice := reflect.ValueOf(v)
for i := 0; i < destination.Len(); i++ {
- source := reflect.Indirect(sourceSlice.Index(i))
+ source := sourceSlice.Index(i)
destination.Index(i).Set(source.Convert(destination.Type().Elem()))
}
} else {
- source := reflect.Indirect(reflect.ValueOf(v))
+ source := reflect.ValueOf(v)
destination.Set(source.Convert(destination.Type()))
}
// Return the converted value.
@@ -2736,7 +3075,7 @@
return true
}
- decodableType := getDecodableSpannerType(v)
+ decodableType := getDecodableSpannerType(v, false)
return decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid
}
}
diff --git a/spanner/value_test.go b/spanner/value_test.go
index d447639..303aacf 100644
--- a/spanner/value_test.go
+++ b/spanner/value_test.go
@@ -20,6 +20,7 @@
"encoding/json"
"fmt"
"math"
+ "math/big"
"reflect"
"testing"
"time"
@@ -187,6 +188,7 @@
type CustomFloat64 float64
type CustomTime time.Time
type CustomDate civil.Date
+ type CustomNumeric big.Rat
type CustomNullString NullString
type CustomNullInt64 NullInt64
@@ -194,6 +196,7 @@
type CustomNullFloat64 NullFloat64
type CustomNullTime NullTime
type CustomNullDate NullDate
+ type CustomNullNumeric NullNumeric
sValue := "abc"
var sNilPtr *string
@@ -207,15 +210,19 @@
var tNilPtr *time.Time
dValue := d1
var dNilPtr *civil.Date
+ numValuePtr := big.NewRat(12345, 1e3)
+ var numNilPtr *big.Rat
+ num2ValuePtr := big.NewRat(12345, 1e4)
var (
- tString = stringType()
- tInt = intType()
- tBool = boolType()
- tFloat = floatType()
- tBytes = bytesType()
- tTime = timeType()
- tDate = dateType()
+ tString = stringType()
+ tInt = intType()
+ tBool = boolType()
+ tFloat = floatType()
+ tBytes = bytesType()
+ tTime = timeType()
+ tDate = dateType()
+ tNumeric = numericType()
)
for i, test := range []struct {
in interface{}
@@ -271,6 +278,17 @@
{[]float64{3.141, 0.618, math.Inf(-1)}, listProto(floatProto(3.141), floatProto(0.618), floatProto(math.Inf(-1))), listType(tFloat), "[]float64"},
{[]NullFloat64{{3.141, true}, {0.618, false}}, listProto(floatProto(3.141), nullProto()), listType(tFloat), "[]NullFloat64"},
{[]*float64{&fValue, fNilPtr}, listProto(floatProto(3.14), nullProto()), listType(tFloat), "[]NullFloat64"},
+ // NUMERIC / NUMERIC ARRAY
+ {*numValuePtr, numericProto(numValuePtr), tNumeric, "big.Rat"},
+ {numValuePtr, numericProto(numValuePtr), tNumeric, "*big.Rat"},
+ {numNilPtr, nullProto(), tNumeric, "*big.Rat with null"},
+ {NullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "NullNumeric with value"},
+ {NullNumeric{*numValuePtr, false}, nullProto(), tNumeric, "NullNumeric with null"},
+ {[]big.Rat(nil), nullProto(), listType(tNumeric), "null []big.Rat"},
+ {[]big.Rat{*numValuePtr, *num2ValuePtr}, listProto(numericProto(numValuePtr), numericProto(num2ValuePtr)), listType(tNumeric), "[]big.Rat"},
+ {[]NullNumeric{{*numValuePtr, true}, {*numValuePtr, false}}, listProto(numericProto(numValuePtr), nullProto()), listType(tNumeric), "[]NullNumeric"},
+ {[]*big.Rat{nil, numValuePtr}, listProto(nullProto(), numericProto(numValuePtr)), listType(tNumeric), "[]*big.Rat"},
+ {[]*big.Rat(nil), nullProto(), listType(tNumeric), "null []*big.Rat"},
// TIMESTAMP / TIMESTAMP ARRAY
{t1, timeProto(t1), tTime, "time"},
{NullTime{t1, true}, timeProto(t1), tTime, "NullTime with value"},
@@ -366,6 +384,14 @@
{customStructToBytes{[]byte("A"), []byte("B")}, bytesProto([]byte("AB")), tBytes, "a struct to bytes"},
{customStructToTime{"A", "B"}, timeProto(tValue), tTime, "a struct to time"},
{customStructToDate{"A", "B"}, dateProto(dValue), tDate, "a struct to date"},
+ // CUSTOM NUMERIC / CUSTOM NUMERIC ARRAY
+ {CustomNumeric(*numValuePtr), numericProto(numValuePtr), tNumeric, "CustomNumeric"},
+ {CustomNullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "CustomNullNumeric with value"},
+ {CustomNullNumeric{*numValuePtr, false}, nullProto(), tNumeric, "CustomNullNumeric with null"},
+ {[]CustomNumeric(nil), nullProto(), listType(tNumeric), "null []CustomNumeric"},
+ {[]CustomNumeric{CustomNumeric(*numValuePtr), CustomNumeric(*num2ValuePtr)}, listProto(numericProto(numValuePtr), numericProto(num2ValuePtr)), listType(tNumeric), "[]CustomNumeric"},
+ {[]CustomNullNumeric(nil), nullProto(), listType(tNumeric), "null []CustomNullNumeric"},
+ {[]CustomNullNumeric{{*numValuePtr, true}, {*num2ValuePtr, false}}, listProto(numericProto(numValuePtr), nullProto()), listType(tNumeric), "[]CustomNullNumeric"},
} {
got, gotType, err := encodeValue(test.in)
if err != nil {
@@ -1205,6 +1231,7 @@
type CustomFloat64 float64
type CustomTime time.Time
type CustomDate civil.Date
+ type CustomNumeric big.Rat
type CustomNullString NullString
type CustomNullInt64 NullInt64
@@ -1212,6 +1239,7 @@
type CustomNullFloat64 NullFloat64
type CustomNullTime NullTime
type CustomNullDate NullDate
+ type CustomNullNumeric NullNumeric
// Pointer values.
sValue := "abc"
@@ -1231,6 +1259,10 @@
var fNilPtr *float64
f2Value := 6.626
+ numValuePtr := big.NewRat(12345, 1e3)
+ var numNilPtr *big.Rat
+ num2ValuePtr := big.NewRat(12345, 1e4)
+
tValue := t1
var tNilPtr *time.Time
t2Value := t2
@@ -1309,9 +1341,23 @@
{desc: "decode NULL to []NullFloat64", proto: nullProto(), protoType: listType(floatType()), want: []NullFloat64(nil)},
// FLOAT64 ARRAY with []float64
{desc: "decode ARRAY<FLOAT64> to []float64", proto: listProto(floatProto(math.Inf(1)), floatProto(math.Inf(-1)), floatProto(3.1)), protoType: listType(floatType()), want: []float64{math.Inf(1), math.Inf(-1), 3.1}},
- // FLOAT64 ARRAY with []NullFloat64
+ // FLOAT64 ARRAY with []*float64
{desc: "decode ARRAY<FLOAT64> to []*float64", proto: listProto(floatProto(fValue), nullProto(), floatProto(f2Value)), protoType: listType(floatType()), want: []*float64{&fValue, nil, &f2Value}},
{desc: "decode NULL to []*float64", proto: nullProto(), protoType: listType(floatType()), want: []*float64(nil)},
+ // NUMERIC
+ {desc: "decode NUMERIC to big.Rat", proto: numericProto(numValuePtr), protoType: numericType(), want: *numValuePtr},
+ {desc: "decode NUMERIC to NullNumeric", proto: numericProto(numValuePtr), protoType: numericType(), want: NullNumeric{*numValuePtr, true}},
+ {desc: "decode NULL to NullNumeric", proto: nullProto(), protoType: numericType(), want: NullNumeric{}},
+ {desc: "decode NUMERIC to *big.Rat", proto: numericProto(numValuePtr), protoType: numericType(), want: numValuePtr},
+ {desc: "decode NULL to *big.Rat", proto: nullProto(), protoType: numericType(), want: numNilPtr},
+ // NUMERIC ARRAY with []NullNumeric
+ {desc: "decode ARRAY<Numeric> to []NullNumeric", proto: listProto(numericProto(numValuePtr), numericProto(num2ValuePtr), nullProto()), protoType: listType(numericType()), want: []NullNumeric{{*numValuePtr, true}, {*num2ValuePtr, true}, {}}},
+ {desc: "decode NULL to []NullNumeric", proto: nullProto(), protoType: listType(numericType()), want: []NullNumeric(nil)},
+ // NUMERIC ARRAY with []big.Rat
+ {desc: "decode ARRAY<NUMERIC> to []big.Rat", proto: listProto(numericProto(numValuePtr), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []big.Rat{*numValuePtr, *num2ValuePtr}},
+ // NUMERIC ARRAY with []*big.Rat
+ {desc: "decode ARRAY<NUMERIC> to []*big.Rat", proto: listProto(numericProto(numValuePtr), nullProto(), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []*big.Rat{numValuePtr, nil, num2ValuePtr}},
+ {desc: "decode NULL to []*big.Rat", proto: nullProto(), protoType: listType(numericType()), want: []*big.Rat(nil)},
// TIMESTAMP
{desc: "decode TIMESTAMP to time.Time", proto: timeProto(t1), protoType: timeType(), want: t1},
{desc: "decode TIMESTAMP to NullTime", proto: timeProto(t1), protoType: timeType(), want: NullTime{t1, true}},
@@ -1504,6 +1550,7 @@
{desc: "decode INT64 to CustomInt64", proto: intProto(-100), protoType: intType(), want: CustomInt64(-100)},
{desc: "decode BOOL to CustomBool", proto: boolProto(true), protoType: boolType(), want: CustomBool(true)},
{desc: "decode FLOAT64 to CustomFloat64", proto: floatProto(6.626), protoType: floatType(), want: CustomFloat64(6.626)},
+ {desc: "decode NUMERIC to CustomNumeric", proto: numericProto(numValuePtr), protoType: numericType(), want: CustomNumeric(*numValuePtr)},
{desc: "decode TIMESTAMP to CustomTimestamp", proto: timeProto(t1), protoType: timeType(), want: CustomTime(t1)},
{desc: "decode DATE to CustomDate", proto: dateProto(d1), protoType: dateType(), want: CustomDate(d1)},
@@ -1512,6 +1559,7 @@
{desc: "decode NULL to CustomInt64", proto: nullProto(), protoType: intType(), want: CustomInt64(0), wantErr: true},
{desc: "decode NULL to CustomBool", proto: nullProto(), protoType: boolType(), want: CustomBool(false), wantErr: true},
{desc: "decode NULL to CustomFloat64", proto: nullProto(), protoType: floatType(), want: CustomFloat64(0), wantErr: true},
+ {desc: "decode NULL to CustomNumeric", proto: nullProto(), protoType: numericType(), want: CustomNumeric{}, wantErr: true},
{desc: "decode NULL to CustomTime", proto: nullProto(), protoType: timeType(), want: CustomTime{}, wantErr: true},
{desc: "decode NULL to CustomDate", proto: nullProto(), protoType: dateType(), want: CustomDate{}, wantErr: true},
@@ -1519,6 +1567,7 @@
{desc: "decode INT64 to CustomNullInt64", proto: intProto(-100), protoType: intType(), want: CustomNullInt64{-100, true}},
{desc: "decode BOOL to CustomNullBool", proto: boolProto(true), protoType: boolType(), want: CustomNullBool{true, true}},
{desc: "decode FLOAT64 to CustomNullFloat64", proto: floatProto(6.626), protoType: floatType(), want: CustomNullFloat64{6.626, true}},
+ {desc: "decode NUMERIC to CustomNullNumeric", proto: numericProto(numValuePtr), protoType: numericType(), want: CustomNullNumeric{*numValuePtr, true}},
{desc: "decode TIMESTAMP to CustomNullTime", proto: timeProto(t1), protoType: timeType(), want: CustomNullTime{t1, true}},
{desc: "decode DATE to CustomNullDate", proto: dateProto(d1), protoType: dateType(), want: CustomNullDate{d1, true}},
@@ -1526,6 +1575,7 @@
{desc: "decode NULL to CustomNullInt64", proto: nullProto(), protoType: intType(), want: CustomNullInt64{}},
{desc: "decode NULL to CustomNullBool", proto: nullProto(), protoType: boolType(), want: CustomNullBool{}},
{desc: "decode NULL to CustomNullFloat64", proto: nullProto(), protoType: floatType(), want: CustomNullFloat64{}},
+ {desc: "decode NULL to CustomNullNumeric", proto: nullProto(), protoType: numericType(), want: CustomNullNumeric{}},
{desc: "decode NULL to CustomNullTime", proto: nullProto(), protoType: timeType(), want: CustomNullTime{}},
{desc: "decode NULL to CustomNullDate", proto: nullProto(), protoType: dateType(), want: CustomNullDate{}},
@@ -1556,6 +1606,12 @@
{desc: "decode ARRAY<FLOAT64> to []CustomFloat64", proto: listProto(floatProto(3.14), floatProto(6.626)), protoType: listType(floatType()), want: []CustomFloat64{3.14, 6.626}},
{desc: "decode NULL to []CustomNullFloat64", proto: nullProto(), protoType: listType(floatType()), want: []CustomNullFloat64(nil)},
{desc: "decode ARRAY<FLOAT64> to []CustomNullFloat64", proto: listProto(floatProto(3.14), nullProto(), floatProto(6.626)), protoType: listType(floatType()), want: []CustomNullFloat64{{3.14, true}, {}, {6.626, true}}},
+ // NUMERIC ARRAY
+ {desc: "decode NULL to []CustomNumeric", proto: nullProto(), protoType: listType(numericType()), want: []CustomNumeric(nil)},
+ {desc: "decode ARRAY<NUMERIC> with NULL values to []CustomNumeric", proto: listProto(numericProto(numValuePtr), nullProto(), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []CustomNumeric{}, wantErr: true},
+ {desc: "decode ARRAY<NUMERIC> to []CustomNumeric", proto: listProto(numericProto(numValuePtr), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []CustomNumeric{CustomNumeric(*numValuePtr), CustomNumeric(*num2ValuePtr)}},
+ {desc: "decode NULL to []CustomNullNumeric", proto: nullProto(), protoType: listType(numericType()), want: []CustomNullNumeric(nil)},
+ {desc: "decode ARRAY<NUMERIC> to []CustomNullNumeric", proto: listProto(numericProto(numValuePtr), nullProto(), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []CustomNullNumeric{{*numValuePtr, true}, {}, {*num2ValuePtr, true}}},
// TIME ARRAY
{desc: "decode NULL to []CustomTime", proto: nullProto(), protoType: listType(timeType()), want: []CustomTime(nil)},
{desc: "decode ARRAY<TIMESTAMP> with NULL values to []CustomTime", proto: listProto(timeProto(t1), nullProto(), timeProto(t2)), protoType: listType(timeType()), want: []CustomTime{}, wantErr: true},
@@ -1590,7 +1646,7 @@
continue
}
got := reflect.Indirect(gotp).Interface()
- if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomTime{}, CustomDate{}, Row{})) {
+ if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{})) {
t.Errorf("%s: unexpected decoding result - got %v (%T), want %v (%T)", test.desc, got, got, test.want, test.want)
}
}
@@ -1622,6 +1678,7 @@
type CustomFloat64 float64
type CustomTime time.Time
type CustomDate civil.Date
+ type CustomNumeric big.Rat
type CustomNullString NullString
type CustomNullInt64 NullInt64
@@ -1629,6 +1686,7 @@
type CustomNullFloat64 NullFloat64
type CustomNullTime NullTime
type CustomNullDate NullDate
+ type CustomNullNumeric NullNumeric
type StringEmbedded struct {
string
@@ -1655,6 +1713,9 @@
{NullFloat64{}, spannerTypeNullFloat64},
{NullTime{}, spannerTypeNullTime},
{NullDate{}, spannerTypeNullDate},
+ {*big.NewRat(1234, 1000), spannerTypeNonNullNumeric},
+ {big.Rat{}, spannerTypeNonNullNumeric},
+ {NullNumeric{}, spannerTypeNullNumeric},
{[]string{"foo", "bar"}, spannerTypeArrayOfNonNullString},
{[][]byte{{1, 2, 3}, {3, 2, 1}}, spannerTypeArrayOfByteArray},
@@ -1670,6 +1731,9 @@
{[]NullFloat64{}, spannerTypeArrayOfNullFloat64},
{[]NullTime{}, spannerTypeArrayOfNullTime},
{[]NullDate{}, spannerTypeArrayOfNullDate},
+ {[]big.Rat{}, spannerTypeArrayOfNonNullNumeric},
+ {[]big.Rat{*big.NewRat(1234, 1000), *big.NewRat(1234, 100)}, spannerTypeArrayOfNonNullNumeric},
+ {[]NullNumeric{}, spannerTypeArrayOfNullNumeric},
{CustomString("foo"), spannerTypeNonNullString},
{CustomInt64(-100), spannerTypeNonNullInt64},
@@ -1677,6 +1741,7 @@
{CustomFloat64(3.141592), spannerTypeNonNullFloat64},
{CustomTime(time.Now()), spannerTypeNonNullTime},
{CustomDate(civil.DateOf(time.Now())), spannerTypeNonNullDate},
+ {CustomNumeric(*big.NewRat(1234, 1000)), spannerTypeNonNullNumeric},
{[]CustomString{}, spannerTypeArrayOfNonNullString},
{[]CustomInt64{}, spannerTypeArrayOfNonNullInt64},
@@ -1684,6 +1749,7 @@
{[]CustomFloat64{}, spannerTypeArrayOfNonNullFloat64},
{[]CustomTime{}, spannerTypeArrayOfNonNullTime},
{[]CustomDate{}, spannerTypeArrayOfNonNullDate},
+ {[]CustomNumeric{}, spannerTypeArrayOfNonNullNumeric},
{CustomNullString{}, spannerTypeNullString},
{CustomNullInt64{}, spannerTypeNullInt64},
@@ -1691,6 +1757,7 @@
{CustomNullFloat64{}, spannerTypeNullFloat64},
{CustomNullTime{}, spannerTypeNullTime},
{CustomNullDate{}, spannerTypeNullDate},
+ {CustomNullNumeric{}, spannerTypeNullNumeric},
{[]CustomNullString{}, spannerTypeArrayOfNullString},
{[]CustomNullInt64{}, spannerTypeArrayOfNullInt64},
@@ -1698,14 +1765,22 @@
{[]CustomNullFloat64{}, spannerTypeArrayOfNullFloat64},
{[]CustomNullTime{}, spannerTypeArrayOfNullTime},
{[]CustomNullDate{}, spannerTypeArrayOfNullDate},
+ {[]CustomNullNumeric{}, spannerTypeArrayOfNullNumeric},
{StringEmbedded{}, spannerTypeUnknown},
{NullStringEmbedded{}, spannerTypeUnknown},
} {
+ // Pass a pointer to the original value.
gotp := reflect.New(reflect.TypeOf(test.in))
- got := getDecodableSpannerType(gotp.Interface())
+ got := getDecodableSpannerType(gotp.Interface(), true)
if got != test.want {
- t.Errorf("%d: unexpected decodable type - got %v, want %v", i, got, test.want)
+ t.Errorf("%d: unexpected decodable type from a pointer - got %v, want %v", i, got, test.want)
+ }
+
+ // Pass the original value.
+ got = getDecodableSpannerType(test.in, false)
+ if got != test.want {
+ t.Errorf("%d: unexpected decodable type from a value - got %v, want %v", i, got, test.want)
}
}
}
@@ -2331,6 +2406,15 @@
{input: NullDate{}, expect: "null"},
},
},
+ {
+ "NullNumeric",
+ []testcase{
+ {input: NullNumeric{*big.NewRat(1234123456789, 1e9), true}, expect: `"1234.123456789"`},
+ {input: &NullNumeric{*big.NewRat(1234123456789, 1e9), true}, expect: `"1234.123456789"`},
+ {input: &NullNumeric{*big.NewRat(1234123456789, 1e9), false}, expect: "null"},
+ {input: NullNumeric{}, expect: "null"},
+ },
+ },
} {
t.Run(test.name, func(t *testing.T) {
for _, tc := range test.cases {
@@ -2419,6 +2503,16 @@
{input: []byte(`"hello`), got: NullDate{}, isNull: true, expect: nullString, expectError: true},
},
},
+ {
+ "NullNumeric",
+ []testcase{
+ {input: []byte(`"1234.123456789"`), got: NullNumeric{}, isNull: false, expect: "1234.123456789", expectError: false},
+ {input: []byte("null"), got: NullNumeric{}, isNull: true, expect: nullString, expectError: false},
+ {input: nil, got: NullNumeric{}, isNull: true, expect: nullString, expectError: true},
+ {input: []byte(""), got: NullNumeric{}, isNull: true, expect: nullString, expectError: true},
+ {input: []byte(`"1234.123456789`), got: NullNumeric{}, isNull: true, expect: nullString, expectError: true},
+ },
+ },
} {
t.Run(test.name, func(t *testing.T) {
for _, tc := range test.cases {
@@ -2441,6 +2535,9 @@
case NullDate:
err := json.Unmarshal(tc.input, &v)
expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError)
+ case NullNumeric:
+ err := json.Unmarshal(tc.input, &v)
+ expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError)
default:
t.Fatalf("Unknown type: %T", v)
}