// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package auth

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"mime"
	"net/http"
	"net/url"
	"strconv"
	"strings"
	"time"

	"cloud.google.com/go/auth/internal"
)

// AuthorizationHandler is a 3-legged-OAuth helper that prompts the user for
// OAuth consent at the specified auth code URL and returns an auth code and
// state upon approval.
type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)

// Options3LO are the options for doing a 3-legged OAuth2 flow.
type Options3LO struct {
	// ClientID is the application's ID.
	ClientID string
	// ClientSecret is the application's secret. Not required if AuthHandlerOpts
	// is set.
	ClientSecret string
	// AuthURL is the URL for authenticating.
	AuthURL string
	// TokenURL is the URL for retrieving a token.
	TokenURL string
	// AuthStyle is used to describe how to client info in the token request.
	AuthStyle Style
	// RefreshToken is the token used to refresh the credential. Not required
	// if AuthHandlerOpts is set.
	RefreshToken string
	// RedirectURL is the URL to redirect users to. Optional.
	RedirectURL string
	// Scopes specifies requested permissions for the Token. Optional.
	Scopes []string

	// URLParams are the set of values to apply to the token exchange. Optional.
	URLParams url.Values
	// Client is the client to be used to make the underlying token requests.
	// Optional.
	Client *http.Client
	// EarlyTokenExpiry is the time before the token expires that it should be
	// refreshed. If not set the default value is 10 seconds. Optional.
	EarlyTokenExpiry time.Duration

	// AuthHandlerOpts provides a set of options for doing a
	// 3-legged OAuth2 flow with a custom [AuthorizationHandler]. Optional.
	AuthHandlerOpts *AuthorizationHandlerOptions
}

func (o *Options3LO) validate() error {
	if o == nil {
		return errors.New("auth: options must be provided")
	}
	if o.ClientID == "" {
		return errors.New("auth: client ID must be provided")
	}
	if o.AuthHandlerOpts == nil && o.ClientSecret == "" {
		return errors.New("auth: client secret must be provided")
	}
	if o.AuthURL == "" {
		return errors.New("auth: auth URL must be provided")
	}
	if o.TokenURL == "" {
		return errors.New("auth: token URL must be provided")
	}
	if o.AuthStyle == StyleUnknown {
		return errors.New("auth: auth style must be provided")
	}
	if o.AuthHandlerOpts == nil && o.RefreshToken == "" {
		return errors.New("auth: refresh token must be provided")
	}
	return nil
}

// PKCEOptions holds parameters to support PKCE.
type PKCEOptions struct {
	// Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier.
	Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier.
	// ChallengeMethod is the encryption method (ex. S256).
	ChallengeMethod string
	// Verifier is the original, non-encrypted secret.
	Verifier string // The original, non-encrypted secret.
}

type tokenJSON struct {
	AccessToken  string `json:"access_token"`
	TokenType    string `json:"token_type"`
	RefreshToken string `json:"refresh_token"`
	ExpiresIn    int    `json:"expires_in"`
	// error fields
	ErrorCode        string `json:"error"`
	ErrorDescription string `json:"error_description"`
	ErrorURI         string `json:"error_uri"`
}

func (e *tokenJSON) expiry() (t time.Time) {
	if v := e.ExpiresIn; v != 0 {
		return time.Now().Add(time.Duration(v) * time.Second)
	}
	return
}

func (o *Options3LO) client() *http.Client {
	if o.Client != nil {
		return o.Client
	}
	return internal.CloneDefaultClient()
}

