| // Copyright 2020 Google LLC. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package idtoken |
| |
| import ( |
| "context" |
| "crypto" |
| "crypto/ecdsa" |
| "crypto/elliptic" |
| "crypto/rsa" |
| "crypto/sha256" |
| "encoding/base64" |
| "encoding/json" |
| "fmt" |
| "math/big" |
| "net/http" |
| "strings" |
| |
| htransport "google.golang.org/api/transport/http" |
| ) |
| |
| const ( |
| es256KeySize int = 32 |
| googleIAPCertsURL string = "https://www.gstatic.com/iap/verify/public_key-jwk" |
| googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs" |
| ) |
| |
| var defaultValidator = &Validator{client: newCachingClient(http.DefaultClient)} |
| |
| // Payload represents a decoded payload of an ID Token. |
| type Payload struct { |
| Issuer string `json:"iss"` |
| Audience string `json:"aud"` |
| Expires int64 `json:"exp"` |
| IssuedAt int64 `json:"iat"` |
| Subject string `json:"sub,omitempty"` |
| Claims map[string]interface{} `json:"-"` |
| } |
| |
| // jwt represents the segments of a jwt and exposes convenience methods for |
| // working with the different segments. |
| type jwt struct { |
| header string |
| payload string |
| signature string |
| } |
| |
| // jwtHeader represents a parted jwt's header segment. |
| type jwtHeader struct { |
| Algorithm string `json:"alg"` |
| Type string `json:"typ"` |
| KeyID string `json:"kid"` |
| } |
| |
| // certResponse represents a list jwks. It is the format returned from known |
| // Google cert endpoints. |
| type certResponse struct { |
| Keys []jwk `json:"keys"` |
| } |
| |
| // jwk is a simplified representation of a standard jwk. It only includes the |
| // fields used by Google's cert endpoints. |
| type jwk struct { |
| Alg string `json:"alg"` |
| Crv string `json:"crv"` |
| Kid string `json:"kid"` |
| Kty string `json:"kty"` |
| Use string `json:"use"` |
| E string `json:"e"` |
| N string `json:"n"` |
| X string `json:"x"` |
| Y string `json:"y"` |
| } |
| |
| // Validator provides a way to validate Google ID Tokens with a user provided |
| // http.Client. |
| type Validator struct { |
| client *cachingClient |
| } |
| |
| // NewValidator creates a Validator that uses the options provided to configure |
| // a the internal http.Client that will be used to make requests to fetch JWKs. |
| func NewValidator(ctx context.Context, opts ...ClientOption) (*Validator, error) { |
| client, _, err := htransport.NewClient(ctx, opts...) |
| if err != nil { |
| return nil, err |
| } |
| return &Validator{client: newCachingClient(client)}, nil |
| } |
| |
| // Validate is used to validate the provided idToken with a known Google cert |
| // URL. If audience is not empty the audience claim of the Token is validated. |
| // Upon successful validation a parsed token Payload is returned allowing the |
| // caller to validate any additional claims. |
| func (v *Validator) Validate(ctx context.Context, idToken string, audience string) (*Payload, error) { |
| return v.validate(ctx, idToken, audience) |
| } |
| |
| // Validate is used to validate the provided idToken with a known Google cert |
| // URL. If audience is not empty the audience claim of the Token is validated. |
| // Upon successful validation a parsed token Payload is returned allowing the |
| // caller to validate any additional claims. |
| func Validate(ctx context.Context, idToken string, audience string) (*Payload, error) { |
| // TODO(codyoss): consider adding a check revoked version of the api. See: https://pkg.go.dev/firebase.google.com/go/auth?tab=doc#Client.VerifyIDTokenAndCheckRevoked |
| return defaultValidator.validate(ctx, idToken, audience) |
| } |
| |
| func (v *Validator) validate(ctx context.Context, idToken string, audience string) (*Payload, error) { |
| jwt, err := parseJWT(idToken) |
| if err != nil { |
| return nil, err |
| } |
| header, err := jwt.parsedHeader() |
| if err != nil { |
| return nil, err |
| } |
| payload, err := jwt.parsedPayload() |
| if err != nil { |
| return nil, err |
| } |
| sig, err := jwt.decodedSignature() |
| if err != nil { |
| return nil, err |
| } |
| |
| if audience != "" && payload.Audience != audience { |
| return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT") |
| } |
| |
| switch header.Algorithm { |
| case "RS256": |
| if err := v.validateRS256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil { |
| return nil, err |
| } |
| case "ES256": |
| if err := v.validateES256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil { |
| return nil, err |
| } |
| default: |
| return nil, fmt.Errorf("idtoken: expected JWT signed with RS256 or ES256 but found %q", header.Algorithm) |
| } |
| |
| return payload, nil |
| } |
| |
| func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { |
| certResp, err := v.client.getCert(ctx, googleSACertsURL) |
| if err != nil { |
| return err |
| } |
| j, err := findMatchingKey(certResp, keyID) |
| if err != nil { |
| return err |
| } |
| dn, err := decode(j.N) |
| if err != nil { |
| return err |
| } |
| de, err := decode(j.E) |
| if err != nil { |
| return err |
| } |
| |
| pk := &rsa.PublicKey{ |
| N: new(big.Int).SetBytes(dn), |
| E: int(new(big.Int).SetBytes(de).Int64()), |
| } |
| return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig) |
| } |
| |
| func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { |
| certResp, err := v.client.getCert(ctx, googleIAPCertsURL) |
| if err != nil { |
| return err |
| } |
| j, err := findMatchingKey(certResp, keyID) |
| if err != nil { |
| return err |
| } |
| dx, err := decode(j.X) |
| if err != nil { |
| return err |
| } |
| dy, err := decode(j.Y) |
| if err != nil { |
| return err |
| } |
| |
| pk := &ecdsa.PublicKey{ |
| Curve: elliptic.P256(), |
| X: new(big.Int).SetBytes(dx), |
| Y: new(big.Int).SetBytes(dy), |
| } |
| r := big.NewInt(0).SetBytes(sig[:es256KeySize]) |
| s := big.NewInt(0).SetBytes(sig[es256KeySize:]) |
| if valid := ecdsa.Verify(pk, hashedContent, r, s); !valid { |
| return fmt.Errorf("idtoken: ES256 signature not valid") |
| } |
| return nil |
| } |
| |
| func findMatchingKey(response *certResponse, keyID string) (*jwk, error) { |
| if response == nil { |
| return nil, fmt.Errorf("idtoken: cert response is nil") |
| } |
| for _, v := range response.Keys { |
| if v.Kid == keyID { |
| return &v, nil |
| } |
| } |
| return nil, fmt.Errorf("idtoken: could not find matching cert keyId for the token provided") |
| } |
| |
| func parseJWT(idToken string) (*jwt, error) { |
| segments := strings.Split(idToken, ".") |
| if len(segments) != 3 { |
| return nil, fmt.Errorf("idtoken: invalid token, token must have three segments; found %d", len(segments)) |
| } |
| return &jwt{ |
| header: segments[0], |
| payload: segments[1], |
| signature: segments[2], |
| }, nil |
| } |
| |
| // decodedHeader base64 decodes the header segment. |
| func (j *jwt) decodedHeader() ([]byte, error) { |
| dh, err := decode(j.header) |
| if err != nil { |
| return nil, fmt.Errorf("idtoken: unable to decode JWT header: %v", err) |
| } |
| return dh, nil |
| } |
| |
| // decodedPayload base64 payload the header segment. |
| func (j *jwt) decodedPayload() ([]byte, error) { |
| p, err := decode(j.payload) |
| if err != nil { |
| return nil, fmt.Errorf("idtoken: unable to decode JWT payload: %v", err) |
| } |
| return p, nil |
| } |
| |
| // decodedPayload base64 payload the header segment. |
| func (j *jwt) decodedSignature() ([]byte, error) { |
| p, err := decode(j.signature) |
| if err != nil { |
| return nil, fmt.Errorf("idtoken: unable to decode JWT signature: %v", err) |
| } |
| return p, nil |
| } |
| |
| // parsedHeader returns a struct representing a JWT header. |
| func (j *jwt) parsedHeader() (jwtHeader, error) { |
| var h jwtHeader |
| dh, err := j.decodedHeader() |
| if err != nil { |
| return h, err |
| } |
| err = json.Unmarshal(dh, &h) |
| if err != nil { |
| return h, fmt.Errorf("idtoken: unable to unmarshal JWT header: %v", err) |
| } |
| return h, nil |
| } |
| |
| // parsedPayload returns a struct representing a JWT payload. |
| func (j *jwt) parsedPayload() (*Payload, error) { |
| var p Payload |
| dp, err := j.decodedPayload() |
| if err != nil { |
| return nil, err |
| } |
| err = json.Unmarshal(dp, &p) |
| if err != nil { |
| return nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload: %v", err) |
| } |
| return &p, nil |
| } |
| |
| // hashedContent gets the SHA256 checksum for verification of the JWT. |
| func (j *jwt) hashedContent() []byte { |
| signedContent := j.header + "." + j.payload |
| hashed := sha256.Sum256([]byte(signedContent)) |
| return hashed[:] |
| } |
| |
| func (j *jwt) String() string { |
| return fmt.Sprintf("%s.%s.%s", j.header, j.payload, j.signature) |
| } |
| |
| func decode(s string) ([]byte, error) { |
| return base64.RawURLEncoding.DecodeString(s) |
| } |