add global jwt_issuer_format option (#5508)

Add a corresponding global setting for the existing route-level
jwt_issuer_format option. The route-level option will take precedence
when set to a non-empty string.
This commit is contained in:
Kenneth Jenkins 2025-03-11 14:11:50 -07:00 committed by GitHub
parent b86c9931b1
commit ad183873f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 902 additions and 781 deletions

View file

@ -149,6 +149,7 @@ func newPolicyEvaluator(
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()), evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()),
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders), evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
evaluator.WithJWTGroupsFilter(opts.JWTGroupsFilter), evaluator.WithJWTGroupsFilter(opts.JWTGroupsFilter),
evaluator.WithDefaultJWTIssuerFormat(opts.JWTIssuerFormat),
) )
} }

View file

@ -16,6 +16,7 @@ type evaluatorConfig struct {
GoogleCloudServerlessAuthenticationServiceAccount string GoogleCloudServerlessAuthenticationServiceAccount string
JWTClaimsHeaders config.JWTClaimHeaders JWTClaimsHeaders config.JWTClaimHeaders
JWTGroupsFilter config.JWTGroupsFilter JWTGroupsFilter config.JWTGroupsFilter
DefaultJWTIssuerFormat config.JWTIssuerFormat
} }
// cacheKey() returns a hash over the configuration, except for the policies. // cacheKey() returns a hash over the configuration, except for the policies.
@ -105,3 +106,10 @@ func WithJWTGroupsFilter(groups config.JWTGroupsFilter) Option {
cfg.JWTGroupsFilter = groups cfg.JWTGroupsFilter = groups
} }
} }
// WithDefaultJWTIssuerFormat sets the default JWT issuer format in the config.
func WithDefaultJWTIssuerFormat(format config.JWTIssuerFormat) Option {
return func(cfg *evaluatorConfig) {
cfg.DefaultJWTIssuerFormat = format
}
}

View file

@ -332,6 +332,7 @@ func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig)
) )
store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders) store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders)
store.UpdateJWTGroupsFilter(cfg.JWTGroupsFilter) store.UpdateJWTGroupsFilter(cfg.JWTGroupsFilter)
store.UpdateDefaultJWTIssuerFormat(cfg.DefaultJWTIssuerFormat)
store.UpdateRoutePolicies(cfg.Policies) store.UpdateRoutePolicies(cfg.Policies)
store.UpdateSigningKey(jwk) store.UpdateSigningKey(jwk)

View file

