// Copyright 2017 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

import (
	"context"
	"testing"
	"time"

	pb "cloud.google.com/go/firestore/apiv1/firestorepb"
	"google.golang.org/api/iterator"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/types/known/emptypb"
	"google.golang.org/protobuf/types/known/timestamppb"
)

func TestRunTransaction(t *testing.T) {
	ctx := context.Background()
	c, srv, cleanup := newMock(t)
	defer cleanup()

	const db = "projects/projectID/databases/(default)"
	tid := []byte{1}

	beginReq := &pb.BeginTransactionRequest{Database: db}
	beginRes := &pb.BeginTransactionResponse{Transaction: tid}
	commitReq := &pb.CommitRequest{Database: db, Transaction: tid}
	// Empty transaction.
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
	err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
	if err != nil {
		t.Fatal(err)
	}

	// Transaction with read and write.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	aDoc := &pb.Document{
		Name:       db + "/documents/C/a",
		CreateTime: aTimestamp,
		UpdateTime: aTimestamp2,
		Fields:     map[string]*pb.Value{"count": intval(1)},
	}
	srv.addRPC(
		&pb.BatchGetDocumentsRequest{
			Database:            c.path(),
			Documents:           []string{db + "/documents/C/a"},
			ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
		}, []interface{}{
			&pb.BatchGetDocumentsResponse{
				Result:   &pb.BatchGetDocumentsResponse_Found{aDoc},
				ReadTime: aTimestamp2,
			},
		})
	aDoc2 := &pb.Document{
		Name:   aDoc.Name,
		Fields: map[string]*pb.Value{"count": intval(2)},
	}
	srv.addRPC(
		&pb.CommitRequest{
			Database:    db,
			Transaction: tid,
			Writes: []*pb.Write{{
				Operation:  &pb.Write_Update{aDoc2},
				UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
				CurrentDocument: &pb.Precondition{
					ConditionType: &pb.Precondition_Exists{true},
				},
			}},
		},
		&pb.CommitResponse{CommitTime: aTimestamp3},
	)
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		docref := c.Collection("C").Doc("a")
		doc, err := tx.Get(docref)
		if err != nil {
			return err
		}
		count, err := doc.DataAt("count")
		if err != nil {
			return err
		}
		return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}})
	})
	if err != nil {
		t.Fatal(err)
	}

	// Query
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(
		&pb.RunQueryRequest{
			Parent: db + "/documents",
			QueryType: &pb.RunQueryRequest_StructuredQuery{
				&pb.StructuredQuery{
					From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: "C"}},
				},
			},
			ConsistencySelector: &pb.RunQueryRequest_Transaction{tid},
		},
		[]interface{}{},
	)
	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp3})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		it := tx.Documents(c.Collection("C"))
		defer it.Stop()
		_, err := it.Next()
		if err != iterator.Done {
			return err
		}
		return nil
	})
	if err != nil {
		t.Fatal(err)
	}

	// Retry entire transaction.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
	srv.addRPC(
		&pb.BeginTransactionRequest{
			Database: db,
			Options: &pb.TransactionOptions{
				Mode: &pb.TransactionOptions_ReadWrite_{
					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
				},
			},
		},
		beginRes,
	)
	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil })
	if err != nil {
		t.Fatal(err)
	}
}