// authCodeURL returns a URL that points to a OAuth2 consent page.
func (o *Options3LO) authCodeURL(state string, values url.Values) string {
	var buf bytes.Buffer
	buf.WriteString(o.AuthURL)
	v := url.Values{
		"response_type": {"code"},
		"client_id":     {o.ClientID},
	}
	if o.RedirectURL != "" {
		v.Set("redirect_uri", o.RedirectURL)
	}
	if len(o.Scopes) > 0 {
		v.Set("scope", strings.Join(o.Scopes, " "))
	}
	if state != "" {
		v.Set("state", state)
	}
	if o.AuthHandlerOpts != nil {
		if o.AuthHandlerOpts.PKCEOpts != nil &&
			o.AuthHandlerOpts.PKCEOpts.Challenge != "" {
			v.Set(codeChallengeKey, o.AuthHandlerOpts.PKCEOpts.Challenge)
		}
		if o.AuthHandlerOpts.PKCEOpts != nil &&
			o.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" {
			v.Set(codeChallengeMethodKey, o.AuthHandlerOpts.PKCEOpts.ChallengeMethod)
		}
	}
	for k := range values {
		v.Set(k, v.Get(k))
	}
	if strings.Contains(o.AuthURL, "?") {
		buf.WriteByte('&')
	} else {
		buf.WriteByte('?')
	}
	buf.WriteString(v.Encode())
	return buf.String()
}

// New3LOTokenProvider returns a [TokenProvider] based on the 3-legged OAuth2
// configuration. The TokenProvider is caches and auto-refreshes tokens by
// default.
func New3LOTokenProvider(opts *Options3LO) (TokenProvider, error) {
	if err := opts.validate(); err != nil {
		return nil, err
	}
	if opts.AuthHandlerOpts != nil {
		return new3LOTokenProviderWithAuthHandler(opts), nil
	}
	return NewCachedTokenProvider(&tokenProvider3LO{opts: opts, refreshToken: opts.RefreshToken, client: opts.client()}, &CachedTokenProviderOptions{
		ExpireEarly: opts.EarlyTokenExpiry,
	}), nil
}

// AuthorizationHandlerOptions provides a set of options to specify for doing a
// 3-legged OAuth2 flow with a custom [AuthorizationHandler].
type AuthorizationHandlerOptions struct {
	// AuthorizationHandler specifies the handler used to for the authorization
	// part of the flow.
	Handler AuthorizationHandler
	// State is used verify that the "state" is identical in the request and
	// response before exchanging the auth code for OAuth2 token.
	State string
	// PKCEOpts allows setting configurations for PKCE. Optional.
	PKCEOpts *PKCEOptions
}

func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider {
	return NewCachedTokenProvider(&tokenProviderWithHandler{opts: opts, state: opts.AuthHandlerOpts.State}, &CachedTokenProviderOptions{
		ExpireEarly: opts.EarlyTokenExpiry,
	})
}

// exchange handles the final exchange portion of the 3lo flow. Returns a Token,
// refreshToken, and error.
func (o *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) {
	// Build request
	v := url.Values{
		"grant_type": {"authorization_code"},
		"code":       {code},
	}
	if o.RedirectURL != "" {
		v.Set("redirect_uri", o.RedirectURL)
	}
	if o.AuthHandlerOpts != nil &&
		o.AuthHandlerOpts.PKCEOpts != nil &&
		o.AuthHandlerOpts.PKCEOpts.Verifier != "" {
		v.Set(codeVerifierKey, o.AuthHandlerOpts.PKCEOpts.Verifier)
	}
	for k := range o.URLParams {
		v.Set(k, o.URLParams.Get(k))
	}
	return fetchToken(ctx, o, v)
}

// This struct is not safe for concurrent access alone, but the way it is used
// in this package by wrapping it with a cachedTokenProvider makes it so.
type tokenProvider3LO struct {
	opts         *Options3LO
	client       *http.Client
	refreshToken string
}

func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) {
	if tp.refreshToken == "" {
		return nil, errors.New("auth: token expired and refresh token is not set")
	}
	v := url.Values{
		"grant_type":    {"refresh_token"},
		"refresh_token": {tp.refreshToken},
	}
	for k := range tp.opts.URLParams {
		v.Set(k, tp.opts.URLParams.Get(k))
	}

	tk, rt, err := fetchToken(ctx, tp.opts, v)
	if err != nil {
		return nil, err
	}
	if tp.refreshToken != rt && rt != "" {
		tp.refreshToken = rt
	}
	return tk, err
}

