mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
150 lines
3.5 KiB
Go
150 lines
3.5 KiB
Go
package identity
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
// SessionClaims are claims that are attached to a session so we can store the raw id token.
|
|
type SessionClaims struct {
|
|
Claims
|
|
RawIDToken string
|
|
}
|
|
|
|
// SetRawIDToken sets the raw id token.
|
|
func (claims *SessionClaims) SetRawIDToken(rawIDToken string) {
|
|
claims.RawIDToken = rawIDToken
|
|
}
|
|
|
|
// Claims are JWT claims.
|
|
type Claims map[string]any
|
|
|
|
// NewClaimsFromRaw creates a new Claims map from a map of raw messages.
|
|
func NewClaimsFromRaw(raw map[string]json.RawMessage) Claims {
|
|
claims := make(Claims)
|
|
for k, rawv := range raw {
|
|
var v any
|
|
if err := json.Unmarshal(rawv, &v); err == nil {
|
|
claims[k] = v
|
|
}
|
|
}
|
|
return claims
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals the raw json data into the claims object.
|
|
func (claims *Claims) UnmarshalJSON(data []byte) error {
|
|
if *claims == nil {
|
|
*claims = make(Claims)
|
|
}
|
|
|
|
var m map[string]any
|
|
err := json.Unmarshal(data, &m)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for k, v := range m {
|
|
(*claims)[k] = v
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Claims takes the claims data and fills v.
|
|
func (claims Claims) Claims(v any) error {
|
|
bs, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return json.Unmarshal(bs, v)
|
|
}
|
|
|
|
// Flatten flattens the claims to a FlattenedClaims map. For example:
|
|
//
|
|
// { "a": { "b": { "c": 12345 } } } => { "a.b.c": [12345] }
|
|
func (claims Claims) Flatten() FlattenedClaims {
|
|
flattened := make(FlattenedClaims)
|
|
for k, v := range claims {
|
|
rv := reflect.ValueOf(v)
|
|
switch rv.Kind() {
|
|
case reflect.Map:
|
|
subClaims := make(Claims)
|
|
iter := rv.MapRange()
|
|
for iter.Next() {
|
|
subClaims[fmt.Sprint(iter.Key().Interface())] = iter.Value().Interface()
|
|
}
|
|
for sk, sv := range subClaims.Flatten() {
|
|
flattened[k+"."+sk] = sv
|
|
}
|
|
case reflect.Slice:
|
|
slc := make([]any, rv.Len())
|
|
for i := 0; i < rv.Len(); i++ {
|
|
slc[i] = rv.Index(i).Interface()
|
|
}
|
|
flattened[k] = slc
|
|
default:
|
|
flattened[k] = []any{v}
|
|
}
|
|
}
|
|
return flattened
|
|
}
|
|
|
|
// ToAnyMap converts the claims into a map of string => any.
|
|
func (claims Claims) ToAnyMap() map[string]*anypb.Any {
|
|
m := map[string]*anypb.Any{}
|
|
for k, v := range claims {
|
|
m[k] = protoutil.ToAny(v)
|
|
}
|
|
return m
|
|
}
|
|
|
|
// FlattenedClaims are a set claims flattened into a single-level map.
|
|
type FlattenedClaims map[string][]any
|
|
|
|
// NewFlattenedClaimsFromPB creates a new FlattenedClaims from the protobuf struct type.
|
|
func NewFlattenedClaimsFromPB(m map[string]*structpb.ListValue) FlattenedClaims {
|
|
claims := make(FlattenedClaims)
|
|
if m == nil {
|
|
return claims
|
|
}
|
|
bs, _ := json.Marshal(m)
|
|
_ = json.Unmarshal(bs, &claims)
|
|
return claims
|
|
}
|
|
|
|
// ToPB converts the flattened claims into a protobuf type.
|
|
func (claims FlattenedClaims) ToPB() map[string]*structpb.ListValue {
|
|
if claims == nil {
|
|
return nil
|
|
}
|
|
m := make(map[string]*structpb.ListValue)
|
|
for k, vs := range claims {
|
|
svs := make([]*structpb.Value, len(vs))
|
|
for i, v := range vs {
|
|
svs[i] = protoutil.ToStruct(v)
|
|
}
|
|
m[k] = &structpb.ListValue{Values: svs}
|
|
}
|
|
return m
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals JSON into the flattened claims.
|
|
func (claims *FlattenedClaims) UnmarshalJSON(data []byte) error {
|
|
var unflattened Claims
|
|
err := json.Unmarshal(data, &unflattened)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if *claims == nil {
|
|
*claims = make(FlattenedClaims)
|
|
}
|
|
for k, v := range unflattened.Flatten() {
|
|
(*claims)[k] = v
|
|
}
|
|
return nil
|
|
}
|