idtoken: populate Claims map (#498)
Unmarshal all claims into Payload.Claims.
Fixes #497.
diff --git a/idtoken/validate.go b/idtoken/validate.go
index d614c90..83efb33 100644
--- a/idtoken/validate.go
+++ b/idtoken/validate.go
@@ -282,10 +282,12 @@
if err != nil {
return nil, err
}
- err = json.Unmarshal(dp, &p)
- if err != nil {
+ if err := json.Unmarshal(dp, &p); err != nil {
return nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload: %v", err)
}
+ if err := json.Unmarshal(dp, &p.Claims); err != nil {
+ return nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload claims: %v", err)
+ }
return &p, nil
}
diff --git a/idtoken/validate_test.go b/idtoken/validate_test.go
index 15a32eb..46d5b87 100644
--- a/idtoken/validate_test.go
+++ b/idtoken/validate_test.go
@@ -77,6 +77,7 @@
wantErr: true,
},
}
+
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{
@@ -110,11 +111,35 @@
t.Fatalf("NewValidator(...) = %q, want nil", err)
}
payload, err := v.Validate(context.Background(), idToken, testAudience)
- if !tt.wantErr && err != nil {
- t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
+ if tt.wantErr && err != nil {
+ // Got the error we wanted.
+ return
}
- if !tt.wantErr && payload.Audience != testAudience {
- t.Fatalf("got %v, want %v", payload.Audience, testAudience)
+ if !tt.wantErr && err != nil {
+ t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err)
+ }
+ if tt.wantErr && err == nil {
+ t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience)
+ }
+ if payload == nil {
+ t.Fatalf("Got nil payload, err: %v", err)
+ }
+ if payload.Audience != testAudience {
+ t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience)
+ }
+ if len(payload.Claims) == 0 {
+ t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims)
+ }
+ if got, ok := payload.Claims["aud"]; !ok {
+ t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims)
+ } else {
+ got, ok := got.(string)
+ if !ok {
+ t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims)
+ }
+ if got != testAudience {
+ t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got)
+ }
}
})
}