@ -247,18 +247,16 @@ func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string {
return make([]string, 0) return make([]string, 0)
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadIss() (string, error) { func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string {
var issuerFormat string issuerFormat := e.evaluator.store.GetDefaultJWTIssuerFormat()
if e.request.Policy != nil { if e.request.Policy != nil && e.request.Policy.JWTIssuerFormat != "" {
issuerFormat = e.request.Policy.JWTIssuerFormat issuerFormat = e.request.Policy.JWTIssuerFormat
} }
switch issuerFormat { switch issuerFormat {
case "uri": case config.JWTIssuerFormatURI:
return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname), nil return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname)
case "", "hostOnly":
return e.request.HTTP.Hostname, nil
default: default:
return "", fmt.Errorf("unsupported JWT issuer format: %s", issuerFormat) return e.request.HTTP.Hostname
} }
} }
@ -412,14 +410,9 @@ func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) (map[str
return e.cachedJWTPayload, nil return e.cachedJWTPayload, nil
} }
iss, err := e.getJWTPayloadIss()
if err != nil {
return nil, err
}
e.gotJWTPayload = true e.gotJWTPayload = true
e.cachedJWTPayload = map[string]any{ e.cachedJWTPayload = map[string]any{
"iss": iss, "iss": e.getJWTPayloadIss(),
"aud": e.getJWTPayloadAud(), "aud": e.getJWTPayloadAud(),
"jti": e.getJWTPayloadJTI(), "jti": e.getJWTPayloadJTI(),
"iat": e.getJWTPayloadIAT(), "iat": e.getJWTPayloadIAT(),

View file

@ -437,34 +437,59 @@ func TestHeadersEvaluator(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "u1@example.com", output.Headers.Get("X-Pomerium-Claim-Email")) assert.Equal(t, "u1@example.com", output.Headers.Get("X-Pomerium-Claim-Email"))
}) })
}
t.Run("issuer format", func(t *testing.T) { func TestHeadersEvaluator_JWTIssuerFormat(t *testing.T) {
t.Parallel() privateJWK, _ := newJWK(t)
for _, tc := range []struct { store := store.New()
format string store.UpdateSigningKey(privateJWK)
input string
output string eval := func(_ *testing.T, input *Request) (*HeadersResponse, error) {
}{ ctx := context.Background()
{"", "example.com", "example.com"}, e := NewHeadersEvaluator(store)
{"hostOnly", "host-only.example.com", "host-only.example.com"}, return e.Evaluate(ctx, input)
{"uri", "uri.example.com", "https://uri.example.com/"}, }
} {
hostname := "route.example.com"
cases := []struct {
globalFormat config.JWTIssuerFormat
routeFormat config.JWTIssuerFormat
expected string
}{
{"", "", "route.example.com"},
{"hostOnly", "", "route.example.com"},
{"uri", "", "https://route.example.com/"},
{"", "hostOnly", "route.example.com"},
{"hostOnly", "hostOnly", "route.example.com"},
{"uri", "hostOnly", "route.example.com"},
{"", "uri", "https://route.example.com/"},
{"hostOnly", "uri", "https://route.example.com/"},
{"uri", "uri", "https://route.example.com/"},
}
for _, tc := range cases {
t.Run("", func(t *testing.T) {
store.UpdateDefaultJWTIssuerFormat(tc.globalFormat)
output, err := eval(t, output, err := eval(t,
nil,
&Request{ &Request{
HTTP: RequestHTTP{ HTTP: RequestHTTP{
Hostname: tc.input, Hostname: hostname,
}, },
Policy: &config.Policy{ Policy: &config.Policy{
JWTIssuerFormat: tc.format, JWTIssuerFormat: tc.routeFormat,
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
m := decodeJWTAssertion(t, output.Headers) m := decodeJWTAssertion(t, output.Headers)
assert.Equal(t, tc.output, m["iss"], "unexpected issuer for format=%s", tc.format) assert.Equal(t, tc.expected, m["iss"],
} "unexpected issuer for global format=%q, route format=%q",
}) tc.globalFormat, tc.routeFormat)
})
}
} }
func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) { func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) {

View file

@ -33,6 +33,7 @@ type Store struct {
googleCloudServerlessAuthenticationServiceAccount atomic.Pointer[string] googleCloudServerlessAuthenticationServiceAccount atomic.Pointer[string]
jwtClaimHeaders atomic.Pointer[map[string]string] jwtClaimHeaders atomic.Pointer[map[string]string]
jwtGroupsFilter atomic.Pointer[config.JWTGroupsFilter] jwtGroupsFilter atomic.Pointer[config.JWTGroupsFilter]
defaultJWTIssuerFormat atomic.Pointer[config.JWTIssuerFormat]
signingKey atomic.Pointer[jose.JSONWebKey] signingKey atomic.Pointer[jose.JSONWebKey]
} }
@ -66,6 +67,13 @@ func (s *Store) GetJWTGroupsFilter() config.JWTGroupsFilter {
return config.JWTGroupsFilter{} return config.JWTGroupsFilter{}
} }
func (s *Store) GetDefaultJWTIssuerFormat() config.JWTIssuerFormat {
if f := s.defaultJWTIssuerFormat.Load(); f != nil {
return *f
}
return ""
}
func (s *Store) GetSigningKey() *jose.JSONWebKey { func (s *Store) GetSigningKey() *jose.JSONWebKey {
return s.signingKey.Load() return s.signingKey.Load()
} }
@ -89,6 +97,12 @@ func (s *Store) UpdateJWTGroupsFilter(groups config.JWTGroupsFilter) {
s.jwtGroupsFilter.Store(&groups) s.jwtGroupsFilter.Store(&groups)
} }
// UpdateDefaultJWTIssuerFormat updates the JWT groups filter in the store.
func (s *Store) UpdateDefaultJWTIssuerFormat(format config.JWTIssuerFormat) {
// This isn't used by the Rego code, so we don't need to write it to the opastorage.Store instance.
s.defaultJWTIssuerFormat.Store(&format)
}
// UpdateRoutePolicies updates the route policies in the store. // UpdateRoutePolicies updates the route policies in the store.
func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) { func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) {
s.write("/route_policies", routePolicies) s.write("/route_policies", routePolicies)

View file

@ -23,6 +23,7 @@ import (
"github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/policy/parser" "github.com/pomerium/pomerium/pkg/policy/parser"
) )
@ -617,3 +618,50 @@ func (f JWTGroupsFilter) Equal(other JWTGroupsFilter) bool {
} }
return f.set.Equal(other.set) return f.set.Equal(other.set)
} }
type JWTIssuerFormat string
const (
JWTIssuerFormatUnset JWTIssuerFormat = ""
JWTIssuerFormatHostOnly JWTIssuerFormat = "hostOnly"
JWTIssuerFormatURI JWTIssuerFormat = "uri"
)
var knownJWTIssuerFormats = map[JWTIssuerFormat]struct{}{
JWTIssuerFormatUnset: {},
JWTIssuerFormatHostOnly: {},
JWTIssuerFormatURI: {},
}
func JWTIssuerFormatFromPB(format *configpb.IssuerFormat) JWTIssuerFormat {
if format == nil {
return JWTIssuerFormatUnset
}
switch *format {
case configpb.IssuerFormat_IssuerHostOnly:
return JWTIssuerFormatHostOnly
case configpb.IssuerFormat_IssuerURI:
return JWTIssuerFormatURI
default:
return JWTIssuerFormatUnset
}
}
func (f JWTIssuerFormat) ToPB() *configpb.IssuerFormat {
switch f {
case JWTIssuerFormatUnset:
return nil
case JWTIssuerFormatHostOnly:
return configpb.IssuerFormat_IssuerHostOnly.Enum()
case JWTIssuerFormatURI:
return configpb.IssuerFormat_IssuerURI.Enum()
default:
return nil
}
}
func (f JWTIssuerFormat) Valid() bool {
_, ok := knownJWTIssuerFormats[f]
return ok
}

