datastore: replay on -short

Change-Id: I02c1ce15f9ddf4ce71d827006ca1f952e77ea9ce
diff --git a/datastore/integration_test.go b/datastore/integration_test.go
index 0e96b73..6329a31 100644
--- a/datastore/integration_test.go
+++ b/datastore/integration_test.go
@@ -26,7 +26,8 @@
 	"sync"
 	"testing"
 	"time"
-	"veneer/grpcreplay"
+
+	"veneer/rpcreplay"
 
 	"cloud.google.com/go/internal/testutil"
 	"golang.org/x/net/context"
@@ -45,12 +46,14 @@
 // when the tests are run in parallel.
 var suffix string
 
+const replayFilename = "datastore.replay"
+
 var (
-	recordFilename = flag.String("record", "", "filename to record RPCs")
-	replayFilename = flag.String("replay", "", "filename to replay RPCs")
-	dump           = flag.Bool("dump", false, "dump record/replay file")
-	logFlag        = flag.Bool("log", false, "log on replay")
-	dialOptions    []grpc.DialOption
+	record      = flag.Bool("record", false, "record RPCs")
+	dump        = flag.Bool("dump", false, "dump record/replay file")
+	logFlag     = flag.Bool("log", false, "log on replay")
+	dialOptions []grpc.DialOption
+	replaying   bool
 )
 
 func TestMain(m *testing.M) {
@@ -59,15 +62,35 @@
 
 func testMain(m *testing.M) int {
 	flag.Parse()
-	switch {
-	case *recordFilename != "" && *replayFilename != "":
-		log.Fatal("cannot provide both -record and -replay")
-	case *recordFilename != "":
+	if testing.Short() {
+		if *record {
+			log.Fatal("cannot combine -short and -record")
+		}
+		if _, err := os.Stat(replayFilename); err == nil {
+			rep, err := rpcreplay.NewReplayer(replayFilename)
+			if err != nil {
+				log.Fatal(err)
+			}
+			dialOptions = rep.DialOptions()
+			if *logFlag {
+				rep.SetLogFunc(log.Printf)
+			}
+			if err := timeNow.UnmarshalBinary(rep.Initial()); err != nil {
+				log.Fatal(err)
+			}
+			defer rep.Close()
+			log.Printf("replaying from %s", replayFilename)
+			if *dump {
+				rpcreplay.Fprint(os.Stdout, replayFilename)
+			}
+			replaying = true
+		}
+	} else if *record {
 		b, err := timeNow.MarshalBinary()
 		if err != nil {
 			log.Fatal(err)
 		}
-		rec, err := grpcreplay.NewRecorder(*recordFilename, b)
+		rec, err := rpcreplay.NewRecorder(replayFilename, b)
 		if err != nil {
 			log.Fatal(err)
 		}
@@ -77,28 +100,10 @@
 				log.Fatalf("closing recorder: %v", err)
 			}
 			if *dump {
-				grpcreplay.FprintFile(os.Stdout, *recordFilename)
+				rpcreplay.Fprint(os.Stdout, replayFilename)
 			}
 		}()
-		log.Printf("recording to %s", *recordFilename)
-
-	case *replayFilename != "":
-		rep, err := grpcreplay.NewReplayer(*replayFilename)
-		if err != nil {
-			log.Fatal(err)
-		}
-		dialOptions = rep.DialOptions()
-		if *logFlag {
-			rep.SetLogFunc(log.Printf)
-		}
-		if err := timeNow.UnmarshalBinary(rep.Initial()); err != nil {
-			log.Fatal(err)
-		}
-		defer rep.Close()
-		log.Printf("replaying from %s", *replayFilename)
-		if *dump {
-			grpcreplay.FprintFile(os.Stdout, *replayFilename)
-		}
+		log.Printf("recording to %s", replayFilename)
 	}
 	suffix = fmt.Sprintf("-t%d", timeNow.UnixNano())
 	return m.Run()
@@ -127,7 +132,7 @@
 }
 
 func TestBasics(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx, _ := context.WithTimeout(context.Background(), time.Second*20)
@@ -160,7 +165,7 @@
 }
 
 func TestTopLevelKeyLoaded(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 
@@ -200,7 +205,7 @@
 }
 
 func TestListValues(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -227,7 +232,7 @@
 }
 
 func TestGetMulti(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -300,7 +305,7 @@
 }
 
 func TestUnindexableValues(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -331,7 +336,7 @@
 }
 
 func TestNilKey(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -416,7 +421,7 @@
 }
 
 func TestFilters(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -505,7 +510,7 @@
 type ckey struct{}
 
 func TestLargeQuery(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -677,7 +682,7 @@
 	// TODO(jba): either make this actually test eventual consistency, or
 	// delete it. Currently it behaves the same with or without the
 	// EventualConsistency call.
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -704,7 +709,7 @@
 }
 
 func TestProjection(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -750,7 +755,7 @@
 }
 
 func TestAllocateIDs(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -776,7 +781,7 @@
 }
 
 func TestGetAllWithFieldMismatch(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -823,7 +828,7 @@
 }
 
 func TestKindlessQueries(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -947,7 +952,7 @@
 }
 
 func TestTransaction(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -1053,7 +1058,7 @@
 }
 
 func TestNilPointers(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()
@@ -1091,7 +1096,7 @@
 }
 
 func TestNestedRepeatedElementNoIndex(t *testing.T) {
-	if testing.Short() {
+	if testing.Short() && !replaying {
 		t.Skip("Integration tests skipped in short mode")
 	}
 	ctx := context.Background()