pomerium/internal/identity/claims.go

151 lines
3.5 KiB
Go

package identity
import (
"encoding/json"
"fmt"
"reflect"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// SessionClaims are claims that are attached to a session so we can store the raw id token.
type SessionClaims struct {
Claims
RawIDToken string
}
// SetRawIDToken sets the raw id token.
func (claims *SessionClaims) SetRawIDToken(rawIDToken string) {
claims.RawIDToken = rawIDToken
}
// Claims are JWT claims.
type Claims map[string]interface{}
// NewClaimsFromRaw creates a new Claims map from a map of raw messages.
func NewClaimsFromRaw(raw map[string]json.RawMessage) Claims {
claims := make(Claims)
for k, rawv := range raw {
var v interface{}
if err := json.Unmarshal(rawv, &v); err == nil {
claims[k] = v
}
}
return claims
}
// UnmarshalJSON unmarshals the raw json data into the claims object.
func (claims *Claims) UnmarshalJSON(data []byte) error {
if *claims == nil {
*claims = make(Claims)
}
var m map[string]interface{}
err := json.Unmarshal(data, &m)
if err != nil {
return err
}
for k, v := range m {
(*claims)[k] = v
}
return nil
}
// Claims takes the claims data and fills v.
func (claims Claims) Claims(v interface{}) error {
bs, err := json.Marshal(claims)
if err != nil {
return err
}
return json.Unmarshal(bs, v)
}
// Flatten flattens the claims to a FlattenedClaims map. For example:
//
// { "a": { "b": { "c": 12345 } } } => { "a.b.c": [12345] }
//
func (claims Claims) Flatten() FlattenedClaims {
flattened := make(FlattenedClaims)
for k, v := range claims {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Map:
subClaims := make(Claims)
iter := rv.MapRange()
for iter.Next() {
subClaims[fmt.Sprint(iter.Key().Interface())] = iter.Value().Interface()
}
for sk, sv := range subClaims.Flatten() {
flattened[k+"."+sk] = sv
}
case reflect.Slice:
slc := make([]interface{}, rv.Len())
for i := 0; i < rv.Len(); i++ {
slc[i] = rv.Index(i).Interface()
}
flattened[k] = slc
default:
flattened[k] = []interface{}{v}
}
}
return flattened
}
// ToAnyMap converts the claims into a map of string => any.
func (claims Claims) ToAnyMap() map[string]*anypb.Any {
m := map[string]*anypb.Any{}
for k, v := range claims {
m[k] = protoutil.ToAny(v)
}
return m
}
// FlattenedClaims are a set claims flattened into a single-level map.
type FlattenedClaims map[string][]interface{}
// NewFlattenedClaimsFromPB creates a new FlattenedClaims from the protobuf struct type.
func NewFlattenedClaimsFromPB(m map[string]*structpb.ListValue) FlattenedClaims {
claims := make(FlattenedClaims)
if m == nil {
return claims
}
bs, _ := json.Marshal(m)
_ = json.Unmarshal(bs, &claims)
return claims
}
// ToPB converts the flattened claims into a protobuf type.
func (claims FlattenedClaims) ToPB() map[string]*structpb.ListValue {
if claims == nil {
return nil
}
m := make(map[string]*structpb.ListValue)
for k, vs := range claims {
svs := make([]*structpb.Value, len(vs))
for i, v := range vs {
svs[i] = protoutil.ToStruct(v)
}
m[k] = &structpb.ListValue{Values: svs}
}
return m
}
// UnmarshalJSON unmarshals JSON into the flattened claims.
func (claims *FlattenedClaims) UnmarshalJSON(data []byte) error {
var unflattened Claims
err := json.Unmarshal(data, &unflattened)
if err != nil {
return err
}
if *claims == nil {
*claims = make(FlattenedClaims)
}
for k, v := range unflattened.Flatten() {
(*claims)[k] = v
}
return nil
}