mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-25 15:07:33 +02:00
authorize: support authenticating with idp tokens (#5484)
* identity: add support for verifying access and identity tokens * allow overriding with policy option * authenticate: add verify endpoints * wip * implement session creation * add verify test * implement idp token login * fix tests * add pr permission * make session ids route-specific * rename method * add test * add access token test * test for newUserFromIDPClaims * more tests * make the session id per-idp * use type for * add test * remove nil checks
This commit is contained in:
parent
6e22b7a19a
commit
b9fd926618
36 changed files with 2791 additions and 885 deletions
172
internal/jwtutil/jwtutil.go
Normal file
172
internal/jwtutil/jwtutil.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
// Package jwtutil contains functions for working with JWTs.
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claims represent claims in a JWT.
|
||||
type Claims map[string]any
|
||||
|
||||
// UnmarshalJSON implements a custom unmarshaller for claims data.
|
||||
func (claims *Claims) UnmarshalJSON(raw []byte) error {
|
||||
dst := map[string]any{}
|
||||
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||
dec.UseNumber()
|
||||
err := dec.Decode(&dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*claims = Claims(dst)
|
||||
return nil
|
||||
}
|
||||
|
||||
// registered claims
|
||||
|
||||
// GetIssuer gets the iss claim.
|
||||
func (claims Claims) GetIssuer() (issuer string, ok bool) {
|
||||
return claims.GetString("iss")
|
||||
}
|
||||
|
||||
// GetSubject gets the sub claim.
|
||||
func (claims Claims) GetSubject() (subject string, ok bool) {
|
||||
return claims.GetString("sub")
|
||||
}
|
||||
|
||||
// GetAudience gets the aud claim.
|
||||
func (claims Claims) GetAudience() (audiences []string, ok bool) {
|
||||
return claims.GetStringSlice("aud")
|
||||
}
|
||||
|
||||
// GetExpirationTime gets the exp claim.
|
||||
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
|
||||
return claims.GetNumericDate("exp")
|
||||
}
|
||||
|
||||
// GetNotBefore gets the nbf claim.
|
||||
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
|
||||
return claims.GetNumericDate("nbf")
|
||||
}
|
||||
|
||||
// GetIssuedAt gets the iat claim.
|
||||
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
|
||||
return claims.GetNumericDate("iat")
|
||||
}
|
||||
|
||||
// GetJWTID gets the jti claim.
|
||||
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
|
||||
return claims.GetString("jti")
|
||||
}
|
||||
|
||||
// custom claims
|
||||
|
||||
// GetUserID returns the oid or sub claim.
|
||||
func (claims Claims) GetUserID() (userID string, ok bool) {
|
||||
if oid, ok := claims.GetString("oid"); ok {
|
||||
return oid, true
|
||||
}
|
||||
|
||||
if sub, ok := claims.GetSubject(); ok {
|
||||
return sub, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetNumericDate returns the claim as a numeric date.
|
||||
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
|
||||
if claims == nil {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case float32:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case float64:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int64:
|
||||
return time.Unix(v, 0), true
|
||||
case int32:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int16:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int8:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint64:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint32:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint16:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint8:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case json.Number:
|
||||
i, err := v.Int64()
|
||||
if err != nil {
|
||||
if f, err := v.Float64(); err == nil {
|
||||
i = int64(f)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return tm, false
|
||||
}
|
||||
return time.Unix(i, 0), true
|
||||
}
|
||||
|
||||
return tm, false
|
||||
}
|
||||
|
||||
// GetString returns the claim as a string.
|
||||
func (claims Claims) GetString(name string) (value string, ok bool) {
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return value, false
|
||||
}
|
||||
|
||||
return toString(raw), true
|
||||
}
|
||||
|
||||
// GetStringSlice returns the claim as a slice of strings.
|
||||
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return toStringSlice(raw), true
|
||||
}
|
||||
|
||||
func toString(data any) string {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
return v
|
||||
}
|
||||
return fmt.Sprint(data)
|
||||
}
|
||||
|
||||
func toStringSlice(obj any) []string {
|
||||
v := reflect.ValueOf(obj)
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
vs := make([]string, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
vs[i] = toString(v.Index(i).Interface())
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
return []string{toString(obj)}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue