diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index ed7abb614..c563572c9 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -295,12 +295,7 @@ func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyRe } func (e *Evaluator) evaluateHeaders(ctx context.Context, req *Request) (*HeadersResponse, error) { - headersReq, err := NewHeadersRequestFromPolicy(req.Policy, req.HTTP) - if err != nil { - return nil, err - } - headersReq.Session = req.Session - res, err := e.headersEvaluators.Evaluate(ctx, headersReq) + res, err := e.headersEvaluators.Evaluate(ctx, req) if err != nil { return nil, err } diff --git a/authorize/evaluator/headers_evaluator.go b/authorize/evaluator/headers_evaluator.go index 2ceeb5e6f..e5a706917 100644 --- a/authorize/evaluator/headers_evaluator.go +++ b/authorize/evaluator/headers_evaluator.go @@ -2,65 +2,15 @@ package evaluator import ( "context" - "fmt" "net/http" "time" - envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" "github.com/open-policy-agent/opa/rego" "github.com/pomerium/pomerium/authorize/internal/store" - "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/telemetry/trace" ) -// HeadersRequest is the input to the headers.rego script. -type HeadersRequest struct { - EnableGoogleCloudServerlessAuthentication bool `json:"enable_google_cloud_serverless_authentication"` - EnableRoutingKey bool `json:"enable_routing_key"` - Issuer string `json:"issuer"` - Audience string `json:"audience"` - KubernetesServiceAccountToken string `json:"kubernetes_service_account_token"` - ToAudience string `json:"to_audience"` - Session RequestSession `json:"session"` - ClientCertificate ClientCertificateInfo `json:"client_certificate"` - SetRequestHeaders map[string]string `json:"set_request_headers"` -} - -// NewHeadersRequestFromPolicy creates a new HeadersRequest from a policy. -func NewHeadersRequestFromPolicy(policy *config.Policy, http RequestHTTP) (*HeadersRequest, error) { - input := new(HeadersRequest) - input.Audience = http.Hostname - var issuerFormat string - if policy != nil { - issuerFormat = policy.JWTIssuerFormat - } - switch issuerFormat { - case "", "hostOnly": - input.Issuer = http.Hostname - case "uri": - input.Issuer = fmt.Sprintf("https://%s/", http.Hostname) - default: - return nil, fmt.Errorf("invalid issuer format: %q", policy.JWTIssuerFormat) - } - if policy != nil { - input.EnableGoogleCloudServerlessAuthentication = policy.EnableGoogleCloudServerlessAuthentication - input.EnableRoutingKey = policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_RING_HASH || - policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_MAGLEV - var err error - input.KubernetesServiceAccountToken, err = policy.GetKubernetesServiceAccountToken() - if err != nil { - return nil, err - } - for _, wu := range policy.To { - input.ToAudience = "https://" + wu.URL.Hostname() - } - input.ClientCertificate = http.ClientCertificate - input.SetRequestHeaders = policy.SetRequestHeaders - } - return input, nil -} - // HeadersResponse is the output from the headers.rego script. type HeadersResponse struct { Headers http.Header @@ -79,7 +29,7 @@ func NewHeadersEvaluator(store *store.Store) *HeadersEvaluator { } // Evaluate evaluates the headers.rego script. -func (e *HeadersEvaluator) Evaluate(ctx context.Context, req *HeadersRequest, options ...rego.EvalOption) (*HeadersResponse, error) { +func (e *HeadersEvaluator) Evaluate(ctx context.Context, req *Request, options ...rego.EvalOption) (*HeadersResponse, error) { ctx, span := trace.StartSpan(ctx, "authorize.HeadersEvaluator.Evaluate") defer span.End() diff --git a/authorize/evaluator/headers_evaluator_evaluation.go b/authorize/evaluator/headers_evaluator_evaluation.go index c5f9c30a1..3604919a3 100644 --- a/authorize/evaluator/headers_evaluator_evaluation.go +++ b/authorize/evaluator/headers_evaluator_evaluation.go @@ -12,6 +12,7 @@ import ( "strings" "time" + envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" "github.com/go-jose/go-jose/v3" "github.com/google/uuid" "google.golang.org/protobuf/types/known/structpb" @@ -26,7 +27,7 @@ import ( // A headersEvaluatorEvaluation is a single evaluation of the headers evaluator. type headersEvaluatorEvaluation struct { evaluator *HeadersEvaluator - request *HeadersRequest + request *Request response *HeadersResponse now time.Time @@ -50,7 +51,7 @@ type headersEvaluatorEvaluation struct { cachedSignedJWT string } -func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *HeadersRequest, now time.Time) *headersEvaluatorEvaluation { +func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Request, now time.Time) *headersEvaluatorEvaluation { return &headersEvaluatorEvaluation{ evaluator: evaluator, request: request, @@ -60,16 +61,19 @@ func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Headers } func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) { - e.fillHeaders(ctx) - return e.response, nil + err := e.fillHeaders(ctx) + return e.response, err } func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) { e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx)) } -func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) { - claims := e.getJWTPayload(ctx) +func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) error { + claims, err := e.getJWTPayload(ctx) + if err != nil { + return err + } for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() { claim, ok := claims[claimKey] if !ok { @@ -78,14 +82,20 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) { } e.response.Headers.Add(headerName, getHeaderStringValue(claim)) } + return nil } func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) { - if e.request.KubernetesServiceAccountToken == "" { + if e.request.Policy == nil { return } - e.response.Headers.Add("Authorization", "Bearer "+e.request.KubernetesServiceAccountToken) + token, err := e.request.Policy.GetKubernetesServiceAccountToken() + if err != nil || token == "" { + return + } + + e.response.Headers.Add("Authorization", "Bearer "+token) impersonateUser := e.getJWTPayloadEmail(ctx) if impersonateUser != "" { e.response.Headers.Add("Impersonate-User", impersonateUser) @@ -97,26 +107,42 @@ func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) } func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) { - if e.request.EnableGoogleCloudServerlessAuthentication { - h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), e.request.ToAudience) - if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers") - return - } - for k, v := range h { - e.response.Headers.Add(k, v) - } + if e.request.Policy == nil || !e.request.Policy.EnableGoogleCloudServerlessAuthentication { + return + } + + var toAudience string + for _, wu := range e.request.Policy.To { + toAudience = "https://" + wu.URL.Hostname() + } + + h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), toAudience) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers") + return + } + for k, v := range h { + e.response.Headers.Add(k, v) } } func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() { - if e.request.EnableRoutingKey { + if e.request.Policy == nil { + return + } + + if e.request.Policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_RING_HASH || + e.request.Policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_MAGLEV { e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID)) } } func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) { - for k, v := range e.request.SetRequestHeaders { + if e.request.Policy == nil { + return + } + + for k, v := range e.request.Policy.SetRequestHeaders { e.response.Headers.Add(k, os.Expand(v, func(name string) string { switch name { case "$": @@ -138,13 +164,16 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) } } -func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) { +func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error { e.fillJWTAssertionHeader(ctx) - e.fillJWTClaimHeaders(ctx) + if err := e.fillJWTClaimHeaders(ctx); err != nil { + return err + } e.fillKubernetesHeaders(ctx) e.fillGoogleCloudServerlessHeaders(ctx) e.fillRoutingKeyHeaders() e.fillSetRequestHeaders(ctx) + return nil } func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) { @@ -182,7 +211,7 @@ func (e *headersEvaluatorEvaluation) getUser(ctx context.Context) *user.User { } func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string { - cert, err := cryptutil.ParsePEMCertificate([]byte(e.request.ClientCertificate.Leaf)) + cert, err := cryptutil.ParsePEMCertificate([]byte(e.request.HTTP.ClientCertificate.Leaf)) if err != nil { return "" } @@ -212,12 +241,23 @@ func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string { return make([]string, 0) } -func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string { - return e.request.Issuer +func (e *headersEvaluatorEvaluation) getJWTPayloadIss() (string, error) { + var issuerFormat string + if e.request.Policy != nil { + issuerFormat = e.request.Policy.JWTIssuerFormat + } + switch issuerFormat { + case "uri": + return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname), nil + case "", "hostOnly": + return e.request.HTTP.Hostname, nil + default: + return "", fmt.Errorf("unsupported JWT issuer format: %s", issuerFormat) + } } func (e *headersEvaluatorEvaluation) getJWTPayloadAud() string { - return e.request.Audience + return e.request.HTTP.Hostname } func (e *headersEvaluatorEvaluation) getJWTPayloadJTI() string { @@ -307,14 +347,19 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) stri return "" } -func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[string]any { +func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) (map[string]any, error) { if e.gotJWTPayload { - return e.cachedJWTPayload + return e.cachedJWTPayload, nil + } + + iss, err := e.getJWTPayloadIss() + if err != nil { + return nil, err } e.gotJWTPayload = true e.cachedJWTPayload = map[string]any{ - "iss": e.getJWTPayloadIss(), + "iss": iss, "aud": e.getJWTPayloadAud(), "jti": e.getJWTPayloadJTI(), "iat": e.getJWTPayloadIAT(), @@ -342,7 +387,7 @@ func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[stri e.cachedJWTPayload[claimKey] = strings.Join(vs, ",") } } - return e.cachedJWTPayload + return e.cachedJWTPayload, nil } func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string { @@ -371,7 +416,11 @@ func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string { return "" } - jwtPayload := e.getJWTPayload(ctx) + jwtPayload, err := e.getJWTPayload(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT payload") + return "" + } bs, err := json.Marshal(jwtPayload) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload") diff --git a/authorize/evaluator/headers_evaluator_test.go b/authorize/evaluator/headers_evaluator_test.go index 3e365695a..a5bf3c1d3 100644 --- a/authorize/evaluator/headers_evaluator_test.go +++ b/authorize/evaluator/headers_evaluator_test.go @@ -7,10 +7,12 @@ import ( "encoding/json" "fmt" "math" + "net/http" "strings" "testing" "time" + envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" "github.com/go-jose/go-jose/v3/jwt" "github.com/open-policy-agent/opa/rego" "github.com/stretchr/testify/assert" @@ -62,22 +64,23 @@ func BenchmarkHeadersEvaluator(b *testing.B) { e := NewHeadersEvaluator(s) - req := &HeadersRequest{ - EnableRoutingKey: true, - Issuer: "from.example.com", - Audience: "from.example.com", - KubernetesServiceAccountToken: "KUBERNETES_SERVICE_ACCOUNT_TOKEN", - ToAudience: "to.example.com", + req := &Request{ + HTTP: RequestHTTP{ + Method: "GET", + Hostname: "from.example.com", + }, + Policy: &config.Policy{ + SetRequestHeaders: map[string]string{ + "X-Custom-Header": "CUSTOM_VALUE", + "X-ID-Token": "${pomerium.id_token}", + "X-Access-Token": "${pomerium.access_token}", + "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", + "Authorization": "Bearer ${pomerium.jwt}", + }, + }, Session: RequestSession{ ID: "s1", }, - SetRequestHeaders: map[string]string{ - "X-Custom-Header": "CUSTOM_VALUE", - "X-ID-Token": "${pomerium.id_token}", - "X-Access-Token": "${pomerium.access_token}", - "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", - "Authorization": "Bearer ${pomerium.jwt}", - }, } b.ResetTimer() for i := 0; i < b.N; i++ { @@ -87,99 +90,6 @@ func BenchmarkHeadersEvaluator(b *testing.B) { } } -func TestNewHeadersRequestFromPolicy(t *testing.T) { - req, _ := NewHeadersRequestFromPolicy(&config.Policy{ - EnableGoogleCloudServerlessAuthentication: true, - From: "https://*.example.com", - To: config.WeightedURLs{ - { - URL: *mustParseURL("http://to.example.com"), - }, - }, - }, RequestHTTP{ - Hostname: "from.example.com", - ClientCertificate: ClientCertificateInfo{ - Leaf: "--- FAKE CERTIFICATE ---", - }, - }) - assert.Equal(t, &HeadersRequest{ - EnableGoogleCloudServerlessAuthentication: true, - Issuer: "from.example.com", - Audience: "from.example.com", - ToAudience: "https://to.example.com", - ClientCertificate: ClientCertificateInfo{ - Leaf: "--- FAKE CERTIFICATE ---", - }, - }, req) -} - -func TestNewHeadersRequestFromPolicy_IssuerFormat(t *testing.T) { - policy := &config.Policy{ - EnableGoogleCloudServerlessAuthentication: true, - From: "https://*.example.com", - To: config.WeightedURLs{ - { - URL: *mustParseURL("http://to.example.com"), - }, - }, - } - for _, tc := range []struct { - format string - expectedIssuer string - expectedAudience string - err string - }{ - { - format: "", - expectedIssuer: "from.example.com", - expectedAudience: "from.example.com", - }, - { - format: "hostOnly", - expectedIssuer: "from.example.com", - expectedAudience: "from.example.com", - }, - { - format: "uri", - expectedIssuer: "https://from.example.com/", - expectedAudience: "from.example.com", - }, - { - format: "foo", - err: `invalid issuer format: "foo"`, - }, - } { - policy.JWTIssuerFormat = tc.format - req, err := NewHeadersRequestFromPolicy(policy, RequestHTTP{ - Hostname: "from.example.com", - ClientCertificate: ClientCertificateInfo{ - Leaf: "--- FAKE CERTIFICATE ---", - }, - }) - if tc.err != "" { - assert.ErrorContains(t, err, tc.err) - } else { - assert.Equal(t, &HeadersRequest{ - EnableGoogleCloudServerlessAuthentication: true, - Issuer: tc.expectedIssuer, - Audience: tc.expectedAudience, - ToAudience: "https://to.example.com", - ClientCertificate: ClientCertificateInfo{ - Leaf: "--- FAKE CERTIFICATE ---", - }, - }, req) - } - } -} - -func TestNewHeadersRequestFromPolicy_nil(t *testing.T) { - req, _ := NewHeadersRequestFromPolicy(nil, RequestHTTP{Hostname: "from.example.com"}) - assert.Equal(t, &HeadersRequest{ - Issuer: "from.example.com", - Audience: "from.example.com", - }, req) -} - func TestHeadersEvaluator(t *testing.T) { t.Parallel() @@ -197,7 +107,7 @@ func TestHeadersEvaluator(t *testing.T) { iat := time.Unix(1686870680, 0) - eval := func(_ *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) { + eval := func(_ *testing.T, data []proto.Message, input *Request) (*HeadersResponse, error) { ctx := context.Background() ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...)) store := store.New() @@ -232,10 +142,11 @@ func TestHeadersEvaluator(t *testing.T) { newDirectoryGroupRecord(directory.Group{ID: "g3", Name: "GROUP3", Email: "g3@example.com"}), newDirectoryGroupRecord(directory.Group{ID: "g4", Name: "GROUP4", Email: "g4@example.com"}), }, - &HeadersRequest{ - Issuer: "from.example.com", - Audience: "from.example.com", - ToAudience: "to.example.com", + &Request{ + HTTP: RequestHTTP{ + Hostname: "from.example.com", + }, + Policy: &config.Policy{}, Session: RequestSession{ ID: "s1", }, @@ -292,7 +203,7 @@ func TestHeadersEvaluator(t *testing.T) { }}, }}, }, - &HeadersRequest{ + &Request{ Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) @@ -312,20 +223,22 @@ func TestHeadersEvaluator(t *testing.T) { AccessToken: "ACCESS_TOKEN", }}, }, - &HeadersRequest{ - Issuer: "from.example.com", - Audience: "from.example.com", - ToAudience: "to.example.com", - Session: RequestSession{ID: "s1"}, - SetRequestHeaders: map[string]string{ - "X-Custom-Header": "CUSTOM_VALUE", - "X-ID-Token": "${pomerium.id_token}", - "X-Access-Token": "${pomerium.access_token}", - "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", - "Authorization": "Bearer ${pomerium.jwt}", - "Foo": "escaped $$dollar sign", + &Request{ + HTTP: RequestHTTP{ + Hostname: "from.example.com", + ClientCertificate: ClientCertificateInfo{Leaf: testValidCert}, }, - ClientCertificate: ClientCertificateInfo{Leaf: testValidCert}, + Policy: &config.Policy{ + SetRequestHeaders: map[string]string{ + "X-Custom-Header": "CUSTOM_VALUE", + "X-ID-Token": "${pomerium.id_token}", + "X-Access-Token": "${pomerium.access_token}", + "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", + "Authorization": "Bearer ${pomerium.jwt}", + "Foo": "escaped $$dollar sign", + }, + }, + Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) @@ -355,13 +268,12 @@ func TestHeadersEvaluator(t *testing.T) { AccessToken: "ACCESS_TOKEN", }}, }, - &HeadersRequest{ - Issuer: "from.example.com", - Audience: "from.example.com", - ToAudience: "to.example.com", - Session: RequestSession{ID: "s1"}, - SetRequestHeaders: map[string]string{ - "X-ID-Token": "${pomerium.id_token}", + &Request{ + Session: RequestSession{ID: "s1"}, + Policy: &config.Policy{ + SetRequestHeaders: map[string]string{ + "X-ID-Token": "${pomerium.id_token}", + }, }, }) require.NoError(t, err) @@ -378,14 +290,13 @@ func TestHeadersEvaluator(t *testing.T) { AccessToken: "ACCESS_TOKEN", }}, }, - &HeadersRequest{ - Issuer: "from.example.com", - Audience: "from.example.com", - ToAudience: "to.example.com", - Session: RequestSession{ID: "s1"}, - SetRequestHeaders: map[string]string{ - "Authorization": "Bearer ${pomerium.id_token}", + &Request{ + Policy: &config.Policy{ + SetRequestHeaders: map[string]string{ + "Authorization": "Bearer ${pomerium.id_token}", + }, }, + Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) @@ -394,12 +305,11 @@ func TestHeadersEvaluator(t *testing.T) { t.Run("set_request_headers no client cert", func(t *testing.T) { output, err := eval(t, nil, - &HeadersRequest{ - Issuer: "from.example.com", - Audience: "from.example.com", - ToAudience: "to.example.com", - SetRequestHeaders: map[string]string{ - "fingerprint": "${pomerium.client_cert_fingerprint}", + &Request{ + Policy: &config.Policy{ + SetRequestHeaders: map[string]string{ + "fingerprint": "${pomerium.client_cert_fingerprint}", + }, }, }) require.NoError(t, err) @@ -427,12 +337,11 @@ func TestHeadersEvaluator(t *testing.T) { Name: "GROUP2", }), }, - &HeadersRequest{ - Issuer: "from.example.com", - Audience: "from.example.com", - ToAudience: "to.example.com", - KubernetesServiceAccountToken: "TOKEN", - Session: RequestSession{ID: "s1"}, + &Request{ + Policy: &config.Policy{ + KubernetesServiceAccountToken: "TOKEN", + }, + Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) assert.Equal(t, "Bearer TOKEN", output.Headers.Get("Authorization")) @@ -445,18 +354,21 @@ func TestHeadersEvaluator(t *testing.T) { output, err := eval(t, []protoreflect.ProtoMessage{}, - &HeadersRequest{ - EnableRoutingKey: false, - Session: RequestSession{ID: "s1"}, + &Request{ + Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) assert.Empty(t, output.Headers.Get("X-Pomerium-Routing-Key")) output, err = eval(t, []protoreflect.ProtoMessage{}, - &HeadersRequest{ - EnableRoutingKey: true, - Session: RequestSession{ID: "s1"}, + &Request{ + Policy: &config.Policy{ + EnvoyOpts: &envoy_config_cluster_v3.Cluster{ + LbPolicy: envoy_config_cluster_v3.Cluster_MAGLEV, + }, + }, + Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) assert.Equal(t, "e8bc163c82eee18733288c7d4ac636db3a6deb013ef2d37b68322be20edc45cc", output.Headers.Get("X-Pomerium-Routing-Key")) @@ -470,7 +382,7 @@ func TestHeadersEvaluator(t *testing.T) { &session.Session{Id: "s1", UserId: "u1"}, &user.User{Id: "u1", Email: "user@example.com"}, }, - &HeadersRequest{ + &Request{ Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) @@ -481,7 +393,7 @@ func TestHeadersEvaluator(t *testing.T) { &session.Session{Id: "s1", UserId: "u1"}, newDirectoryUserRecord(directory.User{ID: "u1", Email: "directory-user@example.com"}), }, - &HeadersRequest{ + &Request{ Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) @@ -498,7 +410,7 @@ func TestHeadersEvaluator(t *testing.T) { }}, }}, }, - &HeadersRequest{ + &Request{ Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) @@ -513,7 +425,7 @@ func TestHeadersEvaluator(t *testing.T) { }}, }}, }, - &HeadersRequest{ + &Request{ Session: RequestSession{ID: "s1"}, }) require.NoError(t, err) @@ -528,12 +440,54 @@ func TestHeadersEvaluator(t *testing.T) { &user.ServiceAccount{Id: "sa1", UserId: "u1"}, &user.User{Id: "u1", Email: "u1@example.com"}, }, - &HeadersRequest{ + &Request{ Session: RequestSession{ID: "sa1"}, }) require.NoError(t, err) assert.Equal(t, "u1@example.com", output.Headers.Get("X-Pomerium-Claim-Email")) }) + + t.Run("issuer format", func(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + format string + input string + output string + }{ + {"", "example.com", "example.com"}, + {"hostOnly", "host-only.example.com", "host-only.example.com"}, + {"uri", "uri.example.com", "https://uri.example.com/"}, + } { + output, err := eval(t, + nil, + &Request{ + HTTP: RequestHTTP{ + Hostname: tc.input, + }, + Policy: &config.Policy{ + JWTIssuerFormat: tc.format, + }, + }) + require.NoError(t, err) + m := decodeJWTAssertion(t, output.Headers) + assert.Equal(t, tc.output, m["iss"], "unexpected issuer for format=%s", tc.format) + } + }) +} + +func decodeJWTAssertion(t *testing.T, headers http.Header) map[string]any { + jwtHeader := headers.Get("X-Pomerium-Jwt-Assertion") + // Make sure the 'iat' and 'exp' claims can be parsed as an integer. We + // need to do some explicit decoding in order to be able to verify + // this, as by default json.Unmarshal() will make no distinction + // between numeric formats. + d := json.NewDecoder(bytes.NewReader(decodeJWSPayload(t, jwtHeader))) + d.UseNumber() + var m map[string]any + err := d.Decode(&m) + require.NoError(t, err) + return m } func decodeJWSPayload(t *testing.T, jws string) []byte {