blob: 631c82e88dc19b8f2e0ec8cd547f3d8ce79e1a79 [file] [log] [blame]
// Copyright 2020 Google LLC.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package idtoken
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io/ioutil"
"math/big"
"net/http"
"testing"
"time"
"google.golang.org/api/option"
)
const (
keyID = "1234"
testAudience = "test-audience"
expiry int64 = 233431200
)
var (
beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
afterExp = func() time.Time { return time.Unix(expiry+1, 0) }
)
func TestValidateRS256(t *testing.T) {
idToken, pk := createRS256JWT(t)
tests := []struct {
name string
keyID string
n *big.Int
e int
nowFunc func() time.Time
wantErr bool
}{
{
name: "works",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: false,
},
{
name: "no matching key",
keyID: "5678",
n: pk.N,
e: pk.E,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "sig does not match",
keyID: keyID,
n: new(big.Int).SetBytes([]byte("42")),
e: 42,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "token expired",
keyID: keyID,
n: pk.N,
e: pk.E,
nowFunc: afterExp,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{
Transport: RoundTripFn(func(req *http.Request) *http.Response {
cr := certResponse{
Keys: []jwk{
{
Kid: tt.keyID,
N: base64.RawURLEncoding.EncodeToString(tt.n.Bytes()),
E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()),
},
},
}
b, err := json.Marshal(&cr)
if err != nil {
t.Fatalf("unable to marshal response: %v", err)
}
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(b)),
Header: make(http.Header),
}
}),
}
oldNow := now
defer func() { now = oldNow }()
now = tt.nowFunc
v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
if err != nil {
t.Fatalf("NewValidator(...) = %q, want nil", err)
}
payload, err := v.Validate(context.Background(), idToken, testAudience)
if tt.wantErr && err != nil {
// Got the error we wanted.
return
}
if !tt.wantErr && err != nil {
t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err)
}
if tt.wantErr && err == nil {
t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience)
}
if payload == nil {
t.Fatalf("Got nil payload, err: %v", err)
}
if payload.Audience != testAudience {
t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience)
}
if len(payload.Claims) == 0 {
t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims)
}
if got, ok := payload.Claims["aud"]; !ok {
t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims)
} else {
got, ok := got.(string)
if !ok {
t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims)
}
if got != testAudience {
t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got)
}
}
})
}
}
func TestValidateES256(t *testing.T) {
idToken, pk := createES256JWT(t)
tests := []struct {
name string
keyID string
x *big.Int
y *big.Int
nowFunc func() time.Time
wantErr bool
}{
{
name: "works",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: false,
},
{
name: "no matching key",
keyID: "5678",
x: pk.X,
y: pk.Y,
nowFunc: beforeExp,
wantErr: true,
},
{
name: "sig does not match",
keyID: keyID,
x: new(big.Int),
y: new(big.Int),
nowFunc: beforeExp,
wantErr: true,
},
{
name: "token expired",
keyID: keyID,
x: pk.X,
y: pk.Y,
nowFunc: afterExp,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{
Transport: RoundTripFn(func(req *http.Request) *http.Response {
cr := certResponse{
Keys: []jwk{
{
Kid: tt.keyID,
X: base64.RawURLEncoding.EncodeToString(tt.x.Bytes()),
Y: base64.RawURLEncoding.EncodeToString(tt.y.Bytes()),
},
},
}
b, err := json.Marshal(&cr)
if err != nil {
t.Fatalf("unable to marshal response: %v", err)
}
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(b)),
Header: make(http.Header),
}
}),
}
oldNow := now
defer func() { now = oldNow }()
now = tt.nowFunc
v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
if err != nil {
t.Fatalf("NewValidator(...) = %q, want nil", err)
}
payload, err := v.Validate(context.Background(), idToken, testAudience)
if !tt.wantErr && err != nil {
t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
}
if !tt.wantErr && payload.Audience != testAudience {
t.Fatalf("got %v, want %v", payload.Audience, testAudience)
}
})
}
}
func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) {
t.Helper()
token := commonToken(t, "ES256")
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("unable to generate key: %v", err)
}
r, s, err := ecdsa.Sign(rand.Reader, privateKey, token.hashedContent())
if err != nil {
t.Fatalf("unable to sign content: %v", err)
}
rb := r.Bytes()
lPadded := make([]byte, es256KeySize)
copy(lPadded[es256KeySize-len(rb):], rb)
var sig []byte
sig = append(sig, lPadded...)
sig = append(sig, s.Bytes()...)
token.signature = base64.RawURLEncoding.EncodeToString(sig)
return token.String(), privateKey.PublicKey
}
func createRS256JWT(t *testing.T) (string, rsa.PublicKey) {
t.Helper()
token := commonToken(t, "RS256")
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("unable to generate key: %v", err)
}
sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, token.hashedContent())
if err != nil {
t.Fatalf("unable to sign content: %v", err)
}
token.signature = base64.RawURLEncoding.EncodeToString(sig)
return token.String(), privateKey.PublicKey
}
func commonToken(t *testing.T, alg string) *jwt {
t.Helper()
header := jwtHeader{
KeyID: keyID,
Algorithm: alg,
Type: "JWT",
}
payload := Payload{
Issuer: "example.com",
Audience: testAudience,
Expires: expiry,
}
hb, err := json.Marshal(&header)
if err != nil {
t.Fatalf("unable to marshall header: %v", err)
}
pb, err := json.Marshal(&payload)
if err != nil {
t.Fatalf("unable to marshall payload: %v", err)
}
eb := base64.RawURLEncoding.EncodeToString(hb)
ep := base64.RawURLEncoding.EncodeToString(pb)
return &jwt{
header: eb,
payload: ep,
}
}
type RoundTripFn func(req *http.Request) *http.Response
func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }