bigquery: add Schema and TotalRows

Make the schema and the totalRows fields of the tabledata.list
response available via exported fields on the RowIterator.

Fixes #765.

Change-Id: I87258c75bfca0e2f5515605782b2d1ac1151db5c
Reviewed-on: https://code-review.googlesource.com/21190
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Michael Darakananda <pongad@google.com>
diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go
index 593aa2d..dfaba93 100644
--- a/bigquery/integration_test.go
+++ b/bigquery/integration_test.go
@@ -600,7 +600,7 @@
 	}
 
 	// Test reading directly into a []Value.
-	valueLists, err := readAll(table.Read(ctx))
+	valueLists, schema, _, err := readAll(table.Read(ctx))
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -610,6 +610,9 @@
 		if err := it.Next(&got); err != nil {
 			t.Fatal(err)
 		}
+		if !testutil.Equal(it.Schema, schema) {
+			t.Fatalf("got schema %v, want %v", it.Schema, schema)
+		}
 		want := []Value(vl)
 		if !testutil.Equal(got, want) {
 			t.Errorf("%d: got %v, want %v", i, got, want)
@@ -954,7 +957,7 @@
 	if err := wait(ctx, job); err != nil {
 		t.Fatal(err)
 	}
-	checkRead(t, "reader load", table.Read(ctx), wantRows)
+	checkReadAndTotalRows(t, "reader load", table.Read(ctx), wantRows)
 
 }
 
@@ -1284,7 +1287,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	checkRead(t, "external query", iter, wantRows)
+	checkReadAndTotalRows(t, "external query", iter, wantRows)
 
 	// Make a table pointing to the file, and query it.
 	// BigQuery does not allow a Table.Read on an external table.
@@ -1302,7 +1305,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	checkRead(t, "external table", iter, wantRows)
+	checkReadAndTotalRows(t, "external table", iter, wantRows)
 
 	// While we're here, check that the table metadata is correct.
 	md, err := table.Metadata(ctx)
@@ -1466,19 +1469,28 @@
 }
 
 func checkRead(t *testing.T, msg string, it *RowIterator, want [][]Value) {
-	if msg2, ok := compareRead(it, want); !ok {
+	if msg2, ok := compareRead(it, want, false); !ok {
 		t.Errorf("%s: %s", msg, msg2)
 	}
 }
 
-func compareRead(it *RowIterator, want [][]Value) (msg string, ok bool) {
-	got, err := readAll(it)
+func checkReadAndTotalRows(t *testing.T, msg string, it *RowIterator, want [][]Value) {
+	if msg2, ok := compareRead(it, want, true); !ok {
+		t.Errorf("%s: %s", msg, msg2)
+	}
+}
+
+func compareRead(it *RowIterator, want [][]Value, compareTotalRows bool) (msg string, ok bool) {
+	got, _, totalRows, err := readAll(it)
 	if err != nil {
 		return err.Error(), false
 	}
 	if len(got) != len(want) {
 		return fmt.Sprintf("got %d rows, want %d", len(got), len(want)), false
 	}
+	if compareTotalRows && len(got) != int(totalRows) {
+		return fmt.Sprintf("got %d rows, but totalRows = %d", len(got), totalRows), false
+	}
 	sort.Sort(byCol0(got))
 	for i, r := range got {
 		gotRow := []Value(r)
@@ -1490,18 +1502,24 @@
 	return "", true
 }
 
-func readAll(it *RowIterator) ([][]Value, error) {
-	var rows [][]Value
+func readAll(it *RowIterator) ([][]Value, Schema, uint64, error) {
+	var (
+		rows      [][]Value
+		schema    Schema
+		totalRows uint64
+	)
 	for {
 		var vals []Value
 		err := it.Next(&vals)
 		if err == iterator.Done {
-			return rows, nil
+			return rows, schema, totalRows, nil
 		}
 		if err != nil {
-			return nil, err
+			return nil, nil, 0, err
 		}
 		rows = append(rows, vals)
+		schema = it.Schema
+		totalRows = it.TotalRows
 	}
 }
 
diff --git a/bigquery/iterator.go b/bigquery/iterator.go
index 5a82edc..f3ce2e5 100644
--- a/bigquery/iterator.go
+++ b/bigquery/iterator.go
@@ -48,9 +48,14 @@
 	// is also set, StartIndex is ignored.
 	StartIndex uint64
 
-	rows [][]Value
+	// The schema of the table. Available after the first call to Next.
+	Schema Schema
 
-	schema       Schema       // populated on first call to fetch
+	// The total number of rows in the result. Available after the first call to Next.
+	// May be zero just after rows were inserted.
+	TotalRows uint64
+
+	rows         [][]Value
 	structLoader structLoader // used to populate a pointer to a struct
 }
 
@@ -113,12 +118,12 @@
 	if vl == nil {
 		// This can only happen if dst is a pointer to a struct. We couldn't
 		// set vl above because we need the schema.
-		if err := it.structLoader.set(dst, it.schema); err != nil {
+		if err := it.structLoader.set(dst, it.Schema); err != nil {
 			return err
 		}
 		vl = &it.structLoader
 	}
-	return vl.Load(row, it.schema)
+	return vl.Load(row, it.Schema)
 }
 
 func isStructPtr(x interface{}) bool {
@@ -130,12 +135,13 @@
 func (it *RowIterator) PageInfo() *iterator.PageInfo { return it.pageInfo }
 
 func (it *RowIterator) fetch(pageSize int, pageToken string) (string, error) {
-	res, err := it.pf(it.ctx, it.table, it.schema, it.StartIndex, int64(pageSize), pageToken)
+	res, err := it.pf(it.ctx, it.table, it.Schema, it.StartIndex, int64(pageSize), pageToken)
 	if err != nil {
 		return "", err
 	}
 	it.rows = append(it.rows, res.rows...)
-	it.schema = res.schema
+	it.Schema = res.schema
+	it.TotalRows = res.totalRows
 	return res.pageToken, nil
 }
 
diff --git a/bigquery/iterator_test.go b/bigquery/iterator_test.go
index 1ecfb53..50cf94f 100644
--- a/bigquery/iterator_test.go
+++ b/bigquery/iterator_test.go
@@ -64,6 +64,7 @@
 		want           [][]Value
 		wantErr        error
 		wantSchema     Schema
+		wantTotalRows  uint64
 	}{
 		{
 			desc: "Iteration over single empty page",
@@ -87,11 +88,13 @@
 						pageToken: "",
 						rows:      [][]Value{{1, 2}, {11, 12}},
 						schema:    iiSchema,
+						totalRows: 4,
 					},
 				},
 			},
-			want:       [][]Value{{1, 2}, {11, 12}},
-			wantSchema: iiSchema,
+			want:          [][]Value{{1, 2}, {11, 12}},
+			wantSchema:    iiSchema,
+			wantTotalRows: 4,
 		},
 		{
 			desc: "Iteration over single page with different schema",
@@ -115,6 +118,7 @@
 						pageToken: "a",
 						rows:      [][]Value{{1, 2}, {11, 12}},
 						schema:    iiSchema,
+						totalRows: 4,
 					},
 				},
 				"a": {
@@ -122,11 +126,13 @@
 						pageToken: "",
 						rows:      [][]Value{{101, 102}, {111, 112}},
 						schema:    iiSchema,
+						totalRows: 4,
 					},
 				},
 			},