View file

@ -196,6 +196,12 @@ type Options struct {
// List of JWT claims to insert as x-pomerium-claim-* headers on proxied requests // List of JWT claims to insert as x-pomerium-claim-* headers on proxied requests
JWTClaimsHeaders JWTClaimHeaders `mapstructure:"jwt_claims_headers" yaml:"jwt_claims_headers,omitempty"` JWTClaimsHeaders JWTClaimHeaders `mapstructure:"jwt_claims_headers" yaml:"jwt_claims_headers,omitempty"`
// JWTIssuerFormat controls the default format of the 'iss' claim in JWTs passed to upstream services.
// Possible values:
// - "hostOnly" (default): Issuer strings will be the hostname of the route, with no scheme or trailing slash.
// - "uri": Issuer strings will be a complete URI, including the scheme and ending with a trailing slash.
JWTIssuerFormat JWTIssuerFormat `mapstructure:"jwt_issuer_format" yaml:"jwt_issuer_format,omitempty"`
// BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values: // BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values:
// - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium. // - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium.
// - "idp_access_token": The Bearer token will be interpreted as an IdP access token. // - "idp_access_token": The Bearer token will be interpreted as an IdP access token.
@ -761,6 +767,10 @@ func (o *Options) Validate() error {
} }
} }
if !o.JWTIssuerFormat.Valid() {
return fmt.Errorf("config: unsupported jwt_issuer_format value %q", o.JWTIssuerFormat)
}
return nil return nil
} }
@ -1514,6 +1524,9 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
if len(settings.JwtGroupsFilter) > 0 { if len(settings.JwtGroupsFilter) > 0 {
o.JWTGroupsFilter = NewJWTGroupsFilter(settings.JwtGroupsFilter) o.JWTGroupsFilter = NewJWTGroupsFilter(settings.JwtGroupsFilter)
} }
if f := JWTIssuerFormatFromPB(settings.JwtIssuerFormat); f != JWTIssuerFormatUnset {
o.JWTIssuerFormat = f
}
setDuration(&o.DefaultUpstreamTimeout, settings.DefaultUpstreamTimeout) setDuration(&o.DefaultUpstreamTimeout, settings.DefaultUpstreamTimeout)
set(&o.MetricsAddr, settings.MetricsAddress) set(&o.MetricsAddr, settings.MetricsAddress)
set(&o.MetricsBasicAuth, settings.MetricsBasicAuth) set(&o.MetricsBasicAuth, settings.MetricsBasicAuth)
@ -1624,6 +1637,7 @@ func (o *Options) ToProto() *config.Config {
settings.JwtClaimsHeaders = o.JWTClaimsHeaders settings.JwtClaimsHeaders = o.JWTClaimsHeaders
settings.BearerTokenFormat = o.BearerTokenFormat.ToPB() settings.BearerTokenFormat = o.BearerTokenFormat.ToPB()
settings.JwtGroupsFilter = o.JWTGroupsFilter.ToSlice() settings.JwtGroupsFilter = o.JWTGroupsFilter.ToSlice()
settings.JwtIssuerFormat = o.JWTIssuerFormat.ToPB()
copyOptionalDuration(&settings.DefaultUpstreamTimeout, o.DefaultUpstreamTimeout) copyOptionalDuration(&settings.DefaultUpstreamTimeout, o.DefaultUpstreamTimeout)
copySrcToOptionalDest(&settings.MetricsAddress, &o.MetricsAddr) copySrcToOptionalDest(&settings.MetricsAddress, &o.MetricsAddr)
copySrcToOptionalDest(&settings.MetricsBasicAuth, &o.MetricsBasicAuth) copySrcToOptionalDest(&settings.MetricsBasicAuth, &o.MetricsBasicAuth)

View file

@ -167,7 +167,7 @@ type Policy struct {
// Possible values: // Possible values:
// - "hostOnly" (default): Issuer strings will be the hostname of the route, with no scheme or trailing slash. // - "hostOnly" (default): Issuer strings will be the hostname of the route, with no scheme or trailing slash.
// - "uri": Issuer strings will be a complete URI, including the scheme and ending with a trailing slash. // - "uri": Issuer strings will be a complete URI, including the scheme and ending with a trailing slash.
JWTIssuerFormat string `mapstructure:"jwt_issuer_format" yaml:"jwt_issuer_format,omitempty"` JWTIssuerFormat JWTIssuerFormat `mapstructure:"jwt_issuer_format" yaml:"jwt_issuer_format,omitempty"`
// BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values: // BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values:
// - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium // - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium
// - "idp_access_token": The Bearer token will be interpreted as an IdP access token. // - "idp_access_token": The Bearer token will be interpreted as an IdP access token.
@ -309,6 +309,7 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
IDPClientID: pb.GetIdpClientId(), IDPClientID: pb.GetIdpClientId(),
IDPClientSecret: pb.GetIdpClientSecret(), IDPClientSecret: pb.GetIdpClientSecret(),
JWTGroupsFilter: NewJWTGroupsFilter(pb.JwtGroupsFilter), JWTGroupsFilter: NewJWTGroupsFilter(pb.JwtGroupsFilter),
JWTIssuerFormat: JWTIssuerFormatFromPB(pb.JwtIssuerFormat),
KubernetesServiceAccountToken: pb.GetKubernetesServiceAccountToken(), KubernetesServiceAccountToken: pb.GetKubernetesServiceAccountToken(),
KubernetesServiceAccountTokenFile: pb.GetKubernetesServiceAccountTokenFile(), KubernetesServiceAccountTokenFile: pb.GetKubernetesServiceAccountTokenFile(),
LogoURL: pb.GetLogoUrl(), LogoURL: pb.GetLogoUrl(),
@ -390,9 +391,9 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
switch pb.GetJwtIssuerFormat() { switch pb.GetJwtIssuerFormat() {
case configpb.IssuerFormat_IssuerHostOnly: case configpb.IssuerFormat_IssuerHostOnly:
p.JWTIssuerFormat = "hostOnly" p.JWTIssuerFormat = JWTIssuerFormatHostOnly
case configpb.IssuerFormat_IssuerURI: case configpb.IssuerFormat_IssuerURI:
p.JWTIssuerFormat = "uri" p.JWTIssuerFormat = JWTIssuerFormatURI
} }
p.BearerTokenFormat = BearerTokenFormatFromPB(pb.BearerTokenFormat) p.BearerTokenFormat = BearerTokenFormatFromPB(pb.BearerTokenFormat)
@ -468,6 +469,7 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
Id: p.ID, Id: p.ID,
IdleTimeout: idleTimeout, IdleTimeout: idleTimeout,
JwtGroupsFilter: p.JWTGroupsFilter.ToSlice(), JwtGroupsFilter: p.JWTGroupsFilter.ToSlice(),
JwtIssuerFormat: p.JWTIssuerFormat.ToPB(),
KubernetesServiceAccountToken: p.KubernetesServiceAccountToken, KubernetesServiceAccountToken: p.KubernetesServiceAccountToken,
KubernetesServiceAccountTokenFile: p.KubernetesServiceAccountTokenFile, KubernetesServiceAccountTokenFile: p.KubernetesServiceAccountTokenFile,
LogoUrl: p.LogoURL, LogoUrl: p.LogoURL,
@ -555,13 +557,6 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
pb.LoadBalancingWeights = weights pb.LoadBalancingWeights = weights
} }
switch p.JWTIssuerFormat {
case "", "hostOnly":
pb.JwtIssuerFormat = configpb.IssuerFormat_IssuerHostOnly
case "uri":
pb.JwtIssuerFormat = configpb.IssuerFormat_IssuerURI
}
pb.BearerTokenFormat = p.BearerTokenFormat.ToPB() pb.BearerTokenFormat = p.BearerTokenFormat.ToPB()
for _, rwh := range p.RewriteResponseHeaders { for _, rwh := range p.RewriteResponseHeaders {
@ -698,6 +693,10 @@ func (p *Policy) Validate() error {
p.compiledRegex, _ = regexp.Compile(rawRE) p.compiledRegex, _ = regexp.Compile(rawRE)
} }
if !p.JWTIssuerFormat.Valid() {
return fmt.Errorf("config: unsupported jwt_issuer_format value %q", p.JWTIssuerFormat)
}
return nil return nil
} }

