pomerium/internal/identity/claims.go
dependabot[bot] ec495bb682
chore(deps): bump github.com/golangci/golangci-lint from 1.48.0 to 1.50.0 (#3667)
* chore(deps): bump github.com/golangci/golangci-lint

Bumps [github.com/golangci/golangci-lint](https://github.com/golangci/golangci-lint) from 1.48.0 to 1.50.0.
- [Release notes](https://github.com/golangci/golangci-lint/releases)
- [Changelog](https://github.com/golangci/golangci-lint/blob/master/CHANGELOG.md)
- [Commits](https://github.com/golangci/golangci-lint/compare/v1.48.0...v1.50.0)

---
updated-dependencies:
- dependency-name: github.com/golangci/golangci-lint
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* lint

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
2022-10-19 09:36:59 -06:00

150 lines
3.5 KiB
Go

package identity
import (
"encoding/json"
"fmt"
"reflect"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// SessionClaims are claims that are attached to a session so we can store the raw id token.
type SessionClaims struct {
Claims
RawIDToken string
}
// SetRawIDToken sets the raw id token.
func (claims *SessionClaims) SetRawIDToken(rawIDToken string) {
claims.RawIDToken = rawIDToken
}
// Claims are JWT claims.
type Claims map[string]interface{}
// NewClaimsFromRaw creates a new Claims map from a map of raw messages.
func NewClaimsFromRaw(raw map[string]json.RawMessage) Claims {
claims := make(Claims)
for k, rawv := range raw {
var v interface{}
if err := json.Unmarshal(rawv, &v); err == nil {
claims[k] = v
}
}
return claims
}
// UnmarshalJSON unmarshals the raw json data into the claims object.
func (claims *Claims) UnmarshalJSON(data []byte) error {
if *claims == nil {
*claims = make(Claims)
}
var m map[string]interface{}
err := json.Unmarshal(data, &m)
if err != nil {
return err
}
for k, v := range m {
(*claims)[k] = v
}
return nil
}
// Claims takes the claims data and fills v.
func (claims Claims) Claims(v interface{}) error {
bs, err := json.Marshal(claims)
if err != nil {
return err
}
return json.Unmarshal(bs, v)
}
// Flatten flattens the claims to a FlattenedClaims map. For example:
//
// { "a": { "b": { "c": 12345 } } } => { "a.b.c": [12345] }
func (claims Claims) Flatten() FlattenedClaims {
flattened := make(FlattenedClaims)
for k, v := range claims {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Map:
subClaims := make(Claims)
iter := rv.MapRange()
for iter.Next() {
subClaims[fmt.Sprint(iter.Key().Interface())] = iter.Value().Interface()
}
for sk, sv := range subClaims.Flatten() {
flattened[k+"."+sk] = sv
}
case reflect.Slice:
slc := make([]interface{}, rv.Len())
for i := 0; i < rv.Len(); i++ {
slc[i] = rv.Index(i).Interface()
}
flattened[k] = slc
default:
flattened[k] = []interface{}{v}
}
}
return flattened
}
// ToAnyMap converts the claims into a map of string => any.
func (claims Claims) ToAnyMap() map[string]*anypb.Any {
m := map[string]*anypb.Any{}
for k, v := range claims {
m[k] = protoutil.ToAny(v)
}
return m
}
// FlattenedClaims are a set claims flattened into a single-level map.
type FlattenedClaims map[string][]interface{}
// NewFlattenedClaimsFromPB creates a new FlattenedClaims from the protobuf struct type.
func NewFlattenedClaimsFromPB(m map[string]*structpb.ListValue) FlattenedClaims {
claims := make(FlattenedClaims)
if m == nil {
return claims
}
bs, _ := json.Marshal(m)
_ = json.Unmarshal(bs, &claims)
return claims
}
// ToPB converts the flattened claims into a protobuf type.
func (claims FlattenedClaims) ToPB() map[string]*structpb.ListValue {
if claims == nil {
return nil
}
m := make(map[string]*structpb.ListValue)
for k, vs := range claims {
svs := make([]*structpb.Value, len(vs))
for i, v := range vs {
svs[i] = protoutil.ToStruct(v)
}
m[k] = &structpb.ListValue{Values: svs}
}
return m
}
// UnmarshalJSON unmarshals JSON into the flattened claims.
func (claims *FlattenedClaims) UnmarshalJSON(data []byte) error {
var unflattened Claims
err := json.Unmarshal(data, &unflattened)
if err != nil {
return err
}
if *claims == nil {
*claims = make(FlattenedClaims)
}
for k, v := range unflattened.Flatten() {
(*claims)[k] = v
}
return nil
}