pomerium/internal/jwtutil/jwtutil.go
Caleb Doxsey b9fd926618
authorize: support authenticating with idp tokens (#5484)
* identity: add support for verifying access and identity tokens

* allow overriding with policy option

* authenticate: add verify endpoints

* wip

* implement session creation

* add verify test

* implement idp token login

* fix tests

* add pr permission

* make session ids route-specific

* rename method

* add test

* add access token test

* test for newUserFromIDPClaims

* more tests

* make the session id per-idp

* use type for

* add test

* remove nil checks
2025-02-18 13:02:06 -07:00

172 lines
3.6 KiB
Go

// Package jwtutil contains functions for working with JWTs.
package jwtutil
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"time"
)
// Claims represent claims in a JWT.
type Claims map[string]any
// UnmarshalJSON implements a custom unmarshaller for claims data.
func (claims *Claims) UnmarshalJSON(raw []byte) error {
dst := map[string]any{}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
err := dec.Decode(&dst)
if err != nil {
return err
}
*claims = Claims(dst)
return nil
}
// registered claims
// GetIssuer gets the iss claim.
func (claims Claims) GetIssuer() (issuer string, ok bool) {
return claims.GetString("iss")
}
// GetSubject gets the sub claim.
func (claims Claims) GetSubject() (subject string, ok bool) {
return claims.GetString("sub")
}
// GetAudience gets the aud claim.
func (claims Claims) GetAudience() (audiences []string, ok bool) {
return claims.GetStringSlice("aud")
}
// GetExpirationTime gets the exp claim.
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
return claims.GetNumericDate("exp")
}
// GetNotBefore gets the nbf claim.
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
return claims.GetNumericDate("nbf")
}
// GetIssuedAt gets the iat claim.
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
return claims.GetNumericDate("iat")
}
// GetJWTID gets the jti claim.
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
return claims.GetString("jti")
}
// custom claims
// GetUserID returns the oid or sub claim.
func (claims Claims) GetUserID() (userID string, ok bool) {
if oid, ok := claims.GetString("oid"); ok {
return oid, true
}
if sub, ok := claims.GetSubject(); ok {
return sub, true
}
return "", false
}
// GetNumericDate returns the claim as a numeric date.
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
if claims == nil {
return tm, false
}
raw, ok := claims[name]
if !ok {
return tm, false
}
switch v := raw.(type) {
case float32:
return time.Unix(int64(v), 0), true
case float64:
return time.Unix(int64(v), 0), true
case int64:
return time.Unix(v, 0), true
case int32:
return time.Unix(int64(v), 0), true
case int16:
return time.Unix(int64(v), 0), true
case int8:
return time.Unix(int64(v), 0), true
case int:
return time.Unix(int64(v), 0), true
case uint64:
return time.Unix(int64(v), 0), true
case uint32:
return time.Unix(int64(v), 0), true
case uint16:
return time.Unix(int64(v), 0), true
case uint8:
return time.Unix(int64(v), 0), true
case uint:
return time.Unix(int64(v), 0), true
case json.Number:
i, err := v.Int64()
if err != nil {
if f, err := v.Float64(); err == nil {
i = int64(f)
}
}
if err != nil {
return tm, false
}
return time.Unix(i, 0), true
}
return tm, false
}
// GetString returns the claim as a string.
func (claims Claims) GetString(name string) (value string, ok bool) {
raw, ok := claims[name]
if !ok {
return value, false
}
return toString(raw), true
}
// GetStringSlice returns the claim as a slice of strings.
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
raw, ok := claims[name]
if !ok {
return nil, false
}
return toStringSlice(raw), true
}
func toString(data any) string {
switch v := data.(type) {
case string:
return v
}
return fmt.Sprint(data)
}
func toStringSlice(obj any) []string {
v := reflect.ValueOf(obj)
switch v.Kind() {
case reflect.Slice:
vs := make([]string, v.Len())
for i := 0; i < v.Len(); i++ {
vs[i] = toString(v.Index(i).Interface())
}
return vs
}
return []string{toString(obj)}
}