blob: d3f26ae1dbb8279a53b9a245203e8a7d7dd9d93c [file] [log] [blame]
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"google.golang.org/api/googleapi"
prediction "google.golang.org/api/prediction/v1.6"
)
func init() {
scopes := []string{
prediction.DevstorageFullControlScope,
prediction.DevstorageReadOnlyScope,
prediction.DevstorageReadWriteScope,
prediction.PredictionScope,
}
registerDemo("prediction", strings.Join(scopes, " "), predictionMain)
}
type predictionType struct {
api *prediction.Service
projectNumber string
bucketName string
trainingFileName string
modelName string
}
// This example demonstrates calling the Prediction API.
// Training data is uploaded to a pre-created Google Cloud Storage Bucket and
// then the Prediction API is called to train a model based on that data.
// After a few minutes, the model should be completely trained and ready
// for prediction. At that point, text is sent to the model and the Prediction
// API attempts to classify the data, and the results are printed out.
//
// To get started, follow the instructions found in the "Hello Prediction!"
// Getting Started Guide located here:
// https://developers.google.com/prediction/docs/hello_world
//
// Example usage:
// go-api-demo -clientid="my-clientid" -secret="my-secret" prediction
// my-project-number my-bucket-name my-training-filename my-model-name
//
// Example output:
// Predict result: language=Spanish
// English Score: 0.000000
// French Score: 0.000000
// Spanish Score: 1.000000
// analyze: output feature text=&{157 English}
// analyze: output feature text=&{149 French}
// analyze: output feature text=&{100 Spanish}
// feature text count=406
func predictionMain(client *http.Client, argv []string) {
if len(argv) != 4 {
fmt.Fprintln(os.Stderr,
"Usage: prediction project_number bucket training_data model_name")
return
}
api, err := prediction.New(client)
if err != nil {
log.Fatalf("unable to create prediction API client: %v", err)
}
t := &predictionType{
api: api,
projectNumber: argv[0],
bucketName: argv[1],
trainingFileName: argv[2],
modelName: argv[3],
}
t.trainModel()
t.predictModel()
}
func (t *predictionType) trainModel() {
// First, check to see if our trained model already exists.
res, err := t.api.Trainedmodels.Get(t.projectNumber, t.modelName).Do()
if err != nil {
if ae, ok := err.(*googleapi.Error); ok && ae.Code != http.StatusNotFound {
log.Fatalf("error getting trained model: %v", err)
}
log.Printf("Training model not found, creating new model.")
res, err = t.api.Trainedmodels.Insert(t.projectNumber, &prediction.Insert{
Id: t.modelName,
StorageDataLocation: filepath.Join(t.bucketName, t.trainingFileName),
}).Do()
if err != nil {
log.Fatalf("unable to create trained model: %v", err)
}
}
if res.TrainingStatus != "DONE" {
// Wait for the trained model to finish training.
fmt.Printf("Training model. Please wait and re-run program after a few minutes.")
os.Exit(0)
}
}
func (t *predictionType) predictModel() {
// Model has now been trained. Predict with it.
input := &prediction.Input{
Input: &prediction.InputInput{
CsvInstance: []interface{}{
"Hola, con quien hablo",
},
},
}
res, err := t.api.Trainedmodels.Predict(t.projectNumber, t.modelName, input).Do()
if err != nil {
log.Fatalf("unable to get trained prediction: %v", err)
}
fmt.Printf("Predict result: language=%v\n", res.OutputLabel)
for _, m := range res.OutputMulti {
fmt.Printf("%v Score: %v\n", m.Label, m.Score)
}
// Now analyze the model.
an, err := t.api.Trainedmodels.Analyze(t.projectNumber, t.modelName).Do()
if err != nil {
log.Fatalf("unable to analyze trained model: %v", err)
}
for _, f := range an.DataDescription.OutputFeature.Text {
fmt.Printf("analyze: output feature text=%v\n", f)
}
for _, f := range an.DataDescription.Features {
fmt.Printf("feature text count=%v\n", f.Text.Count)
}
}