// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testutil

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"log"
	"os"
	"strings"

	"google.golang.org/api/option"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

// HeaderChecker defines header checking and validation rules for any outgoing metadata.
type HeaderChecker struct {
	// Key is the header name to be checked against e.g. "x-goog-api-client".
	Key string

	// ValuesValidator validates the header values retrieved from mapping against
	// Key in the Headers.
	ValuesValidator func(values ...string) error
}

// HeadersEnforcer asserts that outgoing RPC headers
// are present and match expectations. If the expected headers
// are not present or don't match expectations, it'll invoke OnFailure
// with the validation error, or instead log.Fatal if OnFailure is nil.
//
// It expects that every declared key will be present in the outgoing
// RPC header and each value will be validated by the validation function.
type HeadersEnforcer struct {
	// Checkers maps header keys that are expected to be sent in the metadata
	// of outgoing gRPC requests, against the values passed into the custom
	// validation functions.
	//
	// If Checkers is nil or empty, only the default header "x-goog-api-client"
	// will be checked for.
	// Otherwise, if you supply Matchers, those keys and their respective
	// validation functions will be checked.
	Checkers []*HeaderChecker

	// OnFailure is the function that will be invoked after all validation
	// failures have been composed. If OnFailure is nil, log.Fatal will be
	// invoked instead.
	OnFailure func(fmt_ string, args ...interface{})
}

// StreamInterceptors returns a list of StreamClientInterceptor functions which
// enforce the presence and validity of expected headers during streaming RPCs.
//
// For client implementations which provide their own StreamClientInterceptor(s)
// these interceptors should be specified as the final elements to
// WithChainStreamInterceptor.
//
// Alternatively, users may apply gPRC options produced from DialOptions to
// apply all applicable gRPC interceptors.
func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor {
	return []grpc.StreamClientInterceptor{h.interceptStream}
}

// UnaryInterceptors returns a list of UnaryClientInterceptor functions which
// enforce the presence and validity of expected headers during unary RPCs.
//
// For client implementations which provide their own UnaryClientInterceptor(s)
// these interceptors should be specified as the final elements to
// WithChainUnaryInterceptor.
//
// Alternatively, users may apply gPRC options produced from DialOptions to
// apply all applicable gRPC interceptors.
func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor {
	return []grpc.UnaryClientInterceptor{h.interceptUnary}
}

// DialOptions returns gRPC DialOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) DialOptions() []grpc.DialOption {
	return []grpc.DialOption{
		grpc.WithStreamInterceptor(h.interceptStream),
		grpc.WithUnaryInterceptor(h.interceptUnary),
	}
}

// CallOptions returns ClientOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) {
	dopts := h.DialOptions()
	for _, dopt := range dopts {
		copts = append(copts, option.WithGRPCDialOption(dopt))
	}
	return
}

func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
	h.checkMetadata(ctx, method)
	return invoker(ctx, method, req, res, cc, opts...)
}

func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
	h.checkMetadata(ctx, method)
	return streamer(ctx, desc, cc, method, opts...)
}

// XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client"
// header is present on outgoing metadata.
var XGoogClientHeaderChecker = &HeaderChecker{
	Key: "x-goog-api-client",
	ValuesValidator: func(values ...string) error {
		if len(values) == 0 {
			return errors.New("expecting values")
		}
		for _, value := range values {
			switch {
			case strings.Contains(value, "gl-go/"):
				// TODO: check for exact version strings.
				return nil

			default: // Add others here.
			}
		}
		return errors.New("unmatched values")
	},
}

// DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that
// the "x-goog-api-client" key is present in the outgoing metadata headers. On any
// validation failure, it will invoke log.Fatalf with the error message.
func DefaultHeadersEnforcer() *HeadersEnforcer {
	return &HeadersEnforcer{
		Checkers: []*HeaderChecker{XGoogClientHeaderChecker},
	}
}

func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) {
	onFailure := h.OnFailure
	if onFailure == nil {
		lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs.
		onFailure = func(fmt_ string, args ...interface{}) {
			lgr.Fatalf(fmt_, args...)
		}
	}

	md, ok := metadata.FromOutgoingContext(ctx)
	if !ok {
		onFailure("Missing metadata for method %q", method)
		return
	}
	checkers := h.Checkers
	if len(checkers) == 0 {
		// Instead use the default HeaderChecker.
		checkers = append(checkers, XGoogClientHeaderChecker)
	}

	errBuf := new(bytes.Buffer)
	for _, checker := range checkers {
		hdrKey := checker.Key
		outHdrValues, ok := md[hdrKey]
		if !ok {
			fmt.Fprintf(errBuf, "missing header %q\n", hdrKey)
			continue
		}
		if err := checker.ValuesValidator(outHdrValues...); err != nil {
			fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err)
		}
	}

	if errBuf.Len() != 0 {
		onFailure("For method %q, errors:\n%s", method, errBuf)
		return
	}
}
