blob: ce574ab405ef6308f3f91a3bccb4c99f0165b7d7 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httptransport
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"cloud.google.com/go/auth"
"cloud.google.com/go/auth/credentials"
"cloud.google.com/go/auth/internal"
"github.com/google/go-cmp/cmp"
)
func TestAddAuthorizationMiddleware(t *testing.T) {
creds := auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: staticTP("fakeToken"),
})
tests := []struct {
name string
client *http.Client
creds *auth.Credentials
wantErr bool
want string
}{
{
name: "missing both required fields",
wantErr: true,
},
{
name: "missing client field",
creds: creds,
wantErr: true,
},
{
name: "missing creds field",
client: internal.CloneDefaultClient(),
wantErr: true,
},
{
name: "works",
client: internal.CloneDefaultClient(),
creds: creds,
want: "fakeToken",
},
{
name: "works, no transport",
client: &http.Client{},
creds: creds,
want: "fakeToken",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := AddAuthorizationMiddleware(tt.client, tt.creds)
if tt.wantErr && err == nil {
t.Fatalf("AddAuthorizationMiddleware() = nil, want error")
}
if tt.wantErr {
return
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("Authorization")
if !strings.Contains(got, tt.want) {
t.Errorf("got %q, want contain %q", got, tt.want)
}
}))
defer ts.Close()
tt.client.Get(ts.URL)
})
}
}
func TestNewClient_FailsValidation(t *testing.T) {
tests := []struct {
name string
opts *Options
}{
{
name: "missing options",
},
{
name: "has creds with disable options, tp",
opts: &Options{
DisableAuthentication: true,
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: staticTP("fakeToken"),
}),
},
},
{
name: "has creds with disable options, cred file",
opts: &Options{
DisableAuthentication: true,
DetectOpts: &credentials.DetectOptions{
CredentialsFile: "abc.123",
},
},
},
{
name: "has creds with disable options, cred json",
opts: &Options{
DisableAuthentication: true,
DetectOpts: &credentials.DetectOptions{
CredentialsJSON: []byte(`{"foo":"bar"}`),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewClient(tt.opts)
if err == nil {
t.Fatal("NewClient() = _, nil, want error")
}
})
}
}
func TestDial_SkipValidation(t *testing.T) {
opts := &Options{
DisableAuthentication: true,
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: staticTP("fakeToken"),
}),
}
t.Run("invalid opts", func(t *testing.T) {
if err := opts.validate(); err == nil {
t.Fatalf("opts.validate() = nil, want error")
}
})
t.Run("skip invalid opts", func(t *testing.T) {
opts.InternalOptions = &InternalOptions{SkipValidation: true}
if err := opts.validate(); err != nil {
t.Fatalf("opts.validate() = %v, want nil", err)
}
})
}
func TestOptions_ResolveDetectOptions(t *testing.T) {
tests := []struct {
name string
in *Options
want *credentials.DetectOptions
}{
{
name: "base",
in: &Options{
DetectOpts: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
},
},
{
name: "self-signed, with scope",
in: &Options{
InternalOptions: &InternalOptions{
EnableJWTWithScope: true,
},
DetectOpts: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"scope"},
CredentialsFile: "/path/to/a/file",
UseSelfSignedJWT: true,
},
},
{
name: "self-signed, with aud",
in: &Options{
DetectOpts: &credentials.DetectOptions{
Audience: "aud",
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Audience: "aud",
CredentialsFile: "/path/to/a/file",
UseSelfSignedJWT: true,
},
},
{
name: "use default scopes",
in: &Options{
InternalOptions: &InternalOptions{
DefaultScopes: []string{"default"},
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"default"},
CredentialsFile: "/path/to/a/file",
},
},
{
name: "don't use default scopes, scope provided",
in: &Options{
InternalOptions: &InternalOptions{
DefaultScopes: []string{"default"},
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
Scopes: []string{"non-default"},
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Scopes: []string{"non-default"},
CredentialsFile: "/path/to/a/file",
},
},
{
name: "don't use default scopes, aud provided",
in: &Options{
InternalOptions: &InternalOptions{
DefaultScopes: []string{"default"},
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
Audience: "non-default",
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Audience: "non-default",
CredentialsFile: "/path/to/a/file",
UseSelfSignedJWT: true,
},
},
{
name: "use default aud",
in: &Options{
InternalOptions: &InternalOptions{
DefaultAudience: "default",
},
DetectOpts: &credentials.DetectOptions{
CredentialsFile: "/path/to/a/file",
},
},
want: &credentials.DetectOptions{
Audience: "default",
CredentialsFile: "/path/to/a/file",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.in.resolveDetectOptions()
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestNewClient_DetectedServiceAccount(t *testing.T) {
testQuota := "testquota"
wantHeader := "bar"
t.Setenv(internal.QuotaProjectEnvVar, testQuota)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got == "" {
t.Errorf(`got "", want an auth token`)
}
if got := r.Header.Get("Foo"); got != wantHeader {
t.Errorf("got %q, want %q", got, wantHeader)
}
if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota {
t.Errorf("got %q, want %q", got, testQuota)
}
}))
defer ts.Close()
client, err := NewClient(&Options{
Headers: http.Header{"Foo": []string{wantHeader}},
InternalOptions: &InternalOptions{
DefaultEndpointTemplate: ts.URL,
},
DetectOpts: &credentials.DetectOptions{
Audience: ts.URL,
CredentialsFile: "../internal/testdata/sa.json",
UseSelfSignedJWT: true,
},
})
if err != nil {
t.Fatalf("NewClient() = %v", err)
}
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
if _, err := client.Do(req); err != nil {
t.Fatalf("client.Get() = %v", err)
}
}
func TestNewClient_APIKey(t *testing.T) {
testQuota := "testquota"
apiKey := "thereisnospoon"
wantHeader := "bar"
t.Setenv(internal.QuotaProjectEnvVar, testQuota)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.URL.Query().Get("key")
if got != apiKey {
t.Errorf("got %q, want %q", got, apiKey)
}
if got := r.Header.Get("Foo"); got != wantHeader {
t.Errorf("got %q, want %q", got, wantHeader)
}
if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota {
t.Errorf("got %q, want %q", got, testQuota)
}
}))
defer ts.Close()
client, err := NewClient(&Options{
APIKey: apiKey,
Headers: http.Header{"Foo": []string{wantHeader}},
})
if err != nil {
t.Fatalf("NewClient() = %v", err)
}
if _, err := client.Get(ts.URL); err != nil {
t.Fatalf("client.Get() = %v", err)
}
}
func TestNewClient_BaseRoundTripper(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("Foo")
if want := "foo"; got != want {
t.Errorf("got %q, want %q", got, want)
}
got = r.Header.Get("Bar")
if want := "bar"; got != want {
t.Errorf("got %q, want %q", got, want)
}
}))
defer ts.Close()
client, err := NewClient(&Options{
BaseRoundTripper: &rt{key: "Bar", value: "bar"},
Headers: http.Header{"Foo": []string{"foo"}},
APIKey: "key",
})
if err != nil {
t.Fatalf("NewClient() = %v", err)
}
if _, err := client.Get(ts.URL); err != nil {
t.Fatalf("client.Get() = %v", err)
}
}
type staticTP string
func (tp staticTP) Token(context.Context) (*auth.Token, error) {
return &auth.Token{
Value: string(tp),
}, nil
}
type rt struct {
key string
value string
}
func (r *rt) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
req2.Header.Add(r.key, r.value)
return http.DefaultTransport.RoundTrip(req2)
}