| // Copyright 2017 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package firestore |
| |
| import ( |
| "context" |
| "testing" |
| |
| "github.com/golang/protobuf/ptypes/empty" |
| "google.golang.org/api/iterator" |
| pb "google.golang.org/genproto/googleapis/firestore/v1" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/status" |
| ) |
| |
| func TestRunTransaction(t *testing.T) { |
| ctx := context.Background() |
| c, srv, cleanup := newMock(t) |
| defer cleanup() |
| |
| const db = "projects/projectID/databases/(default)" |
| tid := []byte{1} |
| |
| beginReq := &pb.BeginTransactionRequest{Database: db} |
| beginRes := &pb.BeginTransactionResponse{Transaction: tid} |
| commitReq := &pb.CommitRequest{Database: db, Transaction: tid} |
| // Empty transaction. |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp}) |
| err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Transaction with read and write. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| aDoc := &pb.Document{ |
| Name: db + "/documents/C/a", |
| CreateTime: aTimestamp, |
| UpdateTime: aTimestamp2, |
| Fields: map[string]*pb.Value{"count": intval(1)}, |
| } |
| srv.addRPC( |
| &pb.BatchGetDocumentsRequest{ |
| Database: c.path(), |
| Documents: []string{db + "/documents/C/a"}, |
| ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid}, |
| }, []interface{}{ |
| &pb.BatchGetDocumentsResponse{ |
| Result: &pb.BatchGetDocumentsResponse_Found{aDoc}, |
| ReadTime: aTimestamp2, |
| }, |
| }) |
| aDoc2 := &pb.Document{ |
| Name: aDoc.Name, |
| Fields: map[string]*pb.Value{"count": intval(2)}, |
| } |
| srv.addRPC( |
| &pb.CommitRequest{ |
| Database: db, |
| Transaction: tid, |
| Writes: []*pb.Write{{ |
| Operation: &pb.Write_Update{aDoc2}, |
| UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}}, |
| CurrentDocument: &pb.Precondition{ |
| ConditionType: &pb.Precondition_Exists{true}, |
| }, |
| }}, |
| }, |
| &pb.CommitResponse{CommitTime: aTimestamp3}, |
| ) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| docref := c.Collection("C").Doc("a") |
| doc, err := tx.Get(docref) |
| if err != nil { |
| return err |
| } |
| count, err := doc.DataAt("count") |
| if err != nil { |
| return err |
| } |
| return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}}) |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Query |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC( |
| &pb.RunQueryRequest{ |
| Parent: db + "/documents", |
| QueryType: &pb.RunQueryRequest_StructuredQuery{ |
| &pb.StructuredQuery{ |
| From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: "C"}}, |
| }, |
| }, |
| ConsistencySelector: &pb.RunQueryRequest_Transaction{tid}, |
| }, |
| []interface{}{}, |
| ) |
| srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp3}) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| it := tx.Documents(c.Collection("C")) |
| defer it.Stop() |
| _, err := it.Next() |
| if err != iterator.Done { |
| return err |
| } |
| return nil |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Retry entire transaction. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(commitReq, status.Errorf(codes.Aborted, "")) |
| srv.addRPC( |
| &pb.BeginTransactionRequest{ |
| Database: db, |
| Options: &pb.TransactionOptions{ |
| Mode: &pb.TransactionOptions_ReadWrite_{ |
| &pb.TransactionOptions_ReadWrite{RetryTransaction: tid}, |
| }, |
| }, |
| }, |
| beginRes, |
| ) |
| srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp}) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| func TestTransactionErrors(t *testing.T) { |
| ctx := context.Background() |
| const db = "projects/projectID/databases/(default)" |
| c, srv, cleanup := newMock(t) |
| defer cleanup() |
| |
| var ( |
| tid = []byte{1} |
| unknownErr = status.Errorf(codes.Unknown, "so sad") |
| beginReq = &pb.BeginTransactionRequest{ |
| Database: db, |
| } |
| beginRes = &pb.BeginTransactionResponse{Transaction: tid} |
| getReq = &pb.BatchGetDocumentsRequest{ |
| Database: c.path(), |
| Documents: []string{db + "/documents/C/a"}, |
| ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid}, |
| } |
| rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid} |
| commitReq = &pb.CommitRequest{Database: db, Transaction: tid} |
| ) |
| |
| // BeginTransaction has a permanent error. |
| srv.addRPC(beginReq, unknownErr) |
| err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }) |
| if status.Code(err) != codes.Unknown { |
| t.Errorf("got <%v>, want Unknown", err) |
| } |
| |
| // Get has a permanent error. |
| get := func(_ context.Context, tx *Transaction) error { |
| _, err := tx.Get(c.Doc("C/a")) |
| return err |
| } |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(getReq, unknownErr) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, get) |
| if status.Code(err) != codes.Unknown { |
| t.Errorf("got <%v>, want Unknown", err) |
| } |
| |
| // Get has a permanent error, but the rollback fails. We still |
| // return Get's error. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(getReq, unknownErr) |
| srv.addRPC(rollbackReq, status.Errorf(codes.FailedPrecondition, "")) |
| err = c.RunTransaction(ctx, get) |
| if status.Code(err) != codes.Unknown { |
| t.Errorf("got <%v>, want Unknown", err) |
| } |
| |
| // Commit has a permanent error. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(getReq, []interface{}{ |
| &pb.BatchGetDocumentsResponse{ |
| Result: &pb.BatchGetDocumentsResponse_Found{&pb.Document{ |
| Name: "projects/projectID/databases/(default)/documents/C/a", |
| CreateTime: aTimestamp, |
| UpdateTime: aTimestamp2, |
| }}, |
| ReadTime: aTimestamp2, |
| }, |
| }) |
| srv.addRPC(commitReq, unknownErr) |
| err = c.RunTransaction(ctx, get) |
| if status.Code(err) != codes.Unknown { |
| t.Errorf("got <%v>, want Unknown", err) |
| } |
| |
| // Read after write. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| if err := tx.Delete(c.Doc("C/a")); err != nil { |
| return err |
| } |
| if _, err := tx.Get(c.Doc("C/a")); err != nil { |
| return err |
| } |
| return nil |
| }) |
| if err != errReadAfterWrite { |
| t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) |
| } |
| |
| // Read after write, with query. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| if err := tx.Delete(c.Doc("C/a")); err != nil { |
| return err |
| } |
| it := tx.Documents(c.Collection("C").Select("x")) |
| defer it.Stop() |
| if _, err := it.Next(); err != iterator.Done { |
| return err |
| } |
| return nil |
| }) |
| if err != errReadAfterWrite { |
| t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) |
| } |
| |
| // Read after write, with query and GetAll. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| if err := tx.Delete(c.Doc("C/a")); err != nil { |
| return err |
| } |
| _, err := tx.Documents(c.Collection("C").Select("x")).GetAll() |
| return err |
| }) |
| if err != errReadAfterWrite { |
| t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) |
| } |
| |
| // Read after write fails even if the user ignores the read's error. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| if err := tx.Delete(c.Doc("C/a")); err != nil { |
| return err |
| } |
| if _, err := tx.Get(c.Doc("C/a")); err != nil { |
| return err |
| } |
| return nil |
| }) |
| if err != errReadAfterWrite { |
| t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) |
| } |
| |
| // Write in read-only transaction. |
| srv.reset() |
| srv.addRPC( |
| &pb.BeginTransactionRequest{ |
| Database: db, |
| Options: &pb.TransactionOptions{ |
| Mode: &pb.TransactionOptions_ReadOnly_{&pb.TransactionOptions_ReadOnly{}}, |
| }, |
| }, |
| beginRes, |
| ) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| return tx.Delete(c.Doc("C/a")) |
| }, ReadOnly) |
| if err != errWriteReadOnly { |
| t.Errorf("got <%v>, want <%v>", err, errWriteReadOnly) |
| } |
| |
| // Too many retries. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(commitReq, status.Errorf(codes.Aborted, "")) |
| srv.addRPC( |
| &pb.BeginTransactionRequest{ |
| Database: db, |
| Options: &pb.TransactionOptions{ |
| Mode: &pb.TransactionOptions_ReadWrite_{ |
| &pb.TransactionOptions_ReadWrite{RetryTransaction: tid}, |
| }, |
| }, |
| }, |
| beginRes, |
| ) |
| srv.addRPC(commitReq, status.Errorf(codes.Aborted, "")) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }, |
| MaxAttempts(2)) |
| if status.Code(err) != codes.Aborted { |
| t.Errorf("got <%v>, want Aborted", err) |
| } |
| |
| // Nested transaction. |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| srv.addRPC(rollbackReq, &empty.Empty{}) |
| err = c.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error { |
| return c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }) |
| }) |
| if got, want := err, errNestedTransaction; got != want { |
| t.Errorf("got <%v>, want <%v>", got, want) |
| } |
| } |
| |
| func TestTransactionGetAll(t *testing.T) { |
| c, srv, cleanup := newMock(t) |
| defer cleanup() |
| |
| const dbPath = "projects/projectID/databases/(default)" |
| tid := []byte{1} |
| beginReq := &pb.BeginTransactionRequest{Database: dbPath} |
| beginRes := &pb.BeginTransactionResponse{Transaction: tid} |
| srv.addRPC(beginReq, beginRes) |
| req := &pb.BatchGetDocumentsRequest{ |
| Database: dbPath, |
| Documents: []string{ |
| dbPath + "/documents/C/a", |
| dbPath + "/documents/C/b", |
| dbPath + "/documents/C/c", |
| }, |
| ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid}, |
| } |
| err := c.RunTransaction(context.Background(), func(_ context.Context, tx *Transaction) error { |
| testGetAll(t, c, srv, dbPath, |
| func(drs []*DocumentRef) ([]*DocumentSnapshot, error) { return tx.GetAll(drs) }, |
| req) |
| commitReq := &pb.CommitRequest{Database: dbPath, Transaction: tid} |
| srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp}) |
| return nil |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| // Each retry attempt has the same amount of commit writes. |
| func TestRunTransaction_Retries(t *testing.T) { |
| ctx := context.Background() |
| c, srv, cleanup := newMock(t) |
| defer cleanup() |
| |
| const db = "projects/projectID/databases/(default)" |
| tid := []byte{1} |
| |
| srv.addRPC( |
| &pb.BeginTransactionRequest{Database: db}, |
| &pb.BeginTransactionResponse{Transaction: tid}, |
| ) |
| |
| aDoc := &pb.Document{ |
| Name: db + "/documents/C/a", |
| CreateTime: aTimestamp, |
| UpdateTime: aTimestamp2, |
| Fields: map[string]*pb.Value{"count": intval(1)}, |
| } |
| aDoc2 := &pb.Document{ |
| Name: aDoc.Name, |
| Fields: map[string]*pb.Value{"count": intval(7)}, |
| } |
| |
| srv.addRPC( |
| &pb.CommitRequest{ |
| Database: db, |
| Transaction: tid, |
| Writes: []*pb.Write{{ |
| Operation: &pb.Write_Update{aDoc2}, |
| UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}}, |
| CurrentDocument: &pb.Precondition{ |
| ConditionType: &pb.Precondition_Exists{true}, |
| }, |
| }}, |
| }, |
| status.Errorf(codes.Aborted, "something failed! please retry me!"), |
| ) |
| |
| srv.addRPC( |
| &pb.BeginTransactionRequest{ |
| Database: db, |
| Options: &pb.TransactionOptions{ |
| Mode: &pb.TransactionOptions_ReadWrite_{ |
| &pb.TransactionOptions_ReadWrite{RetryTransaction: tid}, |
| }, |
| }, |
| }, |
| &pb.BeginTransactionResponse{Transaction: tid}, |
| ) |
| |
| srv.addRPC( |
| &pb.CommitRequest{ |
| Database: db, |
| Transaction: tid, |
| Writes: []*pb.Write{{ |
| Operation: &pb.Write_Update{aDoc2}, |
| UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}}, |
| CurrentDocument: &pb.Precondition{ |
| ConditionType: &pb.Precondition_Exists{true}, |
| }, |
| }}, |
| }, |
| &pb.CommitResponse{CommitTime: aTimestamp3}, |
| ) |
| |
| err := c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { |
| docref := c.Collection("C").Doc("a") |
| return tx.Update(docref, []Update{{Path: "count", Value: 7}}) |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| // Non-transactional operations are allowed in transactions (although |
| // discouraged). |
| func TestRunTransaction_NonTransactionalOp(t *testing.T) { |
| ctx := context.Background() |
| c, srv, cleanup := newMock(t) |
| defer cleanup() |
| |
| const db = "projects/projectID/databases/(default)" |
| tid := []byte{1} |
| |
| beginReq := &pb.BeginTransactionRequest{Database: db} |
| beginRes := &pb.BeginTransactionResponse{Transaction: tid} |
| |
| srv.reset() |
| srv.addRPC(beginReq, beginRes) |
| aDoc := &pb.Document{ |
| Name: db + "/documents/C/a", |
| CreateTime: aTimestamp, |
| UpdateTime: aTimestamp2, |
| Fields: map[string]*pb.Value{"count": intval(1)}, |
| } |
| srv.addRPC( |
| &pb.BatchGetDocumentsRequest{ |
| Database: c.path(), |
| Documents: []string{db + "/documents/C/a"}, |
| }, []interface{}{ |
| &pb.BatchGetDocumentsResponse{ |
| Result: &pb.BatchGetDocumentsResponse_Found{aDoc}, |
| ReadTime: aTimestamp2, |
| }, |
| }) |
| srv.addRPC( |
| &pb.CommitRequest{ |
| Database: db, |
| Transaction: tid, |
| }, |
| &pb.CommitResponse{CommitTime: aTimestamp3}, |
| ) |
| |
| if err := c.RunTransaction(ctx, func(ctx2 context.Context, tx *Transaction) error { |
| docref := c.Collection("C").Doc("a") |
| if _, err := c.GetAll(ctx2, []*DocumentRef{docref}); err != nil { |
| t.Fatal(err) |
| } |
| return nil |
| }); err != nil { |
| t.Fatal(err) |
| } |
| } |