func TestTransactionErrors(t *testing.T) {
	ctx := context.Background()
	const db = "projects/projectID/databases/(default)"
	c, srv, cleanup := newMock(t)
	defer cleanup()

	var (
		tid        = []byte{1}
		unknownErr = status.Errorf(codes.Unknown, "so sad")
		beginReq   = &pb.BeginTransactionRequest{
			Database: db,
		}
		beginRes = &pb.BeginTransactionResponse{Transaction: tid}
		getReq   = &pb.BatchGetDocumentsRequest{
			Database:            c.path(),
			Documents:           []string{db + "/documents/C/a"},
			ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
		}
		rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid}
		commitReq   = &pb.CommitRequest{Database: db, Transaction: tid}
	)

	// BeginTransaction has a permanent error.
	srv.addRPC(beginReq, unknownErr)
	err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
	if status.Code(err) != codes.Unknown {
		t.Errorf("got <%v>, want Unknown", err)
	}

	// Get has a permanent error.
	get := func(_ context.Context, tx *Transaction) error {
		_, err := tx.Get(c.Doc("C/a"))
		return err
	}
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(getReq, unknownErr)
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, get)
	if status.Code(err) != codes.Unknown {
		t.Errorf("got <%v>, want Unknown", err)
	}

	// Get has a permanent error, but the rollback fails. We still
	// return Get's error.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(getReq, unknownErr)
	srv.addRPC(rollbackReq, status.Errorf(codes.FailedPrecondition, ""))
	err = c.RunTransaction(ctx, get)
	if status.Code(err) != codes.Unknown {
		t.Errorf("got <%v>, want Unknown", err)
	}

	// Commit has a permanent error.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(getReq, []interface{}{
		&pb.BatchGetDocumentsResponse{
			Result: &pb.BatchGetDocumentsResponse_Found{&pb.Document{
				Name:       "projects/projectID/databases/(default)/documents/C/a",
				CreateTime: aTimestamp,
				UpdateTime: aTimestamp2,
			}},
			ReadTime: aTimestamp2,
		},
	})
	srv.addRPC(commitReq, unknownErr)
	err = c.RunTransaction(ctx, get)
	if status.Code(err) != codes.Unknown {
		t.Errorf("got <%v>, want Unknown", err)
	}

	// Read after write.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		if err := tx.Delete(c.Doc("C/a")); err != nil {
			return err
		}
		if _, err := tx.Get(c.Doc("C/a")); err != nil {
			return err
		}
		return nil
	})
	if err != errReadAfterWrite {
		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
	}

	// Read after write, with query.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		if err := tx.Delete(c.Doc("C/a")); err != nil {
			return err
		}
		it := tx.Documents(c.Collection("C").Select("x"))
		defer it.Stop()
		if _, err := it.Next(); err != iterator.Done {
			return err
		}
		return nil
	})
	if err != errReadAfterWrite {
		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
	}

	// Read after write, with query and GetAll.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		if err := tx.Delete(c.Doc("C/a")); err != nil {
			return err
		}
		_, err := tx.Documents(c.Collection("C").Select("x")).GetAll()
		return err
	})
	if err != errReadAfterWrite {
		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
	}

	// Read after write fails even if the user ignores the read's error.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		if err := tx.Delete(c.Doc("C/a")); err != nil {
			return err
		}
		if _, err := tx.Get(c.Doc("C/a")); err != nil {
			return err
		}
		return nil
	})
	if err != errReadAfterWrite {
		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
	}

	// Write in read-only transaction.
	srv.reset()
	srv.addRPC(
		&pb.BeginTransactionRequest{
			Database: db,
			Options: &pb.TransactionOptions{
				Mode: &pb.TransactionOptions_ReadOnly_{&pb.TransactionOptions_ReadOnly{}},
			},
		},
		beginRes,
	)
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		return tx.Delete(c.Doc("C/a"))
	}, ReadOnly)
	if err != errWriteReadOnly {
		t.Errorf("got <%v>, want <%v>", err, errWriteReadOnly)
	}

	// Too many retries.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
	srv.addRPC(
		&pb.BeginTransactionRequest{
			Database: db,
			Options: &pb.TransactionOptions{
				Mode: &pb.TransactionOptions_ReadWrite_{
					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
				},
			},
		},
		beginRes,
	)
	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil },
		MaxAttempts(2))
	if status.Code(err) != codes.Aborted {
		t.Errorf("got <%v>, want Aborted", err)
	}

	// Nested transaction.
	srv.reset()
	srv.addRPC(beginReq, beginRes)
	srv.addRPC(rollbackReq, &emptypb.Empty{})
	err = c.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error {
		return c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
	})
	if got, want := err, errNestedTransaction; got != want {
		t.Errorf("got <%v>, want <%v>", got, want)
	}
}

func TestTransactionGetAll(t *testing.T) {
	c, srv, cleanup := newMock(t)
	defer cleanup()

	const dbPath = "projects/projectID/databases/(default)"
	tid := []byte{1}
	beginReq := &pb.BeginTransactionRequest{Database: dbPath}
	beginRes := &pb.BeginTransactionResponse{Transaction: tid}
	srv.addRPC(beginReq, beginRes)
	req := &pb.BatchGetDocumentsRequest{
		Database: dbPath,
		Documents: []string{
			dbPath + "/documents/C/a",
			dbPath + "/documents/C/b",
			dbPath + "/documents/C/c",
		},
		ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
	}
	err := c.RunTransaction(context.Background(), func(_ context.Context, tx *Transaction) error {
		testGetAll(t, c, srv, dbPath,
			func(drs []*DocumentRef) ([]*DocumentSnapshot, error) { return tx.GetAll(drs) },
			req)
		commitReq := &pb.CommitRequest{Database: dbPath, Transaction: tid}
		srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
		return nil
	})
	if err != nil {
		t.Fatal(err)
	}
}

// Each retry attempt has the same amount of commit writes.
func TestRunTransaction_Retries(t *testing.T) {
	ctx := context.Background()
	c, srv, cleanup := newMock(t)
	defer cleanup()

	const db = "projects/projectID/databases/(default)"
	tid := []byte{1}

	srv.addRPC(
		&pb.BeginTransactionRequest{Database: db},
		&pb.BeginTransactionResponse{Transaction: tid},
	)

	aDoc := &pb.Document{
		Name:       db + "/documents/C/a",
		CreateTime: aTimestamp,
		UpdateTime: aTimestamp2,
		Fields:     map[string]*pb.Value{"count": intval(1)},
	}
	aDoc2 := &pb.Document{
		Name:   aDoc.Name,
		Fields: map[string]*pb.Value{"count": intval(7)},
	}

	srv.addRPC(
		&pb.CommitRequest{
			Database:    db,
			Transaction: tid,
			Writes: []*pb.Write{{
				Operation:  &pb.Write_Update{aDoc2},
				UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
				CurrentDocument: &pb.Precondition{
					ConditionType: &pb.Precondition_Exists{true},
				},
			}},
		},
		status.Errorf(codes.Aborted, "something failed! please retry me!"),
	)

	srv.addRPC(
		&pb.BeginTransactionRequest{
			Database: db,
			Options: &pb.TransactionOptions{
				Mode: &pb.TransactionOptions_ReadWrite_{
					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
				},
			},
		},
		&pb.BeginTransactionResponse{Transaction: tid},
	)

	srv.addRPC(
		&pb.CommitRequest{
			Database:    db,
			Transaction: tid,
			Writes: []*pb.Write{{
				Operation:  &pb.Write_Update{aDoc2},
				UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
				CurrentDocument: &pb.Precondition{
					ConditionType: &pb.Precondition_Exists{true},
				},
			}},
		},
		&pb.CommitResponse{CommitTime: aTimestamp3},
	)

	err := c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
		docref := c.Collection("C").Doc("a")
		return tx.Update(docref, []Update{{Path: "count", Value: 7}})
	})
	if err != nil {
		t.Fatal(err)
	}
}

// Non-transactional operations are allowed in transactions (although
// discouraged).
func TestRunTransaction_NonTransactionalOp(t *testing.T) {
	ctx := context.Background()
	c, srv, cleanup := newMock(t)
	defer cleanup()

	const db = "projects/projectID/databases/(default)"
	tid := []byte{1}

	beginReq := &pb.BeginTransactionRequest{Database: db}
	beginRes := &pb.BeginTransactionResponse{Transaction: tid}

	srv.reset()
	srv.addRPC(beginReq, beginRes)
	aDoc := &pb.Document{
		Name:       db + "/documents/C/a",
		CreateTime: aTimestamp,
		UpdateTime: aTimestamp2,
		Fields:     map[string]*pb.Value{"count": intval(1)},
	}
	srv.addRPC(
		&pb.BatchGetDocumentsRequest{
			Database:  c.path(),
			Documents: []string{db + "/documents/C/a"},
		}, []interface{}{
			&pb.BatchGetDocumentsResponse{
				Result:   &pb.BatchGetDocumentsResponse_Found{aDoc},
				ReadTime: aTimestamp2,
			},
		})
	srv.addRPC(
		&pb.CommitRequest{
			Database:    db,
			Transaction: tid,
		},
		&pb.CommitResponse{CommitTime: aTimestamp3},
	)

	if err := c.RunTransaction(ctx, func(ctx2 context.Context, tx *Transaction) error {
		docref := c.Collection("C").Doc("a")
		if _, err := c.GetAll(ctx2, []*DocumentRef{docref}); err != nil {
			t.Fatal(err)
		}
		return nil
	}); err != nil {
		t.Fatal(err)
	}
}

func TestTransaction_WithReadOptions(t *testing.T) {
	ctx := context.Background()
	c, srv, cleanup := newMock(t)
	defer cleanup()

	const db = "projects/projectID/databases/(default)"
	tm := time.Date(2021, time.February, 20, 0, 0, 0, 0, time.UTC)
	ts := &timestamppb.Timestamp{Nanos: int32(tm.UnixNano())}
	tid := []byte{1}

	beginReq := &pb.BeginTransactionRequest{Database: db}
	beginRes := &pb.BeginTransactionResponse{Transaction: tid}

	srv.reset()
	srv.addRPC(beginReq, beginRes)

	srv.addRPC(
		&pb.CommitRequest{
			Database:    db,
			Transaction: tid,
		},
		&pb.CommitResponse{CommitTime: ts},
	)

	srv.addRPC(
		&pb.CommitRequest{
			Database:    db,
			Transaction: tid,
		},
		&pb.CommitResponse{CommitTime: ts},
	)

	if err := c.RunTransaction(ctx, func(ctx2 context.Context, tx *Transaction) error {
		docref := c.Collection("C").Doc("a")
		tx.WithReadOptions(ReadTime(tm)).Get(docref)
		return nil
	}); err != nil {
		t.Fatal(err)
	}
}