type tokenProviderWithHandler struct {
	opts  *Options3LO
	state string
}

func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) {
	url := tp.opts.authCodeURL(tp.state, nil)
	code, state, err := tp.opts.AuthHandlerOpts.Handler(url)
	if err != nil {
		return nil, err
	}
	if state != tp.state {
		return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow")
	}
	tok, _, err := tp.opts.exchange(ctx, code)
	return tok, err
}

// fetchToken returns a Token, refresh token, and/or an error.
func fetchToken(ctx context.Context, o *Options3LO, v url.Values) (*Token, string, error) {
	var refreshToken string
	if o.AuthStyle == StyleInParams {
		if o.ClientID != "" {
			v.Set("client_id", o.ClientID)
		}
		if o.ClientSecret != "" {
			v.Set("client_secret", o.ClientSecret)
		}
	}
	req, err := http.NewRequest("POST", o.TokenURL, strings.NewReader(v.Encode()))
	if err != nil {
		return nil, refreshToken, err
	}
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
	if o.AuthStyle == StyleInHeader {
		req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret))
	}

	// Make request
	r, err := o.client().Do(req.WithContext(ctx))
	if err != nil {
		return nil, refreshToken, err
	}
	body, err := internal.ReadAll(r.Body)
	r.Body.Close()
	if err != nil {
		return nil, refreshToken, fmt.Errorf("auth: cannot fetch token: %w", err)
	}

	failureStatus := r.StatusCode < 200 || r.StatusCode > 299
	tokError := &Error{
		Response: r,
		Body:     body,
	}

	var token *Token
	// errors ignored because of default switch on content
	content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
	switch content {
	case "application/x-www-form-urlencoded", "text/plain":
		// some endpoints return a query string
		vals, err := url.ParseQuery(string(body))
		if err != nil {
			if failureStatus {
				return nil, refreshToken, tokError
			}
			return nil, refreshToken, fmt.Errorf("auth: cannot parse response: %w", err)
		}
		tokError.code = vals.Get("error")
		tokError.description = vals.Get("error_description")
		tokError.uri = vals.Get("error_uri")
		token = &Token{
			Value:    vals.Get("access_token"),
			Type:     vals.Get("token_type"),
			Metadata: make(map[string]interface{}, len(vals)),
		}
		for k, v := range vals {
			token.Metadata[k] = v
		}
		refreshToken = vals.Get("refresh_token")
		e := vals.Get("expires_in")
		expires, _ := strconv.Atoi(e)
		if expires != 0 {
			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
		}
	default:
		var tj tokenJSON
		if err = json.Unmarshal(body, &tj); err != nil {
			if failureStatus {
				return nil, refreshToken, tokError
			}
			return nil, refreshToken, fmt.Errorf("auth: cannot parse json: %w", err)
		}
		tokError.code = tj.ErrorCode
		tokError.description = tj.ErrorDescription
		tokError.uri = tj.ErrorURI
		token = &Token{
			Value:    tj.AccessToken,
			Type:     tj.TokenType,
			Expiry:   tj.expiry(),
			Metadata: make(map[string]interface{}),
		}
		json.Unmarshal(body, &token.Metadata) // optional field, skip err check
		refreshToken = tj.RefreshToken
	}
	// according to spec, servers should respond status 400 in error case
	// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
	// but some unorthodox servers respond 200 in error case
	if failureStatus || tokError.code != "" {
		return nil, refreshToken, tokError
	}
	if token.Value == "" {
		return nil, refreshToken, errors.New("auth: server response missing access_token")
	}
	return token, refreshToken, nil
}
