blob: 37c5e1aa1e02f2b6800ed6082539698588c857d9 [file] [log] [blame]
// Copyright 2019 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
// HeaderChecker defines header checking and validation rules for any outgoing metadata.
type HeaderChecker struct {
// Key is the header name to be checked against e.g. "x-goog-api-client".
Key string
// ValuesValidator validates the header values retrieved from mapping against
// Key in the Headers.
ValuesValidator func(values ...string) error
// HeadersEnforcer asserts that outgoing RPC headers
// are present and match expectations. If the expected headers
// are not present or don't match expectations, it'll invoke OnFailure
// with the validation error, or instead log.Fatal if OnFailure is nil.
// It expects that every declared key will be present in the outgoing
// RPC header and each value will be validated by the validation function.
type HeadersEnforcer struct {
// Checkers maps header keys that are expected to be sent in the metadata
// of outgoing gRPC requests, against the values passed into the custom
// validation functions.
// If Checkers is nil or empty, only the default header "x-goog-api-client"
// will be checked for.
// Otherwise, if you supply Matchers, those keys and their respective
// validation functions will be checked.
Checkers []*HeaderChecker
// OnFailure is the function that will be invoked after all validation
// failures have been composed. If OnFailure is nil, log.Fatal will be
// invoked instead.
OnFailure func(fmt_ string, args ...interface{})
// DialOptions returns gRPC DialOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) DialOptions() []grpc.DialOption {
return []grpc.DialOption{
// CallOptions returns ClientOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) {
dopts := h.DialOptions()
for _, dopt := range dopts {
copts = append(copts, option.WithGRPCDialOption(dopt))
func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
h.checkMetadata(ctx, method)
return invoker(ctx, method, req, res, cc, opts...)
func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
h.checkMetadata(ctx, method)
return streamer(ctx, desc, cc, method, opts...)
// XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client"
// header is present on outgoing metadata.
var XGoogClientHeaderChecker = &HeaderChecker{
Key: "x-goog-api-client",
ValuesValidator: func(values ...string) error {
if len(values) == 0 {
return errors.New("expecting values")
for _, value := range values {
switch {
case strings.Contains(value, "gl-go/"):
// TODO: check for exact version strings.
return nil
default: // Add others here.
return errors.New("unmatched values")
// DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that
// the "x-goog-api-client" key is present in the outgoing metadata headers. On any
// validation failure, it will invoke log.Fatalf with the error message.
func DefaultHeadersEnforcer() *HeadersEnforcer {
return &HeadersEnforcer{
Checkers: []*HeaderChecker{XGoogClientHeaderChecker},
func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) {
onFailure := h.OnFailure
if onFailure == nil {
lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs.
onFailure = func(fmt_ string, args ...interface{}) {
lgr.Fatalf(fmt_, args...)
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
onFailure("Missing metadata for method %q", method)
checkers := h.Checkers
if len(checkers) == 0 {
// Instead use the default HeaderChecker.
checkers = append(checkers, XGoogClientHeaderChecker)
errBuf := new(bytes.Buffer)
for _, checker := range checkers {
hdrKey := checker.Key
outHdrValues, ok := md[hdrKey]
if !ok {
fmt.Fprintf(errBuf, "missing header %q\n", hdrKey)
if err := checker.ValuesValidator(outHdrValues...); err != nil {
fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err)
if errBuf.Len() != 0 {
onFailure("For method %q, errors:\n%s", method, errBuf)