mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 14:37:12 +02:00
implement session creation
This commit is contained in:
parent
24b35e26a5
commit
b95ad4dbc3
15 changed files with 646 additions and 148 deletions
160
internal/jwtutil/jwtutil.go
Normal file
160
internal/jwtutil/jwtutil.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
// Package jwtutil contains functions for working with JWTs.
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claims represent claims in a JWT.
|
||||
type Claims map[string]any
|
||||
|
||||
// UnmarshalJSON implements a custom unmarshaller for claims data.
|
||||
func (claims *Claims) UnmarshalJSON(raw []byte) error {
|
||||
dst := map[string]any{}
|
||||
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||
dec.UseNumber()
|
||||
err := dec.Decode(&dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*claims = Claims(dst)
|
||||
return nil
|
||||
}
|
||||
|
||||
// registered claims
|
||||
|
||||
// GetIssuer gets the iss claim.
|
||||
func (claims Claims) GetIssuer() (issuer string, ok bool) {
|
||||
return claims.GetString("iss")
|
||||
}
|
||||
|
||||
// GetSubject gets the sub claim.
|
||||
func (claims Claims) GetSubject() (subject string, ok bool) {
|
||||
return claims.GetString("sub")
|
||||
}
|
||||
|
||||
// GetAudience gets the aud claim.
|
||||
func (claims Claims) GetAudience() (audiences []string, ok bool) {
|
||||
return claims.GetStringSlice("aud")
|
||||
}
|
||||
|
||||
// GetExpirationTime gets the exp claim.
|
||||
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
|
||||
return claims.GetNumericDate("exp")
|
||||
}
|
||||
|
||||
// GetNotBefore gets the nbf claim.
|
||||
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
|
||||
return claims.GetNumericDate("nbf")
|
||||
}
|
||||
|
||||
// GetIssuedAt gets the iat claim.
|
||||
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
|
||||
return claims.GetNumericDate("iat")
|
||||
}
|
||||
|
||||
// GetJWTID gets the jti claim.
|
||||
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
|
||||
return claims.GetString("jti")
|
||||
}
|
||||
|
||||
// custom claims
|
||||
|
||||
// GetUserID returns the oid or sub claim.
|
||||
func (claims Claims) GetUserID() (userID string, ok bool) {
|
||||
if oid, ok := claims.GetString("oid"); ok {
|
||||
return oid, true
|
||||
}
|
||||
|
||||
if sub, ok := claims.GetSubject(); ok {
|
||||
return sub, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetNumericDate returns the claim as a numeric date.
|
||||
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
|
||||
if claims == nil {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int64:
|
||||
return time.Unix(v, 0), true
|
||||
case json.Number:
|
||||
i, err := v.Int64()
|
||||
if err != nil {
|
||||
if f, err := v.Float64(); err == nil {
|
||||
i = int64(f)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return tm, false
|
||||
}
|
||||
return time.Unix(i, 0), true
|
||||
}
|
||||
|
||||
return tm, false
|
||||
}
|
||||
|
||||
// GetString returns the claim as a string.
|
||||
func (claims Claims) GetString(name string) (value string, ok bool) {
|
||||
if claims == nil {
|
||||
return value, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return value, false
|
||||
}
|
||||
|
||||
return toString(raw), true
|
||||
}
|
||||
|
||||
// GetStringSlice returns the claim as a slice of strings.
|
||||
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
|
||||
if claims == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return toStringSlice(raw), true
|
||||
}
|
||||
|
||||
func toString(data any) string {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
return v
|
||||
}
|
||||
return fmt.Sprint(data)
|
||||
}
|
||||
|
||||
func toStringSlice(obj any) []string {
|
||||
v := reflect.ValueOf(obj)
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
vs := make([]string, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
vs[i] = toString(v.Index(i).Interface())
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
return []string{toString(obj)}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue