blob: 456a06196a64ac6ff75a5bc686ecd71122260cb5 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package grpctransport
import (
"context"
"errors"
"log"
"net"
"testing"
"cloud.google.com/go/auth"
"cloud.google.com/go/auth/credentials"
echo "cloud.google.com/go/auth/grpctransport/testdata"
"cloud.google.com/go/auth/internal"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)
func TestCheckDirectPathEndPoint(t *testing.T) {
for _, testcase := range []struct {
name string
endpoint string
want bool
}{
{
name: "empty endpoint are disallowed",
endpoint: "",
want: false,
},
{
name: "dns schemes are allowed",
endpoint: "dns:///foo",
want: true,
},
{
name: "host without no prefix are allowed",
endpoint: "foo",
want: true,
},
{
name: "host with port are allowed",
endpoint: "foo:1234",
want: true,
},
{
name: "non-dns schemes are disallowed",
endpoint: "https://foo",
want: false,
},
} {
t.Run(testcase.name, func(t *testing.T) {
if got := checkDirectPathEndPoint(testcase.endpoint); got != testcase.want {
t.Fatalf("got %v, want %v", got, testcase.want)
}
})
}
}
func TestDial_FailsValidation(t *testing.T) {
tests := []struct {
name string
opts *Options
}{
{
name: "missing options",
},
{
name: "has creds with disable options, tp",
opts: &Options{
DisableAuthentication: true,
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
}),
},
},
{
name: "has creds with disable options, cred file",
opts: &Options{
DisableAuthentication: true,
DetectOpts: &credentials.DetectOptions{
CredentialsFile: "abc.123",
},
},
},
{
name: "has creds with disable options, cred json",
opts: &Options{
DisableAuthentication: true,
DetectOpts: &credentials.DetectOptions{
CredentialsJSON: []byte(`{"foo":"bar"}`),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Dial(context.Background(), false, tt.opts)
if err == nil {
t.Fatal("NewClient() = _, nil, want error")
}
})
}
}
func TestDial_SkipValidation(t *testing.T) {
opts := &Options{
DisableAuthentication: true,
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
}),
}
t.Run("invalid opts", func(t *testing.T) {
if err := opts.validate(); err == nil {
t.Fatalf("opts.validate() = nil, want error")
}
})
t.Run("skip invalid opts", func(t *testing.T) {
opts.InternalOptions = &InternalOptions{SkipValidation: true}
if err := opts.validate(); err != nil {
t.Fatalf("opts.validate() = %v, want nil", err)
}
})
}
func TestOptions_ResolveDetectOptions(t *testing.T) {
tests := []struct {
name string
in *Options
want *credentials.DetectOptions
}{
{
name: "base",
in: &Options{
DetectOpts: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
},
},
{
name: "self-signed, with scope",
in: &Options{
InternalOptions: &InternalOptions{
EnableJWTWithScope: true,
},
DetectOpts: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
UseSelfSignedJWT: true,
},
},
{
name: "self-signed, with aud",
in: &Options{
DetectOpts: &credentials.DetectOptions{
Audience: "aud",
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Audience: "aud",
CredentialsFile: "/path/to/a/file",
UseSelfSignedJWT: true,
},
},
{
name: "use default scopes",
in: &Options{
InternalOptions: &InternalOptions{
DefaultScopes: []string{"default"},
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"default"},
CredentialsFile: "/path/to/a/file",
},
},
{
name: "don't use default scopes, scope provided",
in: &Options{
InternalOptions: &InternalOptions{
DefaultScopes: []string{"default"},
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
Scopes: []string{"non-default"},
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"non-default"},
CredentialsFile: "/path/to/a/file",
},
},
{
name: "don't use default scopes, aud provided",
in: &Options{
InternalOptions: &InternalOptions{
DefaultScopes: []string{"default"},
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
Audience: "non-default",
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Audience: "non-default",
CredentialsFile: "/path/to/a/file",
UseSelfSignedJWT: true,
},
},
{
name: "use default aud",
in: &Options{
InternalOptions: &InternalOptions{
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Audience: "default",
CredentialsFile: "/path/to/a/file",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.in.resolveDetectOptions()
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) {
nonDefault := "example.com"
tests := []struct {
name string
universeDomain string
want string
}{
{
name: "default",
universeDomain: "",
want: internal.DefaultUniverseDomain,
},
{
name: "non-default",
universeDomain: nonDefault,
want: nonDefault,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
at := &grpcCredentialsProvider{clientUniverseDomain: tt.universeDomain}
got := at.getClientUniverseDomain()
if got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
func TestGrpcCredentialsProvider_TokenType(t *testing.T) {
tests := []struct {
name string
tok *auth.Token
want string
}{
{
name: "type set",
tok: &auth.Token{
Value: "token",
Type: "Basic",
},
want: "Basic token",
},
{
name: "type set",
tok: &auth.Token{
Value: "token",
},
want: "Bearer token",
},
}
for _, tc := range tests {
cp := grpcCredentialsProvider{
creds: &auth.Credentials{
TokenProvider: &staticTP{tok: tc.tok},
},
}
m, err := cp.GetRequestMetadata(context.Background(), "")
if err != nil {
log.Fatalf("cp.GetRequestMetadata() = %v, want nil", err)
}
if got := m["authorization"]; got != tc.want {
t.Fatalf("got %q, want %q", got, tc.want)
}
}
}
func TestNewClient_DetectedServiceAccount(t *testing.T) {
testQuota := "testquota"
wantHeader := "bar"
t.Setenv(internal.QuotaProjectEnvVar, testQuota)
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
gsrv := grpc.NewServer()
defer gsrv.Stop()
echo.RegisterEchoerServer(gsrv, &fakeEchoService{
Fn: func(ctx context.Context, _ *echo.EchoRequest) (*echo.EchoReply, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
t.Error("unable to extract metadata")
return nil, errors.New("oops")
}
if got := md.Get("authorization"); len(got) != 1 {
t.Errorf(`got "", want an auth token`)
}
if got := md.Get("Foo"); len(got) != 1 || got[0] != wantHeader {
t.Errorf("got %q, want %q", got, wantHeader)
}
if got := md.Get(quotaProjectHeaderKey); len(got) != 1 || got[0] != testQuota {
t.Errorf("got %q, want %q", got, testQuota)
}
return &echo.EchoReply{}, nil
},
})
go func() {
if err := gsrv.Serve(l); err != nil {
panic(err)
}
}()
pool, err := Dial(context.Background(), false, &Options{
Metadata: map[string]string{"Foo": wantHeader},
InternalOptions: &InternalOptions{
DefaultEndpointTemplate: l.Addr().String(),
},
DetectOpts: &credentials.DetectOptions{
Audience: l.Addr().String(),
CredentialsFile: "../internal/testdata/sa_universe_domain.json",
UseSelfSignedJWT: true,
},
GRPCDialOpts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
UniverseDomain: "example.com", // Also configured in sa_universe_domain.json
})
if err != nil {
t.Fatalf("NewClient() = %v", err)
}
client := echo.NewEchoerClient(pool)
if _, err := client.Echo(context.Background(), &echo.EchoRequest{}); err != nil {
t.Fatalf("client.Echo() = %v", err)
}
}
type staticTP struct {
tok *auth.Token
}
func (tp *staticTP) Token(context.Context) (*auth.Token, error) {
return tp.tok, nil
}
type fakeEchoService struct {
Fn func(context.Context, *echo.EchoRequest) (*echo.EchoReply, error)
echo.UnimplementedEchoerServer
}
func (s *fakeEchoService) Echo(c context.Context, r *echo.EchoRequest) (*echo.EchoReply, error) {
return s.Fn(c, r)
}