idtoken: validate if token is expired (#492)
Updates: #484
diff --git a/idtoken/validate.go b/idtoken/validate.go
index 518528b..d614c90 100644
--- a/idtoken/validate.go
+++ b/idtoken/validate.go
@@ -17,6 +17,7 @@
"math/big"
"net/http"
"strings"
+ "time"
htransport "google.golang.org/api/transport/http"
)
@@ -27,7 +28,11 @@
googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs"
)
-var defaultValidator = &Validator{client: newCachingClient(http.DefaultClient)}
+var (
+ defaultValidator = &Validator{client: newCachingClient(http.DefaultClient)}
+ // now aliases time.Now for testing.
+ now = time.Now
+)
// Payload represents a decoded payload of an ID Token.
type Payload struct {
@@ -129,6 +134,10 @@
return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT")
}
+ if now().Unix() > payload.Expires {
+ return nil, fmt.Errorf("idtoken: token expired")
+ }
+
switch header.Algorithm {
case "RS256":
if err := v.validateRS256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil {
diff --git a/idtoken/validate_test.go b/idtoken/validate_test.go
index 54ff07c..15a32eb 100644
--- a/idtoken/validate_test.go
+++ b/idtoken/validate_test.go
@@ -18,13 +18,20 @@
"math/big"
"net/http"
"testing"
+ "time"
"google.golang.org/api/option"
)
const (
- keyID = "1234"
- testAudience = "test-audience"
+ keyID = "1234"
+ testAudience = "test-audience"
+ expiry int64 = 233431200
+)
+
+var (
+ beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
+ afterExp = func() time.Time { return time.Unix(expiry+1, 0) }
)
func TestValidateRS256(t *testing.T) {
@@ -34,11 +41,41 @@
keyID string
n *big.Int
e int
+ nowFunc func() time.Time
wantErr bool
}{
- {name: "works", keyID: keyID, n: pk.N, e: pk.E, wantErr: false},
- {name: "no matching key", keyID: "5678", n: pk.N, e: pk.E, wantErr: true},
- {name: "sig does not match", keyID: keyID, n: new(big.Int).SetBytes([]byte("42")), e: 42, wantErr: true},
+ {
+ name: "works",
+ keyID: keyID,
+ n: pk.N,
+ e: pk.E,
+ nowFunc: beforeExp,
+ wantErr: false,
+ },
+ {
+ name: "no matching key",
+ keyID: "5678",
+ n: pk.N,
+ e: pk.E,
+ nowFunc: beforeExp,
+ wantErr: true,
+ },
+ {
+ name: "sig does not match",
+ keyID: keyID,
+ n: new(big.Int).SetBytes([]byte("42")),
+ e: 42,
+ nowFunc: beforeExp,
+ wantErr: true,
+ },
+ {
+ name: "token expired",
+ keyID: keyID,
+ n: pk.N,
+ e: pk.E,
+ nowFunc: afterExp,
+ wantErr: true,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -64,6 +101,9 @@
}
}),
}
+ oldNow := now
+ defer func() { now = oldNow }()
+ now = tt.nowFunc
v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
if err != nil {
@@ -87,11 +127,41 @@
keyID string
x *big.Int
y *big.Int
+ nowFunc func() time.Time
wantErr bool
}{
- {name: "works", keyID: keyID, x: pk.X, y: pk.Y, wantErr: false},
- {name: "no matching key", keyID: "5678", x: pk.X, y: pk.Y, wantErr: true},
- {name: "sig does not match", keyID: keyID, x: new(big.Int), y: new(big.Int), wantErr: true},
+ {
+ name: "works",
+ keyID: keyID,
+ x: pk.X,
+ y: pk.Y,
+ nowFunc: beforeExp,
+ wantErr: false,
+ },
+ {
+ name: "no matching key",
+ keyID: "5678",
+ x: pk.X,
+ y: pk.Y,
+ nowFunc: beforeExp,
+ wantErr: true,
+ },
+ {
+ name: "sig does not match",
+ keyID: keyID,
+ x: new(big.Int),
+ y: new(big.Int),
+ nowFunc: beforeExp,
+ wantErr: true,
+ },
+ {
+ name: "token expired",
+ keyID: keyID,
+ x: pk.X,
+ y: pk.Y,
+ nowFunc: afterExp,
+ wantErr: true,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -117,6 +187,9 @@
}
}),
}
+ oldNow := now
+ defer func() { now = oldNow }()
+ now = tt.nowFunc
v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
if err != nil {
@@ -176,6 +249,7 @@
payload := Payload{
Issuer: "example.com",
Audience: testAudience,
+ Expires: expiry,
}
hb, err := json.Marshal(&header)