spanner: disallow untyped nil in param binding
Change-Id: Ic0b88fc3a18db17fd6a2e635bdacfc0c092c644b
Reviewed-on: https://code-review.googlesource.com/13830
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Vikas Kedia <vikask@google.com>
diff --git a/spanner/statement.go b/spanner/statement.go
index 8e422b0..d04c200 100644
--- a/spanner/statement.go
+++ b/spanner/statement.go
@@ -17,6 +17,7 @@
package spanner
import (
+ "errors"
"fmt"
proto3 "github.com/golang/protobuf/ptypes/struct"
@@ -54,12 +55,17 @@
}
se, ok := toSpannerError(err).(*Error)
if !ok {
- return spannerErrorf(codes.InvalidArgument, "failed to bind query parameter(name: %q, value: %q), error = <%v>", k, v, err)
+ return spannerErrorf(codes.InvalidArgument, "failed to bind query parameter(name: %q, value: %v), error = <%v>", k, v, err)
}
- se.decorate(fmt.Sprintf("failed to bind query parameter(name: %q, value: %q)", k, v))
+ se.decorate(fmt.Sprintf("failed to bind query parameter(name: %q, value: %v)", k, v))
return se
}
+var (
+ errNilParam = errors.New("use T(nil), not nil")
+ errNoType = errors.New("no type information")
+)
+
// bindParams binds parameters in a Statement to a sppb.ExecuteSqlRequest.
func (s *Statement) bindParams(r *sppb.ExecuteSqlRequest) error {
r.Params = &proto3.Struct{
@@ -67,10 +73,16 @@
}
r.ParamTypes = map[string]*sppb.Type{}
for k, v := range s.Params {
+ if v == nil {
+ return errBindParam(k, v, errNilParam)
+ }
val, t, err := encodeValue(v)
if err != nil {
return errBindParam(k, v, err)
}
+ if t == nil { // should not happen, because of nil check above
+ return errBindParam(k, v, errNoType)
+ }
r.Params.Fields[k] = val
r.ParamTypes[k] = t
}
diff --git a/spanner/statement_test.go b/spanner/statement_test.go
index 9854f5d..4607c80 100644
--- a/spanner/statement_test.go
+++ b/spanner/statement_test.go
@@ -60,11 +60,24 @@
}
// Verify type error reporting.
- st.Params["var"] = struct{}{}
- wantErr := errBindParam("var", struct{}{}, errEncoderUnsupportedType(struct{}{}))
- var got sppb.ExecuteSqlRequest
- if err := st.bindParams(&got); !reflect.DeepEqual(err, wantErr) {
- t.Errorf("got unexpected error: %v, want: %v", err, wantErr)
+ for _, test := range []struct {
+ val interface{}
+ wantErr error
+ }{
+ {
+ struct{}{},
+ errBindParam("var", struct{}{}, errEncoderUnsupportedType(struct{}{})),
+ },
+ {
+ nil,
+ errBindParam("var", nil, errNilParam),
+ },
+ } {
+ st.Params["var"] = test.val
+ var got sppb.ExecuteSqlRequest
+ if err := st.bindParams(&got); !reflect.DeepEqual(err, test.wantErr) {
+ t.Errorf("value %#v:\ngot: %v\nwant: %v", test.val, err, test.wantErr)
+ }
}
}