-			want:       [][]Value{{1, 2}, {11, 12}, {101, 102}, {111, 112}},
-			wantSchema: iiSchema,
+			want:          [][]Value{{1, 2}, {11, 12}, {101, 102}, {111, 112}},
+			wantSchema:    iiSchema,
+			wantTotalRows: 4,
 		},
 		{
 			desc: "Server response includes empty page",
@@ -240,7 +246,7 @@
 		}
 		it := newRowIterator(context.Background(), nil, pf.fetchPage)
 		it.PageInfo().Token = tc.pageToken
-		values, schema, err := consumeRowIterator(it)
+		values, schema, totalRows, err := consumeRowIterator(it)
 		if err != tc.wantErr {
 			t.Fatalf("%s: got %v, want %v", tc.desc, err, tc.wantErr)
 		}
@@ -250,35 +256,31 @@
 		if (len(schema) != 0 || len(tc.wantSchema) != 0) && !testutil.Equal(schema, tc.wantSchema) {
 			t.Errorf("%s: iterator.Schema:\ngot: %v\nwant: %v", tc.desc, schema, tc.wantSchema)
 		}
+		if totalRows != tc.wantTotalRows {
+			t.Errorf("%s: totalRows: got %d, want %d", tc.desc, totalRows, tc.wantTotalRows)
+		}
 	}
 }
 
-type valueListWithSchema struct {
-	vals   valueList
-	schema Schema
-}
-
-func (v *valueListWithSchema) Load(vs []Value, s Schema) error {
-	v.vals.Load(vs, s)
-	v.schema = s
-	return nil
-}
-
 // consumeRowIterator reads the schema and all values from a RowIterator and returns them.
-func consumeRowIterator(it *RowIterator) ([][]Value, Schema, error) {
-	var got [][]Value
-	var schema Schema
+func consumeRowIterator(it *RowIterator) ([][]Value, Schema, uint64, error) {
+	var (
+		got       [][]Value
+		schema    Schema
+		totalRows uint64
+	)
 	for {
-		var vls valueListWithSchema
+		var vls []Value
 		err := it.Next(&vls)
 		if err == iterator.Done {
-			return got, schema, nil
+			return got, schema, totalRows, nil
 		}
 		if err != nil {
-			return got, schema, err
+			return got, schema, totalRows, err
 		}
-		got = append(got, vls.vals)
-		schema = vls.schema
+		got = append(got, vls)
+		schema = it.Schema
+		totalRows = it.TotalRows
 	}
 }
 
@@ -333,7 +335,7 @@
 		}
 		it := newRowIterator(context.Background(), nil, pf.fetchPage)
 
-		values, _, err := consumeRowIterator(it)
+		values, _, _, err := consumeRowIterator(it)
 		if err != nil {
 			t.Fatal(err)
 		}
diff --git a/bigquery/job.go b/bigquery/job.go
index 91040ed..b415e9f 100644
--- a/bigquery/job.go
+++ b/bigquery/job.go
@@ -271,7 +271,7 @@
 	}
 	dt := bqToTable(destTable, j.c)
 	it := newRowIterator(ctx, dt, pf)
-	it.schema = schema
+	it.Schema = schema
 	return it, nil
 }