blob: d90b1c63d7c9c877dd4b622ebba50ee0d4c1ab29 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"os"
"strings"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
// HeaderChecker defines header checking and validation rules for any outgoing metadata.
type HeaderChecker struct {
// Key is the header name to be checked against e.g. "x-goog-api-client".
Key string
// ValuesValidator validates the header values retrieved from mapping against
// Key in the Headers.
ValuesValidator func(values ...string) error
}
// HeadersEnforcer asserts that outgoing RPC headers
// are present and match expectations. If the expected headers
// are not present or don't match expectations, it'll invoke OnFailure
// with the validation error, or instead log.Fatal if OnFailure is nil.
//
// It expects that every declared key will be present in the outgoing
// RPC header and each value will be validated by the validation function.
type HeadersEnforcer struct {
// Checkers maps header keys that are expected to be sent in the metadata
// of outgoing gRPC requests, against the values passed into the custom
// validation functions.
//
// If Checkers is nil or empty, only the default header "x-goog-api-client"
// will be checked for.
// Otherwise, if you supply Matchers, those keys and their respective
// validation functions will be checked.
Checkers []*HeaderChecker
// OnFailure is the function that will be invoked after all validation
// failures have been composed. If OnFailure is nil, log.Fatal will be
// invoked instead.
OnFailure func(fmt_ string, args ...interface{})
}
// StreamInterceptors returns a list of StreamClientInterceptor functions which
// enforce the presence and validity of expected headers during streaming RPCs.
//
// For client implementations which provide their own StreamClientInterceptor(s)
// these interceptors should be specified as the final elements to
// WithChainStreamInterceptor.
//
// Alternatively, users may apply gPRC options produced from DialOptions to
// apply all applicable gRPC interceptors.
func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor {
return []grpc.StreamClientInterceptor{h.interceptStream}
}
// UnaryInterceptors returns a list of UnaryClientInterceptor functions which
// enforce the presence and validity of expected headers during unary RPCs.
//
// For client implementations which provide their own UnaryClientInterceptor(s)
// these interceptors should be specified as the final elements to
// WithChainUnaryInterceptor.
//
// Alternatively, users may apply gPRC options produced from DialOptions to
// apply all applicable gRPC interceptors.
func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor {
return []grpc.UnaryClientInterceptor{h.interceptUnary}
}
// DialOptions returns gRPC DialOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) DialOptions() []grpc.DialOption {
return []grpc.DialOption{
grpc.WithChainStreamInterceptor(h.interceptStream),
grpc.WithChainUnaryInterceptor(h.interceptUnary),
}
}
// CallOptions returns ClientOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) {
dopts := h.DialOptions()
for _, dopt := range dopts {
copts = append(copts, option.WithGRPCDialOption(dopt))
}
return
}
func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
h.checkMetadata(ctx, method)
return invoker(ctx, method, req, res, cc, opts...)
}
func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
h.checkMetadata(ctx, method)
return streamer(ctx, desc, cc, method, opts...)
}
// XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client"
// header is present on outgoing metadata.
var XGoogClientHeaderChecker = &HeaderChecker{
Key: "x-goog-api-client",
ValuesValidator: func(values ...string) error {
if len(values) == 0 {
return errors.New("expecting values")
}
for _, value := range values {
switch {
case strings.Contains(value, "gl-go/"):
// TODO: check for exact version strings.
return nil
default: // Add others here.
}
}
return errors.New("unmatched values")
},
}
// DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that
// the "x-goog-api-client" key is present in the outgoing metadata headers. On any
// validation failure, it will invoke log.Fatalf with the error message.
func DefaultHeadersEnforcer() *HeadersEnforcer {
return &HeadersEnforcer{
Checkers: []*HeaderChecker{XGoogClientHeaderChecker},
}
}
func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) {
onFailure := h.OnFailure
if onFailure == nil {
lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs.
onFailure = func(fmt_ string, args ...interface{}) {
lgr.Fatalf(fmt_, args...)
}
}
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
onFailure("Missing metadata for method %q", method)
return
}
checkers := h.Checkers
if len(checkers) == 0 {
// Instead use the default HeaderChecker.
checkers = append(checkers, XGoogClientHeaderChecker)
}
errBuf := new(bytes.Buffer)
for _, checker := range checkers {
hdrKey := checker.Key
outHdrValues, ok := md[hdrKey]
if !ok {
fmt.Fprintf(errBuf, "missing header %q\n", hdrKey)
continue
}
if err := checker.ValuesValidator(outHdrValues...); err != nil {
fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err)
}
}
if errBuf.Len() != 0 {
onFailure("For method %q, errors:\n%s", method, errBuf)
return
}
}