bigquery: test that model queries work
Test that CREATE MODEL and PREDICT both work, and also that we can
see a model when we list tables.
Change-Id: Ia4f9ef7e8a98043da953661bbc06d4f237dddd85
Reviewed-on: https://code-review.googlesource.com/30210
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go
index ca30bbb..82e3626 100644
--- a/bigquery/integration_test.go
+++ b/bigquery/integration_test.go
@@ -1153,7 +1153,7 @@
('b', [1], STRUCT<BOOL>(FALSE)),
('c', [2], STRUCT<BOOL>(TRUE))`,
table.DatasetID, table.TableID)
- if err := dmlInsert(ctx, sql); err != nil {
+ if err := runDML(ctx, sql); err != nil {
t.Fatal(err)
}
wantRows := [][]Value{
@@ -1164,7 +1164,7 @@
checkRead(t, "DML", table.Read(ctx), wantRows)
}
-func dmlInsert(ctx context.Context, sql string) error {
+func runDML(ctx context.Context, sql string) error {
// Retry insert; sometimes it fails with INTERNAL.
return internal.Retry(ctx, gax.Backoff{}, func() (bool, error) {
// Use DML to insert.
@@ -1223,7 +1223,7 @@
"VALUES ('%s', '%s', '%s', '%s')",
table.DatasetID, table.TableID,
d, CivilTimeString(tm), CivilDateTimeString(dtm), ts.Format("2006-01-02 15:04:05"))
- if err := dmlInsert(ctx, query); err != nil {
+ if err := runDML(ctx, query); err != nil {
t.Fatal(err)
}
wantRows = append(wantRows, wantRows[0])
@@ -1508,7 +1508,7 @@
sql := fmt.Sprintf(`INSERT %s.%s (name, num)
VALUES ('a', 1), ('b', 2), ('c', 3)`,
table.DatasetID, table.TableID)
- if err := dmlInsert(ctx, sql); err != nil {
+ if err := runDML(ctx, sql); err != nil {
t.Fatal(err)
}
// Extract to a GCS object as CSV.
@@ -1872,6 +1872,73 @@
}
}
+func TestIntegration_Model(t *testing.T) {
+ // Create an ML model.
+ if client == nil {
+ t.Skip("Integration tests skipped")
+ }
+ ctx := context.Background()
+ schema := Schema{
+ {Name: "input", Type: IntegerFieldType},
+ {Name: "label", Type: IntegerFieldType},
+ }
+ table := newTable(t, schema)
+ defer table.Delete(ctx)
+
+ // Insert table data.
+ tableName := fmt.Sprintf("%s.%s", table.DatasetID, table.TableID)
+ sql := fmt.Sprintf(`INSERT %s (input, label)
+ VALUES (1, 0), (2, 1), (3, 0), (4, 1)`,
+ tableName)
+ wantNumRows := 4
+ if err := runDML(ctx, sql); err != nil {
+ t.Fatal(err)
+ }
+
+ model := dataset.Table("my_model")
+ modelName := fmt.Sprintf("%s.%s", model.DatasetID, model.TableID)
+ sql = fmt.Sprintf(`CREATE MODEL %s OPTIONS (model_type='logistic_reg') AS SELECT input, label FROM %s`,
+ modelName, tableName)
+ if err := runDML(ctx, sql); err != nil {
+ t.Fatal(err)
+ }
+ defer model.Delete(ctx)
+
+ sql = fmt.Sprintf(`SELECT * FROM ml.PREDICT(MODEL %s, TABLE %s)`, modelName, tableName)
+ q := client.Query(sql)
+ ri, err := q.Read(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rows, _, _, err := readAll(ri)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got := len(rows); got != wantNumRows {
+ t.Fatalf("got %d rows in prediction table, want %d", got, wantNumRows)
+ }
+ iter := dataset.Tables(ctx)
+ seen := false
+ for {
+ tbl, err := iter.Next()
+ if err == iterator.Done {
+ break
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if tbl.TableID == "my_model" {
+ seen = true
+ }
+ }
+ if !seen {
+ t.Fatal("model not listed in dataset")
+ }
+ if err := model.Delete(ctx); err != nil {
+ t.Fatal(err)
+ }
+}
+
// Creates a new, temporary table with a unique name and the given schema.
func newTable(t *testing.T, s Schema) *Table {
table := dataset.Table(tableIDs.New())