| // Copyright 2017 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package firestore |
| |
| // A simple mock server. |
| |
| import ( |
| "fmt" |
| "sort" |
| "strings" |
| |
| "cloud.google.com/go/internal/testutil" |
| "github.com/golang/protobuf/proto" |
| "github.com/golang/protobuf/ptypes/empty" |
| "golang.org/x/net/context" |
| pb "google.golang.org/genproto/googleapis/firestore/v1beta1" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/status" |
| ) |
| |
| type mockServer struct { |
| pb.FirestoreServer |
| |
| Addr string |
| |
| reqItems []reqItem |
| resps []interface{} |
| } |
| |
| type reqItem struct { |
| wantReq proto.Message |
| adjust func(gotReq proto.Message) |
| } |
| |
| func newMockServer() (*mockServer, error) { |
| srv, err := testutil.NewServer() |
| if err != nil { |
| return nil, err |
| } |
| mock := &mockServer{Addr: srv.Addr} |
| pb.RegisterFirestoreServer(srv.Gsrv, mock) |
| srv.Start() |
| return mock, nil |
| } |
| |
| // addRPC adds a (request, response) pair to the server's list of expected |
| // interactions. The server will compare the incoming request with wantReq |
| // using proto.Equal. The response can be a message or an error. |
| // |
| // For the Listen RPC, resp should be a []interface{}, where each element |
| // is either ListenResponse or an error. |
| // |
| // Passing nil for wantReq disables the request check. |
| func (s *mockServer) addRPC(wantReq proto.Message, resp interface{}) { |
| s.addRPCAdjust(wantReq, resp, nil) |
| } |
| |
| // addRPCAdjust is like addRPC, but accepts a function that can be used |
| // to tweak the requests before comparison, for example to adjust for |
| // randomness. |
| func (s *mockServer) addRPCAdjust(wantReq proto.Message, resp interface{}, adjust func(proto.Message)) { |
| s.reqItems = append(s.reqItems, reqItem{wantReq, adjust}) |
| s.resps = append(s.resps, resp) |
| } |
| |
| // popRPC compares the request with the next expected (request, response) pair. |
| // It returns the response, or an error if the request doesn't match what |
| // was expected or there are no expected rpcs. |
| func (s *mockServer) popRPC(gotReq proto.Message) (interface{}, error) { |
| if len(s.reqItems) == 0 { |
| panic("out of RPCs") |
| } |
| ri := s.reqItems[0] |
| s.reqItems = s.reqItems[1:] |
| if ri.wantReq != nil { |
| if ri.adjust != nil { |
| ri.adjust(gotReq) |
| } |
| |
| // Sort FieldTransforms by FieldPath, since slice order is undefined and proto.Equal |
| // is strict about order. |
| switch gotReqTyped := gotReq.(type) { |
| case *pb.CommitRequest: |
| for _, w := range gotReqTyped.Writes { |
| switch opTyped := w.Operation.(type) { |
| case *pb.Write_Transform: |
| sort.Sort(ByFieldPath(opTyped.Transform.FieldTransforms)) |
| } |
| } |
| } |
| |
| if !proto.Equal(gotReq, ri.wantReq) { |
| return nil, fmt.Errorf("mockServer: bad request\ngot: %T\n%s\nwant: %T\n%s", |
| gotReq, proto.MarshalTextString(gotReq), |
| ri.wantReq, proto.MarshalTextString(ri.wantReq)) |
| } |
| } |
| resp := s.resps[0] |
| s.resps = s.resps[1:] |
| if err, ok := resp.(error); ok { |
| return nil, err |
| } |
| return resp, nil |
| } |
| |
| func (a ByFieldPath) Len() int { return len(a) } |
| func (a ByFieldPath) Swap(i, j int) { a[i], a[j] = a[j], a[i] } |
| func (a ByFieldPath) Less(i, j int) bool { return a[i].FieldPath < a[j].FieldPath } |
| |
| type ByFieldPath []*pb.DocumentTransform_FieldTransform |
| |
| func (s *mockServer) reset() { |
| s.reqItems = nil |
| s.resps = nil |
| } |
| |
| func (s *mockServer) GetDocument(_ context.Context, req *pb.GetDocumentRequest) (*pb.Document, error) { |
| res, err := s.popRPC(req) |
| if err != nil { |
| return nil, err |
| } |
| return res.(*pb.Document), nil |
| } |
| |
| func (s *mockServer) Commit(_ context.Context, req *pb.CommitRequest) (*pb.CommitResponse, error) { |
| res, err := s.popRPC(req) |
| if err != nil { |
| return nil, err |
| } |
| return res.(*pb.CommitResponse), nil |
| } |
| |
| func (s *mockServer) BatchGetDocuments(req *pb.BatchGetDocumentsRequest, bs pb.Firestore_BatchGetDocumentsServer) error { |
| res, err := s.popRPC(req) |
| if err != nil { |
| return err |
| } |
| responses := res.([]interface{}) |
| for _, res := range responses { |
| switch res := res.(type) { |
| case *pb.BatchGetDocumentsResponse: |
| if err := bs.Send(res); err != nil { |
| return err |
| } |
| case error: |
| return res |
| default: |
| panic(fmt.Sprintf("bad response type in BatchGetDocuments: %+v", res)) |
| } |
| } |
| return nil |
| } |
| |
| func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryServer) error { |
| res, err := s.popRPC(req) |
| if err != nil { |
| return err |
| } |
| responses := res.([]interface{}) |
| for _, res := range responses { |
| switch res := res.(type) { |
| case *pb.RunQueryResponse: |
| if err := qs.Send(res); err != nil { |
| return err |
| } |
| case error: |
| return res |
| default: |
| panic(fmt.Sprintf("bad response type in RunQuery: %+v", res)) |
| } |
| } |
| return nil |
| } |
| |
| func (s *mockServer) BeginTransaction(_ context.Context, req *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) { |
| res, err := s.popRPC(req) |
| if err != nil { |
| return nil, err |
| } |
| return res.(*pb.BeginTransactionResponse), nil |
| } |
| |
| func (s *mockServer) Rollback(_ context.Context, req *pb.RollbackRequest) (*empty.Empty, error) { |
| res, err := s.popRPC(req) |
| if err != nil { |
| return nil, err |
| } |
| return res.(*empty.Empty), nil |
| } |
| |
| func (s *mockServer) Listen(stream pb.Firestore_ListenServer) error { |
| req, err := stream.Recv() |
| if err != nil { |
| return err |
| } |
| responses, err := s.popRPC(req) |
| if err != nil { |
| if status.Code(err) == codes.Unknown && strings.Contains(err.Error(), "mockServer") { |
| // The stream will retry on Unknown, but we don't want that to happen if |
| // the error comes from us. |
| panic(err) |
| } |
| return err |
| } |
| for _, res := range responses.([]interface{}) { |
| if err, ok := res.(error); ok { |
| return err |
| } |
| if err := stream.Send(res.(*pb.ListenResponse)); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |