// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bigquery

import (
	"context"
	"fmt"
	"time"

	"cloud.google.com/go/internal/optional"
	"cloud.google.com/go/internal/trace"
	bq "google.golang.org/api/bigquery/v2"
)

// Model represent a reference to a BigQuery ML model.
// Within the API, models are used largely for communicating
// statistical information about a given model, as creation of models is only
// supported via BigQuery queries (e.g. CREATE MODEL .. AS ..).
//
// For more info, see documentation for Bigquery ML,
// see: https://cloud.google.com/bigquery/docs/bigqueryml
type Model struct {
	ProjectID string
	DatasetID string
	// ModelID must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_).
	// The maximum length is 1,024 characters.
	ModelID string

	c *Client
}

// FullyQualifiedName returns the ID of the model in projectID:datasetID.modelid format.
func (m *Model) FullyQualifiedName() string {
	return fmt.Sprintf("%s:%s.%s", m.ProjectID, m.DatasetID, m.ModelID)
}

func (m *Model) toBQ() *bq.ModelReference {
	return &bq.ModelReference{
		ProjectId: m.ProjectID,
		DatasetId: m.DatasetID,
		ModelId:   m.ModelID,
	}
}

// Metadata fetches the metadata for a model, which includes ML training statistics.
func (m *Model) Metadata(ctx context.Context) (mm *ModelMetadata, err error) {
	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Metadata")
	defer func() { trace.EndSpan(ctx, err) }()

	req := m.c.bqs.Models.Get(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
	setClientHeader(req.Header())
	var model *bq.Model
	err = runWithRetry(ctx, func() (err error) {
		model, err = req.Do()
		return err
	})
	if err != nil {
		return nil, err
	}
	return bqToModelMetadata(model)
}

// Update updates mutable fields in an ML model.
func (m *Model) Update(ctx context.Context, mm ModelMetadataToUpdate, etag string) (md *ModelMetadata, err error) {
	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Update")
	defer func() { trace.EndSpan(ctx, err) }()

	bqm, err := mm.toBQ()
	if err != nil {
		return nil, err
	}
	call := m.c.bqs.Models.Patch(m.ProjectID, m.DatasetID, m.ModelID, bqm).Context(ctx)
	setClientHeader(call.Header())
	if etag != "" {
		call.Header().Set("If-Match", etag)
	}
	var res *bq.Model
	if err := runWithRetry(ctx, func() (err error) {
		res, err = call.Do()
		return err
	}); err != nil {
		return nil, err
	}
	return bqToModelMetadata(res)
}

// Delete deletes an ML model.
func (m *Model) Delete(ctx context.Context) (err error) {
	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Delete")
	defer func() { trace.EndSpan(ctx, err) }()

	req := m.c.bqs.Models.Delete(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
	setClientHeader(req.Header())
	return req.Do()
}

// ModelMetadata represents information about a BigQuery ML model.
type ModelMetadata struct {
	// The user-friendly description of the model.
	Description string

	// The user-friendly name of the model.
	Name string

	// The type of the model.  Possible values include:
	// "LINEAR_REGRESSION" - a linear regression model
	// "LOGISTIC_REGRESSION" - a logistic regression model
	// "KMEANS" - a k-means clustering model
	Type string

	// The creation time of the model.
	CreationTime time.Time

	// The last modified time of the model.
	LastModifiedTime time.Time

	// The expiration time of the model.
	ExpirationTime time.Time

	// The geographic location where the model resides.  This value is
	// inherited from the encapsulating dataset.
	Location string

	// Custom encryption configuration (e.g., Cloud KMS keys).
	EncryptionConfig *EncryptionConfig

	// The input feature columns used to train the model.
	featureColumns []*bq.StandardSqlField

	// The label columns used to train the model.  Output
	// from the model will have a "predicted_" prefix for these columns.
	labelColumns []*bq.StandardSqlField

	// Information for all training runs, ordered by increasing start times.
	trainingRuns []*bq.TrainingRun

	Labels map[string]string

	// ETag is the ETag obtained when reading metadata. Pass it to Model.Update
	// to ensure that the metadata hasn't changed since it was read.
	ETag string
}

// TrainingRun represents information about a single training run for a BigQuery ML model.
// Experimental:  This information may be modified or removed in future versions of this package.
type TrainingRun bq.TrainingRun

// RawTrainingRuns exposes the underlying training run stats for a model using types from
// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
// It is EXPERIMENTAL and subject to change or removal without notice.
func (mm *ModelMetadata) RawTrainingRuns() []*TrainingRun {
	if mm.trainingRuns == nil {
		return nil
	}
	var runs []*TrainingRun

	for _, v := range mm.trainingRuns {
		r := TrainingRun(*v)
		runs = append(runs, &r)
	}
	return runs
}

// RawLabelColumns exposes the underlying label columns used to train an ML model and uses types from
// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
// It is EXPERIMENTAL and subject to change or removal without notice.
func (mm *ModelMetadata) RawLabelColumns() ([]*StandardSQLField, error) {
	return bqToModelCols(mm.labelColumns)
}

// RawFeatureColumns exposes the underlying feature columns used to train an ML model and uses types from
// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
// It is EXPERIMENTAL and subject to change or removal without notice.
func (mm *ModelMetadata) RawFeatureColumns() ([]*StandardSQLField, error) {
	return bqToModelCols(mm.featureColumns)
}

func bqToModelCols(s []*bq.StandardSqlField) ([]*StandardSQLField, error) {
	if s == nil {
		return nil, nil
	}
	var cols []*StandardSQLField
	for _, v := range s {
		c, err := bqToStandardSQLField(v)
		if err != nil {
			return nil, err
		}
		cols = append(cols, c)
	}
	return cols, nil
}

func bqToModelMetadata(m *bq.Model) (*ModelMetadata, error) {
	md := &ModelMetadata{
		Description:      m.Description,
		Name:             m.FriendlyName,
		Type:             m.ModelType,
		Location:         m.Location,
		Labels:           m.Labels,
		ExpirationTime:   unixMillisToTime(m.ExpirationTime),
		CreationTime:     unixMillisToTime(m.CreationTime),
		LastModifiedTime: unixMillisToTime(m.LastModifiedTime),
		EncryptionConfig: bqToEncryptionConfig(m.EncryptionConfiguration),
		featureColumns:   m.FeatureColumns,
		labelColumns:     m.LabelColumns,
		trainingRuns:     m.TrainingRuns,
		ETag:             m.Etag,
	}
	return md, nil
}

// ModelMetadataToUpdate is used when updating an ML model's metadata.
// Only non-nil fields will be updated.
type ModelMetadataToUpdate struct {
	// The user-friendly description of this model.
	Description optional.String

	// The user-friendly name of this model.
	Name optional.String

	// The time when this model expires.  To remove a model's expiration,
	// set ExpirationTime to NeverExpire.  The zero value is ignored.
	ExpirationTime time.Time

	// The model's encryption configuration.
	EncryptionConfig *EncryptionConfig

	labelUpdater
}

func (mm *ModelMetadataToUpdate) toBQ() (*bq.Model, error) {
	m := &bq.Model{}
	forceSend := func(field string) {
		m.ForceSendFields = append(m.ForceSendFields, field)
	}

	if mm.Description != nil {
		m.Description = optional.ToString(mm.Description)
		forceSend("Description")
	}

	if mm.Name != nil {
		m.FriendlyName = optional.ToString(mm.Name)
		forceSend("FriendlyName")
	}

	if mm.EncryptionConfig != nil {
		m.EncryptionConfiguration = mm.EncryptionConfig.toBQ()
	}

	if !validExpiration(mm.ExpirationTime) {
		return nil, invalidTimeError(mm.ExpirationTime)
	}
	if mm.ExpirationTime == NeverExpire {
		m.NullFields = append(m.NullFields, "ExpirationTime")
	} else if !mm.ExpirationTime.IsZero() {
		m.ExpirationTime = mm.ExpirationTime.UnixNano() / 1e6
		forceSend("ExpirationTime")
	}
	labels, forces, nulls := mm.update()
	m.Labels = labels
	m.ForceSendFields = append(m.ForceSendFields, forces...)
	m.NullFields = append(m.NullFields, nulls...)
	return m, nil
}
