return issuer format errors

This commit is contained in:
Caleb Doxsey 2025-01-02 15:12:07 -07:00
parent e39416a872
commit 229c6ad84e

View file

@ -61,16 +61,19 @@ func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Request
} }
func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) { func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) {
e.fillHeaders(ctx) err := e.fillHeaders(ctx)
return e.response, nil return e.response, err
} }
func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) {
e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx)) e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx))
} }
func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) error {
claims := e.getJWTPayload(ctx) claims, err := e.getJWTPayload(ctx)
if err != nil {
return err
}
for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() { for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
claim, ok := claims[claimKey] claim, ok := claims[claimKey]
if !ok { if !ok {
@ -79,6 +82,7 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) {
} }
e.response.Headers.Add(headerName, getHeaderStringValue(claim)) e.response.Headers.Add(headerName, getHeaderStringValue(claim))
} }
return nil
} }
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
@ -160,13 +164,16 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context)
} }
} }
func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error {
e.fillJWTAssertionHeader(ctx) e.fillJWTAssertionHeader(ctx)
e.fillJWTClaimHeaders(ctx) if err := e.fillJWTClaimHeaders(ctx); err != nil {
return err
}
e.fillKubernetesHeaders(ctx) e.fillKubernetesHeaders(ctx)
e.fillGoogleCloudServerlessHeaders(ctx) e.fillGoogleCloudServerlessHeaders(ctx)
e.fillRoutingKeyHeaders() e.fillRoutingKeyHeaders()
e.fillSetRequestHeaders(ctx) e.fillSetRequestHeaders(ctx)
return nil
} }
func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) { func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) {
@ -234,18 +241,18 @@ func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string {
return make([]string, 0) return make([]string, 0)
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string { func (e *headersEvaluatorEvaluation) getJWTPayloadIss() (string, error) {
var issuerFormat string var issuerFormat string
if e.request.Policy != nil { if e.request.Policy != nil {
issuerFormat = e.request.Policy.JWTIssuerFormat issuerFormat = e.request.Policy.JWTIssuerFormat
} }
switch issuerFormat { switch issuerFormat {
case "uri": case "uri":
return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname) return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname), nil
case "", "hostOnly": case "", "hostOnly":
return e.request.HTTP.Hostname return e.request.HTTP.Hostname, nil
default: default:
return "" return "", fmt.Errorf("unsupported JWT issuer format: %s", issuerFormat)
} }
} }
@ -340,14 +347,19 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) stri
return "" return ""
} }
func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[string]any { func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) (map[string]any, error) {
if e.gotJWTPayload { if e.gotJWTPayload {
return e.cachedJWTPayload return e.cachedJWTPayload, nil
}
iss, err := e.getJWTPayloadIss()
if err != nil {
return nil, err
} }
e.gotJWTPayload = true e.gotJWTPayload = true
e.cachedJWTPayload = map[string]any{ e.cachedJWTPayload = map[string]any{
"iss": e.getJWTPayloadIss(), "iss": iss,
"aud": e.getJWTPayloadAud(), "aud": e.getJWTPayloadAud(),
"jti": e.getJWTPayloadJTI(), "jti": e.getJWTPayloadJTI(),
"iat": e.getJWTPayloadIAT(), "iat": e.getJWTPayloadIAT(),
@ -375,7 +387,7 @@ func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[stri
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",") e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
} }
} }
return e.cachedJWTPayload return e.cachedJWTPayload, nil
} }
func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string { func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
@ -404,7 +416,11 @@ func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
return "" return ""
} }
jwtPayload := e.getJWTPayload(ctx) jwtPayload, err := e.getJWTPayload(ctx)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT payload")
return ""
}
bs, err := json.Marshal(jwtPayload) bs, err := json.Marshal(jwtPayload)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload")