blob: 7cc90ea7d904a0983584477b0b80e801ce677f36 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil_test
import (
"strconv"
. "cloud.google.com/go/spanner/internal/testutil"
"context"
"flag"
"fmt"
"log"
"net"
"os"
"strings"
"testing"
structpb "github.com/golang/protobuf/ptypes/struct"
spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
"google.golang.org/grpc/codes"
apiv1 "cloud.google.com/go/spanner/apiv1"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/grpc"
gstatus "google.golang.org/grpc/status"
)
// clientOpt is the option tests should use to connect to the test server.
// It is initialized by TestMain.
var serverAddress string
var clientOpt option.ClientOption
var testSpanner InMemSpannerServer
// Mocked selectSQL statement.
const selectSQL = "SELECT FOO FROM BAR"
const selectRowCount int64 = 2
const selectColCount int = 1
var selectValues = [...]int64{1, 2}
// Mocked DML statement.
const updateSQL = "UPDATE FOO SET BAR=1 WHERE ID=ID"
const updateRowCount int64 = 2
func TestMain(m *testing.M) {
flag.Parse()
testSpanner = NewInMemSpannerServer()
serv := grpc.NewServer()
spannerpb.RegisterSpannerServer(serv, testSpanner)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatal(err)
}
go serv.Serve(lis)
serverAddress = lis.Addr().String()
conn, err := grpc.Dial(serverAddress, grpc.WithInsecure())
if err != nil {
log.Fatal(err)
}
clientOpt = option.WithGRPCConn(conn)
os.Exit(m.Run())
}
// Resets the mock server to its default values and registers a mocked result
// for the statements "SELECT FOO FROM BAR" and
// "UPDATE FOO SET BAR=1 WHERE ID=ID".
func setup() {
testSpanner.Reset()
fields := make([]*spannerpb.StructType_Field, selectColCount)
fields[0] = &spannerpb.StructType_Field{
Name: "FOO",
Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64},
}
rowType := &spannerpb.StructType{
Fields: fields,
}
metadata := &spannerpb.ResultSetMetadata{
RowType: rowType,
}
rows := make([]*structpb.ListValue, selectRowCount)
for idx, value := range selectValues {
rowValue := make([]*structpb.Value, selectColCount)
rowValue[0] = &structpb.Value{
Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(value, 10)},
}
rows[idx] = &structpb.ListValue{
Values: rowValue,
}
}
resultSet := &spannerpb.ResultSet{
Metadata: metadata,
Rows: rows,
}
result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet}
testSpanner.PutStatementResult(selectSQL, result)
updateResult := &StatementResult{Type: StatementResultUpdateCount, UpdateCount: updateRowCount}
testSpanner.PutStatementResult(updateSQL, updateResult)
}
func TestSpannerCreateSession(t *testing.T) {
testSpanner.Reset()
var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var request = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
resp, err := c.CreateSession(context.Background(), request)
if err != nil {
t.Fatal(err)
}
if strings.Index(resp.Name, expectedName) != 0 {
t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName)
}
}
func TestSpannerCreateSession_Unavailable(t *testing.T) {
testSpanner.Reset()
var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var request = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
testSpanner.SetError(gstatus.Error(codes.Unavailable, "Temporary unavailable"))
resp, err := c.CreateSession(context.Background(), request)
if err != nil {
t.Fatal(err)
}
if strings.Index(resp.Name, expectedName) != 0 {
t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName)
}
}
func TestSpannerGetSession(t *testing.T) {
testSpanner.Reset()
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
createResp, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
var getRequest = &spannerpb.GetSessionRequest{
Name: createResp.Name,
}
getResp, err := c.GetSession(context.Background(), getRequest)
if err != nil {
t.Fatal(err)
}
if getResp.Name != getRequest.Name {
t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", getResp.Name, getRequest.Name)
}
}
func TestSpannerListSessions(t *testing.T) {
testSpanner.Reset()
const expectedNumberOfSessions = 5
var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
for i := 0; i < expectedNumberOfSessions; i++ {
_, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
}
var listRequest = &spannerpb.ListSessionsRequest{
Database: formattedDatabase,
}
var sessionCount int
listResp := c.ListSessions(context.Background(), listRequest)
for {
session, err := listResp.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatal(err)
}
if strings.Index(session.Name, expectedName) != 0 {
t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", session.Name, expectedName)
}
sessionCount++
}
if sessionCount != expectedNumberOfSessions {
t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions)
}
}
func TestSpannerDeleteSession(t *testing.T) {
testSpanner.Reset()
const expectedNumberOfSessions = 5
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
for i := 0; i < expectedNumberOfSessions; i++ {
_, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
}
var listRequest = &spannerpb.ListSessionsRequest{
Database: formattedDatabase,
}
var sessionCount int
listResp := c.ListSessions(context.Background(), listRequest)
for {
session, err := listResp.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatal(err)
}
var deleteRequest = &spannerpb.DeleteSessionRequest{
Name: session.Name,
}
c.DeleteSession(context.Background(), deleteRequest)
sessionCount++
}
if sessionCount != expectedNumberOfSessions {
t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions)
}
// Re-list all sessions. This should now be empty.
listResp = c.ListSessions(context.Background(), listRequest)
_, err = listResp.Next()
if err != iterator.Done {
t.Errorf("expected empty session iterator")
}
}
func TestSpannerExecuteSql(t *testing.T) {
setup()
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
session, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
request := &spannerpb.ExecuteSqlRequest{
Session: session.Name,
Sql: selectSQL,
Transaction: &spannerpb.TransactionSelector{
Selector: &spannerpb.TransactionSelector_SingleUse{
SingleUse: &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadOnly_{
ReadOnly: &spannerpb.TransactionOptions_ReadOnly{
ReturnReadTimestamp: false,
TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{
Strong: true,
},
},
},
},
},
},
Seqno: 1,
QueryMode: spannerpb.ExecuteSqlRequest_NORMAL,
}
response, err := c.ExecuteSql(context.Background(), request)
if err != nil {
t.Fatal(err)
}
var rowCount int64
for _, row := range response.Rows {
if len(row.Values) != selectColCount {
t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", len(row.Values), selectColCount)
}
rowCount++
}
if rowCount != selectRowCount {
t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowCount, selectRowCount)
}
}
func TestSpannerExecuteSqlDml(t *testing.T) {
setup()
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
session, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
request := &spannerpb.ExecuteSqlRequest{
Session: session.Name,
Sql: updateSQL,
Transaction: &spannerpb.TransactionSelector{
Selector: &spannerpb.TransactionSelector_Begin{
Begin: &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadWrite_{
ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
},
},
},
},
Seqno: 1,
QueryMode: spannerpb.ExecuteSqlRequest_NORMAL,
}
response, err := c.ExecuteSql(context.Background(), request)
if err != nil {
t.Fatal(err)
}
var rowCount int64 = response.Stats.GetRowCountExact()
if rowCount != updateRowCount {
t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount)
}
}
func TestSpannerExecuteStreamingSql(t *testing.T) {
setup()
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
session, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
request := &spannerpb.ExecuteSqlRequest{
Session: session.Name,
Sql: selectSQL,
Transaction: &spannerpb.TransactionSelector{
Selector: &spannerpb.TransactionSelector_SingleUse{
SingleUse: &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadOnly_{
ReadOnly: &spannerpb.TransactionOptions_ReadOnly{
ReturnReadTimestamp: false,
TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{
Strong: true,
},
},
},
},
},
},
Seqno: 1,
QueryMode: spannerpb.ExecuteSqlRequest_NORMAL,
}
response, err := c.ExecuteStreamingSql(context.Background(), request)
if err != nil {
t.Fatal(err)
}
var rowIndex int64
var colCount int
for {
for rowIndexInPartial := int64(0); rowIndexInPartial < MaxRowsPerPartialResultSet; rowIndexInPartial++ {
partial, err := response.Recv()
if err != nil {
t.Fatal(err)
}
if rowIndex == 0 {
colCount = len(partial.Metadata.RowType.Fields)
if colCount != selectColCount {
t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", colCount, selectColCount)
}
}
for col := 0; col < colCount; col++ {
pIndex := rowIndexInPartial*int64(colCount) + int64(col)
val, err := strconv.ParseInt(partial.Values[pIndex].GetStringValue(), 10, 64)
if err != nil {
t.Fatalf("Error parsing integer at #%d: %v", pIndex, err)
}
if val != selectValues[rowIndex] {
t.Fatalf("Value mismatch at index %d\nGot: %d\nWant: %d", rowIndex, val, selectValues[rowIndex])
}
}
rowIndex++
}
if rowIndex == selectRowCount {
break
}
}
if rowIndex != selectRowCount {
t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowIndex, selectRowCount)
}
}
func TestSpannerExecuteBatchDml(t *testing.T) {
setup()
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
session, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
statements := make([]*spannerpb.ExecuteBatchDmlRequest_Statement, 3)
for idx := 0; idx < len(statements); idx++ {
statements[idx] = &spannerpb.ExecuteBatchDmlRequest_Statement{Sql: updateSQL}
}
executeBatchDmlRequest := &spannerpb.ExecuteBatchDmlRequest{
Session: session.Name,
Statements: statements,
Transaction: &spannerpb.TransactionSelector{
Selector: &spannerpb.TransactionSelector_Begin{
Begin: &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadWrite_{
ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
},
},
},
},
Seqno: 1,
}
response, err := c.ExecuteBatchDml(context.Background(), executeBatchDmlRequest)
if err != nil {
t.Fatal(err)
}
var totalRowCount int64
for _, res := range response.ResultSets {
var rowCount int64 = res.Stats.GetRowCountExact()
if rowCount != updateRowCount {
t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount)
}
totalRowCount += rowCount
}
if totalRowCount != updateRowCount*int64(len(statements)) {
t.Fatalf("Total update count mismatch\nGot: %d\nWant: %d", totalRowCount, updateRowCount*int64(len(statements)))
}
}
func TestBeginTransaction(t *testing.T) {
setup()
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
session, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
beginRequest := &spannerpb.BeginTransactionRequest{
Session: session.Name,
Options: &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadWrite_{
ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
},
},
}
tx, err := c.BeginTransaction(context.Background(), beginRequest)
if err != nil {
t.Fatal(err)
}
expectedName := fmt.Sprintf("%s/transactions/", session.Name)
if strings.Index(string(tx.Id), expectedName) != 0 {
t.Errorf("Transaction name mismatch\nGot: %s\nWant: Name should start with %s)", string(tx.Id), expectedName)
}
}
func TestCommitTransaction(t *testing.T) {
setup()
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
session, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
beginRequest := &spannerpb.BeginTransactionRequest{
Session: session.Name,
Options: &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadWrite_{
ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
},
},
}
tx, err := c.BeginTransaction(context.Background(), beginRequest)
if err != nil {
t.Fatal(err)
}
commitRequest := &spannerpb.CommitRequest{
Session: session.Name,
Transaction: &spannerpb.CommitRequest_TransactionId{
TransactionId: tx.Id,
},
}
resp, err := c.Commit(context.Background(), commitRequest)
if err != nil {
t.Fatal(err)
}
if resp.CommitTimestamp == nil {
t.Fatalf("No commit timestamp returned")
}
}
func TestRollbackTransaction(t *testing.T) {
setup()
c, err := apiv1.NewClient(context.Background(), clientOpt)
if err != nil {
t.Fatal(err)
}
var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
var createRequest = &spannerpb.CreateSessionRequest{
Database: formattedDatabase,
}
session, err := c.CreateSession(context.Background(), createRequest)
if err != nil {
t.Fatal(err)
}
beginRequest := &spannerpb.BeginTransactionRequest{
Session: session.Name,
Options: &spannerpb.TransactionOptions{
Mode: &spannerpb.TransactionOptions_ReadWrite_{
ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
},
},
}
tx, err := c.BeginTransaction(context.Background(), beginRequest)
if err != nil {
t.Fatal(err)
}
rollbackRequest := &spannerpb.RollbackRequest{
Session: session.Name,
TransactionId: tx.Id,
}
err = c.Rollback(context.Background(), rollbackRequest)
if err != nil {
t.Fatal(err)
}
}