blob: afabe814c3a2b27227f0cfde97860eb88318ef7c [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package idtoken
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"testing"
"time"
"cloud.google.com/go/auth/internal/jwt"
)
const (
keyID = "1234"
testAudience = "test-audience"
expiry int64 = 233431200
)
var (
beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
afterExp = func() time.Time { return time.Unix(expiry+1, 0) }
)
func TestValidateRS256(t *testing.T) {
idToken, pk := createRS256JWT(t)
tests := []struct {
name string
keyID string
n *big.Int
e int
nowFunc func() time.Time
wantErr bool
}{
{
name: "works",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: false,
},
{
name: "no matching key",
keyID: "5678",
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "sig does not match",
keyID: keyID,
n: new(big.Int).SetBytes([]byte("42")),
e: 42,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "token expired",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: afterExp,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{
Transport: RoundTripFn(func(req *http.Request) *http.Response {
cr := certResponse{
Keys: []jwk{
{
Kid: tt.keyID,
N: base64.RawURLEncoding.EncodeToString(tt.n.Bytes()),
E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()),
},
},
}
b, err := json.Marshal(&cr)
if err != nil {
t.Fatalf("unable to marshal response: %v", err)
}
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(b)),
Header: make(http.Header),
}
}),
}
oldNow := now
defer func() { now = oldNow }()
now = tt.nowFunc
v, err := NewValidator(&ValidatorOptions{
Client: client,
})
if err != nil {
t.Fatalf("NewValidator(...) = %q, want nil", err)
}
payload, err := v.Validate(context.Background(), idToken, testAudience)
if tt.wantErr && err != nil {
// Got the error we wanted.
return
}
if !tt.wantErr && err != nil {
t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err)
}
if tt.wantErr && err == nil {
t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience)
}
if payload == nil {
t.Fatalf("Got nil payload, err: %v", err)
}
if payload.Audience != testAudience {
t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience)
}
if len(payload.Claims) == 0 {
t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims)
}
if got, ok := payload.Claims["aud"]; !ok {
t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims)
} else {
got, ok := got.(string)
if !ok {
t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims)
}
if got != testAudience {
t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got)
}
}
})
}
}
func TestValidateES256(t *testing.T) {
idToken, pk := createES256JWT(t)
tests := []struct {
name string
keyID string
x *big.Int
y *big.Int
nowFunc func() time.Time
wantErr bool
}{
{
name: "works",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: false,
},
{
name: "no matching key",
keyID: "5678",
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "sig does not match",
keyID: keyID,
x: new(big.Int),
y: new(big.Int),
nowFunc: beforeExp,
wantErr: true,
},
{
name: "token expired",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: afterExp,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{
Transport: RoundTripFn(func(req *http.Request) *http.Response {
cr := certResponse{
Keys: []jwk{
{
Kid: tt.keyID,
X: base64.RawURLEncoding.EncodeToString(tt.x.Bytes()),
Y: base64.RawURLEncoding.EncodeToString(tt.y.Bytes()),
},
},
}
b, err := json.Marshal(&cr)
if err != nil {
t.Fatalf("unable to marshal response: %v", err)
}
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(b)),
Header: make(http.Header),
}
}),
}
oldNow := now
defer func() { now = oldNow }()
now = tt.nowFunc
v, err := NewValidator(&ValidatorOptions{
Client: client,
})
if err != nil {
t.Fatalf("NewValidator(...) = %q, want nil", err)
}
payload, err := v.Validate(context.Background(), idToken, testAudience)
if !tt.wantErr && err != nil {
t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
}
if !tt.wantErr && payload.Audience != testAudience {
t.Fatalf("got %v, want %v", payload.Audience, testAudience)
}
})
}
}
func TestParsePayload(t *testing.T) {
idToken, _ := createRS256JWT(t)
tests := []struct {
name string
token string
wantPayloadAudience string
wantErr bool
}{{
name: "valid token",
token: idToken,
wantPayloadAudience: testAudience,
}, {
name: "unparseable token",
token: "aaa.bbb.ccc",
wantErr: true,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
payload, err := ParsePayload(tt.token)
gotErr := err != nil
if gotErr != tt.wantErr {
t.Errorf("ParsePayload(%q) got error %v, wantErr = %v", tt.token, err, tt.wantErr)
}
if tt.wantPayloadAudience != "" {
if payload == nil || payload.Audience != tt.wantPayloadAudience {
t.Errorf("ParsePayload(%q) got payload %+v, want payload with audience = %q", tt.token, payload, tt.wantPayloadAudience)
}
}
})
}
}
func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) {
t.Helper()
header, claims := commonToken(t, "ES256")
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("unable to generate key: %v", err)
}
signedContent := header + "." + claims
hashed := sha256.Sum256([]byte(signedContent))
hash := hashed[:]
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash)
if err != nil {
t.Fatalf("unable to sign content: %v", err)
}
rb := r.Bytes()
lPadded := make([]byte, es256KeySize)
copy(lPadded[es256KeySize-len(rb):], rb)
var sig []byte
sig = append(sig, lPadded...)
sig = append(sig, s.Bytes()...)
signature := base64.RawURLEncoding.EncodeToString(sig)
return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey
}
func createRS256JWT(t *testing.T) (string, rsa.PublicKey) {
t.Helper()
header, claims := commonToken(t, jwt.HeaderAlgRSA256)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("unable to generate key: %v", err)
}
signedContent := header + "." + claims
hashed := sha256.Sum256([]byte(signedContent))
hash := hashed[:]
sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hash)
if err != nil {
t.Fatalf("unable to sign content: %v", err)
}
signature := base64.RawURLEncoding.EncodeToString(sig)
return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey
}
// returns header and claims
func commonToken(t *testing.T, alg string) (string, string) {
t.Helper()
header := jwt.Header{
KeyID: keyID,
Algorithm: alg,
Type: jwt.HeaderType,
}
payload := Payload{
Issuer: "example.com",
Audience: testAudience,
Expires: expiry,
}
hb, err := json.Marshal(&header)
if err != nil {
t.Fatalf("unable to marshall header: %v", err)
}
pb, err := json.Marshal(&payload)
if err != nil {
t.Fatalf("unable to marshall payload: %v", err)
}
eb := base64.RawURLEncoding.EncodeToString(hb)
ep := base64.RawURLEncoding.EncodeToString(pb)
return eb, ep
}
type RoundTripFn func(req *http.Request) *http.Response
func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }