cleanup headers (#5408)

* cleanup headers

* return issuer format errors

* go mod
This commit is contained in:
Caleb Doxsey 2025-01-06 09:52:29 -07:00 committed by GitHub
parent 8f36870650
commit fb7b61a677
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 197 additions and 249 deletions

View file

@ -295,12 +295,7 @@ func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyRe
} }
func (e *Evaluator) evaluateHeaders(ctx context.Context, req *Request) (*HeadersResponse, error) { func (e *Evaluator) evaluateHeaders(ctx context.Context, req *Request) (*HeadersResponse, error) {
headersReq, err := NewHeadersRequestFromPolicy(req.Policy, req.HTTP) res, err := e.headersEvaluators.Evaluate(ctx, req)
if err != nil {
return nil, err
}
headersReq.Session = req.Session
res, err := e.headersEvaluators.Evaluate(ctx, headersReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -2,65 +2,15 @@ package evaluator
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"time" "time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
) )
// HeadersRequest is the input to the headers.rego script.
type HeadersRequest struct {
EnableGoogleCloudServerlessAuthentication bool `json:"enable_google_cloud_serverless_authentication"`
EnableRoutingKey bool `json:"enable_routing_key"`
Issuer string `json:"issuer"`
Audience string `json:"audience"`
KubernetesServiceAccountToken string `json:"kubernetes_service_account_token"`
ToAudience string `json:"to_audience"`
Session RequestSession `json:"session"`
ClientCertificate ClientCertificateInfo `json:"client_certificate"`
SetRequestHeaders map[string]string `json:"set_request_headers"`
}
// NewHeadersRequestFromPolicy creates a new HeadersRequest from a policy.
func NewHeadersRequestFromPolicy(policy *config.Policy, http RequestHTTP) (*HeadersRequest, error) {
input := new(HeadersRequest)
input.Audience = http.Hostname
var issuerFormat string
if policy != nil {
issuerFormat = policy.JWTIssuerFormat
}
switch issuerFormat {
case "", "hostOnly":
input.Issuer = http.Hostname
case "uri":
input.Issuer = fmt.Sprintf("https://%s/", http.Hostname)
default:
return nil, fmt.Errorf("invalid issuer format: %q", policy.JWTIssuerFormat)
}
if policy != nil {
input.EnableGoogleCloudServerlessAuthentication = policy.EnableGoogleCloudServerlessAuthentication
input.EnableRoutingKey = policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_RING_HASH ||
policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_MAGLEV
var err error
input.KubernetesServiceAccountToken, err = policy.GetKubernetesServiceAccountToken()
if err != nil {
return nil, err
}
for _, wu := range policy.To {
input.ToAudience = "https://" + wu.URL.Hostname()
}
input.ClientCertificate = http.ClientCertificate
input.SetRequestHeaders = policy.SetRequestHeaders
}
return input, nil
}
// HeadersResponse is the output from the headers.rego script. // HeadersResponse is the output from the headers.rego script.
type HeadersResponse struct { type HeadersResponse struct {
Headers http.Header Headers http.Header
@ -79,7 +29,7 @@ func NewHeadersEvaluator(store *store.Store) *HeadersEvaluator {
} }
// Evaluate evaluates the headers.rego script. // Evaluate evaluates the headers.rego script.
func (e *HeadersEvaluator) Evaluate(ctx context.Context, req *HeadersRequest, options ...rego.EvalOption) (*HeadersResponse, error) { func (e *HeadersEvaluator) Evaluate(ctx context.Context, req *Request, options ...rego.EvalOption) (*HeadersResponse, error) {
ctx, span := trace.StartSpan(ctx, "authorize.HeadersEvaluator.Evaluate") ctx, span := trace.StartSpan(ctx, "authorize.HeadersEvaluator.Evaluate")
defer span.End() defer span.End()

View file

@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
"github.com/google/uuid" "github.com/google/uuid"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
@ -26,7 +27,7 @@ import (
// A headersEvaluatorEvaluation is a single evaluation of the headers evaluator. // A headersEvaluatorEvaluation is a single evaluation of the headers evaluator.
type headersEvaluatorEvaluation struct { type headersEvaluatorEvaluation struct {
evaluator *HeadersEvaluator evaluator *HeadersEvaluator
request *HeadersRequest request *Request
response *HeadersResponse response *HeadersResponse
now time.Time now time.Time
@ -50,7 +51,7 @@ type headersEvaluatorEvaluation struct {
cachedSignedJWT string cachedSignedJWT string
} }
func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *HeadersRequest, now time.Time) *headersEvaluatorEvaluation { func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Request, now time.Time) *headersEvaluatorEvaluation {
return &headersEvaluatorEvaluation{ return &headersEvaluatorEvaluation{
evaluator: evaluator, evaluator: evaluator,
request: request, request: request,
@ -60,16 +61,19 @@ func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Headers
} }
func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) { func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) {
e.fillHeaders(ctx) err := e.fillHeaders(ctx)
return e.response, nil return e.response, err
} }
func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) {
e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx)) e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx))
} }
func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) error {
claims := e.getJWTPayload(ctx) claims, err := e.getJWTPayload(ctx)
if err != nil {
return err
}
for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() { for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
claim, ok := claims[claimKey] claim, ok := claims[claimKey]
if !ok { if !ok {
@ -78,14 +82,20 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) {
} }
e.response.Headers.Add(headerName, getHeaderStringValue(claim)) e.response.Headers.Add(headerName, getHeaderStringValue(claim))
} }
return nil
} }
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
if e.request.KubernetesServiceAccountToken == "" { if e.request.Policy == nil {
return return
} }
e.response.Headers.Add("Authorization", "Bearer "+e.request.KubernetesServiceAccountToken) token, err := e.request.Policy.GetKubernetesServiceAccountToken()
if err != nil || token == "" {
return
}
e.response.Headers.Add("Authorization", "Bearer "+token)
impersonateUser := e.getJWTPayloadEmail(ctx) impersonateUser := e.getJWTPayloadEmail(ctx)
if impersonateUser != "" { if impersonateUser != "" {
e.response.Headers.Add("Impersonate-User", impersonateUser) e.response.Headers.Add("Impersonate-User", impersonateUser)
@ -97,8 +107,16 @@ func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context)
} }
func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) {
if e.request.EnableGoogleCloudServerlessAuthentication { if e.request.Policy == nil || !e.request.Policy.EnableGoogleCloudServerlessAuthentication {
h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), e.request.ToAudience) return
}
var toAudience string
for _, wu := range e.request.Policy.To {
toAudience = "https://" + wu.URL.Hostname()
}
h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), toAudience)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers")
return return
@ -107,16 +125,24 @@ func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx contex
e.response.Headers.Add(k, v) e.response.Headers.Add(k, v)
} }
} }
}
func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() { func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() {
if e.request.EnableRoutingKey { if e.request.Policy == nil {
return
}
if e.request.Policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_RING_HASH ||
e.request.Policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_MAGLEV {
e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID)) e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID))
} }
} }
func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) {
for k, v := range e.request.SetRequestHeaders { if e.request.Policy == nil {
return
}
for k, v := range e.request.Policy.SetRequestHeaders {
e.response.Headers.Add(k, os.Expand(v, func(name string) string { e.response.Headers.Add(k, os.Expand(v, func(name string) string {
switch name { switch name {
case "$": case "$":
@ -138,13 +164,16 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context)
} }
} }
func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error {
e.fillJWTAssertionHeader(ctx) e.fillJWTAssertionHeader(ctx)
e.fillJWTClaimHeaders(ctx) if err := e.fillJWTClaimHeaders(ctx); err != nil {
return err
}
e.fillKubernetesHeaders(ctx) e.fillKubernetesHeaders(ctx)
e.fillGoogleCloudServerlessHeaders(ctx) e.fillGoogleCloudServerlessHeaders(ctx)
e.fillRoutingKeyHeaders() e.fillRoutingKeyHeaders()
e.fillSetRequestHeaders(ctx) e.fillSetRequestHeaders(ctx)
return nil
} }
func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) { func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) {
@ -182,7 +211,7 @@ func (e *headersEvaluatorEvaluation) getUser(ctx context.Context) *user.User {
} }
func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string { func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string {
cert, err := cryptutil.ParsePEMCertificate([]byte(e.request.ClientCertificate.Leaf)) cert, err := cryptutil.ParsePEMCertificate([]byte(e.request.HTTP.ClientCertificate.Leaf))
if err != nil { if err != nil {
return "" return ""
} }
@ -212,12 +241,23 @@ func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string {
return make([]string, 0) return make([]string, 0)
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string { func (e *headersEvaluatorEvaluation) getJWTPayloadIss() (string, error) {
return e.request.Issuer var issuerFormat string
if e.request.Policy != nil {
issuerFormat = e.request.Policy.JWTIssuerFormat
}
switch issuerFormat {
case "uri":
return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname), nil
case "", "hostOnly":
return e.request.HTTP.Hostname, nil
default:
return "", fmt.Errorf("unsupported JWT issuer format: %s", issuerFormat)
}
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadAud() string { func (e *headersEvaluatorEvaluation) getJWTPayloadAud() string {
return e.request.Audience return e.request.HTTP.Hostname
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadJTI() string { func (e *headersEvaluatorEvaluation) getJWTPayloadJTI() string {
@ -307,14 +347,19 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) stri
return "" return ""
} }
func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[string]any { func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) (map[string]any, error) {
if e.gotJWTPayload { if e.gotJWTPayload {
return e.cachedJWTPayload return e.cachedJWTPayload, nil
}
iss, err := e.getJWTPayloadIss()
if err != nil {
return nil, err
} }
e.gotJWTPayload = true e.gotJWTPayload = true
e.cachedJWTPayload = map[string]any{ e.cachedJWTPayload = map[string]any{
"iss": e.getJWTPayloadIss(), "iss": iss,
"aud": e.getJWTPayloadAud(), "aud": e.getJWTPayloadAud(),
"jti": e.getJWTPayloadJTI(), "jti": e.getJWTPayloadJTI(),
"iat": e.getJWTPayloadIAT(), "iat": e.getJWTPayloadIAT(),
@ -342,7 +387,7 @@ func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[stri
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",") e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
} }
} }
return e.cachedJWTPayload return e.cachedJWTPayload, nil
} }
func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string { func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
@ -371,7 +416,11 @@ func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
return "" return ""
} }
jwtPayload := e.getJWTPayload(ctx) jwtPayload, err := e.getJWTPayload(ctx)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT payload")
return ""
}
bs, err := json.Marshal(jwtPayload) bs, err := json.Marshal(jwtPayload)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload")

View file

@ -7,10 +7,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
"net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v3/jwt"
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -62,15 +64,12 @@ func BenchmarkHeadersEvaluator(b *testing.B) {
e := NewHeadersEvaluator(s) e := NewHeadersEvaluator(s)
req := &HeadersRequest{ req := &Request{
EnableRoutingKey: true, HTTP: RequestHTTP{
Issuer: "from.example.com", Method: "GET",
Audience: "from.example.com", Hostname: "from.example.com",
KubernetesServiceAccountToken: "KUBERNETES_SERVICE_ACCOUNT_TOKEN",
ToAudience: "to.example.com",
Session: RequestSession{
ID: "s1",
}, },
Policy: &config.Policy{
SetRequestHeaders: map[string]string{ SetRequestHeaders: map[string]string{
"X-Custom-Header": "CUSTOM_VALUE", "X-Custom-Header": "CUSTOM_VALUE",
"X-ID-Token": "${pomerium.id_token}", "X-ID-Token": "${pomerium.id_token}",
@ -78,6 +77,10 @@ func BenchmarkHeadersEvaluator(b *testing.B) {
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
"Authorization": "Bearer ${pomerium.jwt}", "Authorization": "Bearer ${pomerium.jwt}",
}, },
},
Session: RequestSession{
ID: "s1",
},
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -87,99 +90,6 @@ func BenchmarkHeadersEvaluator(b *testing.B) {
} }
} }
func TestNewHeadersRequestFromPolicy(t *testing.T) {
req, _ := NewHeadersRequestFromPolicy(&config.Policy{
EnableGoogleCloudServerlessAuthentication: true,
From: "https://*.example.com",
To: config.WeightedURLs{
{
URL: *mustParseURL("http://to.example.com"),
},
},
}, RequestHTTP{
Hostname: "from.example.com",
ClientCertificate: ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
})
assert.Equal(t, &HeadersRequest{
EnableGoogleCloudServerlessAuthentication: true,
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "https://to.example.com",
ClientCertificate: ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
}, req)
}
func TestNewHeadersRequestFromPolicy_IssuerFormat(t *testing.T) {
policy := &config.Policy{
EnableGoogleCloudServerlessAuthentication: true,
From: "https://*.example.com",
To: config.WeightedURLs{
{
URL: *mustParseURL("http://to.example.com"),
},
},
}
for _, tc := range []struct {
format string
expectedIssuer string
expectedAudience string
err string
}{
{
format: "",
expectedIssuer: "from.example.com",
expectedAudience: "from.example.com",
},
{
format: "hostOnly",
expectedIssuer: "from.example.com",
expectedAudience: "from.example.com",
},
{
format: "uri",
expectedIssuer: "https://from.example.com/",
expectedAudience: "from.example.com",
},
{
format: "foo",
err: `invalid issuer format: "foo"`,
},
} {
policy.JWTIssuerFormat = tc.format
req, err := NewHeadersRequestFromPolicy(policy, RequestHTTP{
Hostname: "from.example.com",
ClientCertificate: ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
})
if tc.err != "" {
assert.ErrorContains(t, err, tc.err)
} else {
assert.Equal(t, &HeadersRequest{
EnableGoogleCloudServerlessAuthentication: true,
Issuer: tc.expectedIssuer,
Audience: tc.expectedAudience,
ToAudience: "https://to.example.com",
ClientCertificate: ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
}, req)
}
}
}
func TestNewHeadersRequestFromPolicy_nil(t *testing.T) {
req, _ := NewHeadersRequestFromPolicy(nil, RequestHTTP{Hostname: "from.example.com"})
assert.Equal(t, &HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
}, req)
}
func TestHeadersEvaluator(t *testing.T) { func TestHeadersEvaluator(t *testing.T) {
t.Parallel() t.Parallel()
@ -197,7 +107,7 @@ func TestHeadersEvaluator(t *testing.T) {
iat := time.Unix(1686870680, 0) iat := time.Unix(1686870680, 0)
eval := func(_ *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) { eval := func(_ *testing.T, data []proto.Message, input *Request) (*HeadersResponse, error) {
ctx := context.Background() ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...)) ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New() store := store.New()
@ -232,10 +142,11 @@ func TestHeadersEvaluator(t *testing.T) {
newDirectoryGroupRecord(directory.Group{ID: "g3", Name: "GROUP3", Email: "g3@example.com"}), newDirectoryGroupRecord(directory.Group{ID: "g3", Name: "GROUP3", Email: "g3@example.com"}),
newDirectoryGroupRecord(directory.Group{ID: "g4", Name: "GROUP4", Email: "g4@example.com"}), newDirectoryGroupRecord(directory.Group{ID: "g4", Name: "GROUP4", Email: "g4@example.com"}),
}, },
&HeadersRequest{ &Request{
Issuer: "from.example.com", HTTP: RequestHTTP{
Audience: "from.example.com", Hostname: "from.example.com",
ToAudience: "to.example.com", },
Policy: &config.Policy{},
Session: RequestSession{ Session: RequestSession{
ID: "s1", ID: "s1",
}, },
@ -292,7 +203,7 @@ func TestHeadersEvaluator(t *testing.T) {
}}, }},
}}, }},
}, },
&HeadersRequest{ &Request{
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -312,11 +223,12 @@ func TestHeadersEvaluator(t *testing.T) {
AccessToken: "ACCESS_TOKEN", AccessToken: "ACCESS_TOKEN",
}}, }},
}, },
&HeadersRequest{ &Request{
Issuer: "from.example.com", HTTP: RequestHTTP{
Audience: "from.example.com", Hostname: "from.example.com",
ToAudience: "to.example.com", ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
Session: RequestSession{ID: "s1"}, },
Policy: &config.Policy{
SetRequestHeaders: map[string]string{ SetRequestHeaders: map[string]string{
"X-Custom-Header": "CUSTOM_VALUE", "X-Custom-Header": "CUSTOM_VALUE",
"X-ID-Token": "${pomerium.id_token}", "X-ID-Token": "${pomerium.id_token}",
@ -325,7 +237,8 @@ func TestHeadersEvaluator(t *testing.T) {
"Authorization": "Bearer ${pomerium.jwt}", "Authorization": "Bearer ${pomerium.jwt}",
"Foo": "escaped $$dollar sign", "Foo": "escaped $$dollar sign",
}, },
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert}, },
Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -355,14 +268,13 @@ func TestHeadersEvaluator(t *testing.T) {
AccessToken: "ACCESS_TOKEN", AccessToken: "ACCESS_TOKEN",
}}, }},
}, },
&HeadersRequest{ &Request{
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "to.example.com",
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
Policy: &config.Policy{
SetRequestHeaders: map[string]string{ SetRequestHeaders: map[string]string{
"X-ID-Token": "${pomerium.id_token}", "X-ID-Token": "${pomerium.id_token}",
}, },
},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -378,14 +290,13 @@ func TestHeadersEvaluator(t *testing.T) {
AccessToken: "ACCESS_TOKEN", AccessToken: "ACCESS_TOKEN",
}}, }},
}, },
&HeadersRequest{ &Request{
Issuer: "from.example.com", Policy: &config.Policy{
Audience: "from.example.com",
ToAudience: "to.example.com",
Session: RequestSession{ID: "s1"},
SetRequestHeaders: map[string]string{ SetRequestHeaders: map[string]string{
"Authorization": "Bearer ${pomerium.id_token}", "Authorization": "Bearer ${pomerium.id_token}",
}, },
},
Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -394,13 +305,12 @@ func TestHeadersEvaluator(t *testing.T) {
t.Run("set_request_headers no client cert", func(t *testing.T) { t.Run("set_request_headers no client cert", func(t *testing.T) {
output, err := eval(t, nil, output, err := eval(t, nil,
&HeadersRequest{ &Request{
Issuer: "from.example.com", Policy: &config.Policy{
Audience: "from.example.com",
ToAudience: "to.example.com",
SetRequestHeaders: map[string]string{ SetRequestHeaders: map[string]string{
"fingerprint": "${pomerium.client_cert_fingerprint}", "fingerprint": "${pomerium.client_cert_fingerprint}",
}, },
},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -427,11 +337,10 @@ func TestHeadersEvaluator(t *testing.T) {
Name: "GROUP2", Name: "GROUP2",
}), }),
}, },
&HeadersRequest{ &Request{
Issuer: "from.example.com", Policy: &config.Policy{
Audience: "from.example.com",
ToAudience: "to.example.com",
KubernetesServiceAccountToken: "TOKEN", KubernetesServiceAccountToken: "TOKEN",
},
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -445,8 +354,7 @@ func TestHeadersEvaluator(t *testing.T) {
output, err := eval(t, output, err := eval(t,
[]protoreflect.ProtoMessage{}, []protoreflect.ProtoMessage{},
&HeadersRequest{ &Request{
EnableRoutingKey: false,
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -454,8 +362,12 @@ func TestHeadersEvaluator(t *testing.T) {
output, err = eval(t, output, err = eval(t,
[]protoreflect.ProtoMessage{}, []protoreflect.ProtoMessage{},
&HeadersRequest{ &Request{
EnableRoutingKey: true, Policy: &config.Policy{
EnvoyOpts: &envoy_config_cluster_v3.Cluster{
LbPolicy: envoy_config_cluster_v3.Cluster_MAGLEV,
},
},
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -470,7 +382,7 @@ func TestHeadersEvaluator(t *testing.T) {
&session.Session{Id: "s1", UserId: "u1"}, &session.Session{Id: "s1", UserId: "u1"},
&user.User{Id: "u1", Email: "user@example.com"}, &user.User{Id: "u1", Email: "user@example.com"},
}, },
&HeadersRequest{ &Request{
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -481,7 +393,7 @@ func TestHeadersEvaluator(t *testing.T) {
&session.Session{Id: "s1", UserId: "u1"}, &session.Session{Id: "s1", UserId: "u1"},
newDirectoryUserRecord(directory.User{ID: "u1", Email: "directory-user@example.com"}), newDirectoryUserRecord(directory.User{ID: "u1", Email: "directory-user@example.com"}),
}, },
&HeadersRequest{ &Request{
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -498,7 +410,7 @@ func TestHeadersEvaluator(t *testing.T) {
}}, }},
}}, }},
}, },
&HeadersRequest{ &Request{
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -513,7 +425,7 @@ func TestHeadersEvaluator(t *testing.T) {
}}, }},
}}, }},
}, },
&HeadersRequest{ &Request{
Session: RequestSession{ID: "s1"}, Session: RequestSession{ID: "s1"},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -528,12 +440,54 @@ func TestHeadersEvaluator(t *testing.T) {
&user.ServiceAccount{Id: "sa1", UserId: "u1"}, &user.ServiceAccount{Id: "sa1", UserId: "u1"},
&user.User{Id: "u1", Email: "u1@example.com"}, &user.User{Id: "u1", Email: "u1@example.com"},
}, },
&HeadersRequest{ &Request{
Session: RequestSession{ID: "sa1"}, Session: RequestSession{ID: "sa1"},
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "u1@example.com", output.Headers.Get("X-Pomerium-Claim-Email")) assert.Equal(t, "u1@example.com", output.Headers.Get("X-Pomerium-Claim-Email"))
}) })
t.Run("issuer format", func(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
format string
input string
output string
}{
{"", "example.com", "example.com"},
{"hostOnly", "host-only.example.com", "host-only.example.com"},
{"uri", "uri.example.com", "https://uri.example.com/"},
} {
output, err := eval(t,
nil,
&Request{
HTTP: RequestHTTP{
Hostname: tc.input,
},
Policy: &config.Policy{
JWTIssuerFormat: tc.format,
},
})
require.NoError(t, err)
m := decodeJWTAssertion(t, output.Headers)
assert.Equal(t, tc.output, m["iss"], "unexpected issuer for format=%s", tc.format)
}
})
}
func decodeJWTAssertion(t *testing.T, headers http.Header) map[string]any {
jwtHeader := headers.Get("X-Pomerium-Jwt-Assertion")
// Make sure the 'iat' and 'exp' claims can be parsed as an integer. We
// need to do some explicit decoding in order to be able to verify
// this, as by default json.Unmarshal() will make no distinction
// between numeric formats.
d := json.NewDecoder(bytes.NewReader(decodeJWSPayload(t, jwtHeader)))
d.UseNumber()
var m map[string]any
err := d.Decode(&m)
require.NoError(t, err)
return m
} }
func decodeJWSPayload(t *testing.T, jws string) []byte { func decodeJWSPayload(t *testing.T, jws string) []byte {