config: support map of jwt claim headers (#1906)

* config: support map of jwt claim headers

* fix array handling, add test

* update docs

* use separate hook, add tests
This commit is contained in:
Caleb Doxsey 2021-02-17 13:43:18 -07:00 committed by GitHub
parent d04416a5fd
commit 1a1cc30c67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 482 additions and 269 deletions

View file

@ -9,14 +9,103 @@ import (
"reflect"
"strconv"
"strings"
"unicode"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/mitchellh/mapstructure"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"gopkg.in/yaml.v2"
"github.com/pomerium/pomerium/internal/httputil"
)
// JWTClaimHeaders are headers to add to a request based on IDP claims.
type JWTClaimHeaders map[string]string
// NewJWTClaimHeaders creates a JWTClaimHeaders map from a slice of claims.
func NewJWTClaimHeaders(claims ...string) JWTClaimHeaders {
hdrs := make(JWTClaimHeaders)
for _, claim := range claims {
k := httputil.PomeriumJWTHeaderName(claim)
hdrs[k] = claim
}
return hdrs
}
// UnmarshalJSON unmarshals JSON data into the JWTClaimHeaders.
func (hdrs *JWTClaimHeaders) UnmarshalJSON(data []byte) error {
var m map[string]interface{}
if json.Unmarshal(data, &m) == nil {
*hdrs = make(map[string]string)
for k, v := range m {
str := fmt.Sprint(v)
(*hdrs)[k] = str
}
return nil
}
var a []interface{}
if json.Unmarshal(data, &a) == nil {
var vs []string
for _, v := range a {
vs = append(vs, fmt.Sprint(v))
}
*hdrs = NewJWTClaimHeaders(vs...)
return nil
}
var s string
if json.Unmarshal(data, &s) == nil {
*hdrs = NewJWTClaimHeaders(strings.FieldsFunc(s, func(r rune) bool {
return r == ',' || unicode.IsSpace(r)
})...)
return nil
}
return fmt.Errorf("JWTClaimHeaders must be an object or an array of values, got: %s", data)
}
// UnmarshalYAML uses UnmarshalJSON to unmarshal YAML data into the JWTClaimHeaders.
func (hdrs *JWTClaimHeaders) UnmarshalYAML(unmarshal func(interface{}) error) error {
var i interface{}
err := unmarshal(&i)
if err != nil {
return err
}
m, err := serializable(i)
if err != nil {
return err
}
bs, err := json.Marshal(m)
if err != nil {
return err
}
return hdrs.UnmarshalJSON(bs)
}
func decodeJWTClaimHeadersHookFunc() mapstructure.DecodeHookFunc {
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
if t != reflect.TypeOf(JWTClaimHeaders{}) {
return data, nil
}
bs, err := json.Marshal(data)
if err != nil {
return nil, err
}
var hdrs JWTClaimHeaders
err = json.Unmarshal(bs, &hdrs)
if err != nil {
return nil, err
}
return hdrs, nil
}
}
// A StringSlice is a slice of strings.
type StringSlice []string
@ -108,6 +197,7 @@ type WeightedURL struct {
LbWeight uint32
}
// Validate validates the WeightedURL.
func (u *WeightedURL) Validate() error {
if u.URL.Hostname() == "" {
return errHostnameMustBeSpecified
@ -145,6 +235,7 @@ func (u *WeightedURL) String() string {
return fmt.Sprintf("{url=%s, weight=%d}", str, u.LbWeight)
}
// WeightedURLs is a slice of WeightedURL.
type WeightedURLs []WeightedURL
// ParseWeightedUrls parses
@ -220,6 +311,7 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
return str, wghts, nil
}
// DecodePolicyBase64Hook creates a mapstructure DecodeHookFunc.
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
if t != reflect.TypeOf([]Policy{}) {
@ -249,6 +341,7 @@ func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
}
}
// DecodePolicyHookFunc creates a mapstructure DecodeHookFunc.
func DecodePolicyHookFunc() mapstructure.DecodeHookFunc {
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
if t != reflect.TypeOf(Policy{}) {