feat: support NUMERIC as key (#3627)
Co-authored-by: skuruppu <skuruppu@google.com>
diff --git a/spanner/client_test.go b/spanner/client_test.go
index 4fc6219..cd29a79 100644
--- a/spanner/client_test.go
+++ b/spanner/client_test.go
@@ -20,6 +20,7 @@
"context"
"fmt"
"io"
+ "math/big"
"os"
"strings"
"testing"
@@ -2267,3 +2268,27 @@
t.Errorf("Span status mismatch\nGot: %v\nWant: %v", s.Code, codes.InvalidArgument)
}
}
+
+func TestClient_Single_Read_WithNumericKey(t *testing.T) {
+ t.Parallel()
+
+ _, client, teardown := setupMockedTestServer(t)
+ defer teardown()
+ ctx := context.Background()
+ iter := client.Single().Read(ctx, "Albums", KeySets(Key{*big.NewRat(1, 1)}), []string{"SingerId", "AlbumId", "AlbumTitle"})
+ defer iter.Stop()
+ rowCount := int64(0)
+ for {
+ _, err := iter.Next()
+ if err == iterator.Done {
+ break
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ rowCount++
+ }
+ if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
+ t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
+ }
+}
diff --git a/spanner/key.go b/spanner/key.go
index d981901..aa876ec 100644
--- a/spanner/key.go
+++ b/spanner/key.go
@@ -19,6 +19,7 @@
import (
"bytes"
"fmt"
+ "math/big"
"time"
"cloud.google.com/go/civil"
@@ -84,7 +85,7 @@
pb, _, err = encodeValue(int64(v))
case float32:
pb, _, err = encodeValue(float64(v))
- case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate:
+ case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate, big.Rat, NullNumeric:
pb, _, err = encodeValue(v)
case Encoder:
part, err = v.EncodeSpanner()
@@ -150,7 +151,7 @@
} else {
fmt.Fprint(b, nullString)
}
- case NullInt64, NullFloat64, NullBool:
+ case NullInt64, NullFloat64, NullBool, NullNumeric:
// The above types implement fmt.Stringer.
fmt.Fprintf(b, "%s", v)
case NullString, NullDate, NullTime:
@@ -164,6 +165,8 @@
fmt.Fprintf(b, "%q", v)
case time.Time:
fmt.Fprintf(b, "%q", v.Format(time.RFC3339Nano))
+ case big.Rat:
+ fmt.Fprintf(b, "%v", NumericString(&v))
case Encoder:
var err error
part, err = v.EncodeSpanner()
diff --git a/spanner/key_test.go b/spanner/key_test.go
index 0d5723e..da103fd 100644
--- a/spanner/key_test.go
+++ b/spanner/key_test.go
@@ -18,6 +18,7 @@
import (
"errors"
+ "math/big"
"testing"
"time"
@@ -133,6 +134,11 @@
wantStr: `("2016-11-15")`,
},
{
+ k: Key{*big.NewRat(1, 1)},
+ wantProto: listValueProto(stringProto("1.000000000")),
+ wantStr: `(1.000000000)`,
+ },
+ {
k: Key{[]byte("value")},
wantProto: listValueProto(bytesProto([]byte("value"))),
wantStr: `("value")`,
@@ -204,6 +210,16 @@
wantStr: `(1,<null>,"value",1.5,true)`,
},
{
+ k: Key{NullNumeric{*big.NewRat(2, 3), true}},
+ wantProto: listValueProto(stringProto("0.666666667")),
+ wantStr: "(0.666666667)",
+ },
+ {
+ k: Key{NullNumeric{big.Rat{}, false}},
+ wantProto: listValueProto(nullProto()),
+ wantStr: "(<null>)",
+ },
+ {
k: Key{customKeyToString("value")},
wantProto: listValueProto(stringProto("value")),
wantStr: `("value")`,