File diff suppressed because it is too large Load diff

View file

@ -123,7 +123,7 @@ message Route {
string kubernetes_service_account_token = 26; string kubernetes_service_account_token = 26;
string kubernetes_service_account_token_file = 64; string kubernetes_service_account_token_file = 64;
bool enable_google_cloud_serverless_authentication = 42; bool enable_google_cloud_serverless_authentication = 42;
IssuerFormat jwt_issuer_format = 65; optional IssuerFormat jwt_issuer_format = 65;
repeated string jwt_groups_filter = 66; repeated string jwt_groups_filter = 66;
optional BearerTokenFormat bearer_token_format = 70; optional BearerTokenFormat bearer_token_format = 70;
@ -160,7 +160,7 @@ message Policy {
string remediation = 9; string remediation = 9;
} }
// Next ID: 139. // Next ID: 140.
message Settings { message Settings {
message Certificate { message Certificate {
bytes cert_bytes = 3; bytes cert_bytes = 3;
@ -214,6 +214,7 @@ message Settings {
map<string, string> set_response_headers = 69; map<string, string> set_response_headers = 69;
// repeated string jwt_claims_headers = 37; // repeated string jwt_claims_headers = 37;
map<string, string> jwt_claims_headers = 63; map<string, string> jwt_claims_headers = 63;
optional IssuerFormat jwt_issuer_format = 139;
repeated string jwt_groups_filter = 119; repeated string jwt_groups_filter = 119;
optional BearerTokenFormat bearer_token_format = 138; optional BearerTokenFormat bearer_token_format = 138;
optional google.protobuf.Duration default_upstream_timeout = 39; optional google.protobuf.Duration default_upstream_timeout = 39;