mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-28 08:27:26 +02:00
config: support map of jwt claim headers (#1906)
* config: support map of jwt claim headers * fix array handling, add test * update docs * use separate hook, add tests
This commit is contained in:
parent
d04416a5fd
commit
1a1cc30c67
14 changed files with 482 additions and 269 deletions
|
@ -9,14 +9,103 @@ import (
|
|||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
// JWTClaimHeaders are headers to add to a request based on IDP claims.
|
||||
type JWTClaimHeaders map[string]string
|
||||
|
||||
// NewJWTClaimHeaders creates a JWTClaimHeaders map from a slice of claims.
|
||||
func NewJWTClaimHeaders(claims ...string) JWTClaimHeaders {
|
||||
hdrs := make(JWTClaimHeaders)
|
||||
for _, claim := range claims {
|
||||
k := httputil.PomeriumJWTHeaderName(claim)
|
||||
hdrs[k] = claim
|
||||
}
|
||||
return hdrs
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals JSON data into the JWTClaimHeaders.
|
||||
func (hdrs *JWTClaimHeaders) UnmarshalJSON(data []byte) error {
|
||||
var m map[string]interface{}
|
||||
if json.Unmarshal(data, &m) == nil {
|
||||
*hdrs = make(map[string]string)
|
||||
for k, v := range m {
|
||||
str := fmt.Sprint(v)
|
||||
(*hdrs)[k] = str
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var a []interface{}
|
||||
if json.Unmarshal(data, &a) == nil {
|
||||
var vs []string
|
||||
for _, v := range a {
|
||||
vs = append(vs, fmt.Sprint(v))
|
||||
}
|
||||
*hdrs = NewJWTClaimHeaders(vs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if json.Unmarshal(data, &s) == nil {
|
||||
*hdrs = NewJWTClaimHeaders(strings.FieldsFunc(s, func(r rune) bool {
|
||||
return r == ',' || unicode.IsSpace(r)
|
||||
})...)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("JWTClaimHeaders must be an object or an array of values, got: %s", data)
|
||||
}
|
||||
|
||||
// UnmarshalYAML uses UnmarshalJSON to unmarshal YAML data into the JWTClaimHeaders.
|
||||
func (hdrs *JWTClaimHeaders) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var i interface{}
|
||||
err := unmarshal(&i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := serializable(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return hdrs.UnmarshalJSON(bs)
|
||||
}
|
||||
|
||||
func decodeJWTClaimHeadersHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
if t != reflect.TypeOf(JWTClaimHeaders{}) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var hdrs JWTClaimHeaders
|
||||
err = json.Unmarshal(bs, &hdrs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hdrs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// A StringSlice is a slice of strings.
|
||||
type StringSlice []string
|
||||
|
||||
|
@ -108,6 +197,7 @@ type WeightedURL struct {
|
|||
LbWeight uint32
|
||||
}
|
||||
|
||||
// Validate validates the WeightedURL.
|
||||
func (u *WeightedURL) Validate() error {
|
||||
if u.URL.Hostname() == "" {
|
||||
return errHostnameMustBeSpecified
|
||||
|
@ -145,6 +235,7 @@ func (u *WeightedURL) String() string {
|
|||
return fmt.Sprintf("{url=%s, weight=%d}", str, u.LbWeight)
|
||||
}
|
||||
|
||||
// WeightedURLs is a slice of WeightedURL.
|
||||
type WeightedURLs []WeightedURL
|
||||
|
||||
// ParseWeightedUrls parses
|
||||
|
@ -220,6 +311,7 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
|
|||
return str, wghts, nil
|
||||
}
|
||||
|
||||
// DecodePolicyBase64Hook creates a mapstructure DecodeHookFunc.
|
||||
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
if t != reflect.TypeOf([]Policy{}) {
|
||||
|
@ -249,6 +341,7 @@ func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// DecodePolicyHookFunc creates a mapstructure DecodeHookFunc.
|
||||
func DecodePolicyHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
if t != reflect.TypeOf(Policy{}) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue