From 229c6ad84e66de64288afd18b1a974bbcd7e15c9 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 2 Jan 2025 15:12:07 -0700 Subject: [PATCH] return issuer format errors --- .../evaluator/headers_evaluator_evaluation.go | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/authorize/evaluator/headers_evaluator_evaluation.go b/authorize/evaluator/headers_evaluator_evaluation.go index c29a9f6da..3604919a3 100644 --- a/authorize/evaluator/headers_evaluator_evaluation.go +++ b/authorize/evaluator/headers_evaluator_evaluation.go @@ -61,16 +61,19 @@ func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Request } func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) { - e.fillHeaders(ctx) - return e.response, nil + err := e.fillHeaders(ctx) + return e.response, err } func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) { e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx)) } -func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) { - claims := e.getJWTPayload(ctx) +func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) error { + claims, err := e.getJWTPayload(ctx) + if err != nil { + return err + } for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() { claim, ok := claims[claimKey] if !ok { @@ -79,6 +82,7 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) { } e.response.Headers.Add(headerName, getHeaderStringValue(claim)) } + return nil } func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) { @@ -160,13 +164,16 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) } } -func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) { +func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error { e.fillJWTAssertionHeader(ctx) - e.fillJWTClaimHeaders(ctx) + if err := e.fillJWTClaimHeaders(ctx); err != nil { + return err + } e.fillKubernetesHeaders(ctx) e.fillGoogleCloudServerlessHeaders(ctx) e.fillRoutingKeyHeaders() e.fillSetRequestHeaders(ctx) + return nil } func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) { @@ -234,18 +241,18 @@ func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string { return make([]string, 0) } -func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string { +func (e *headersEvaluatorEvaluation) getJWTPayloadIss() (string, error) { var issuerFormat string if e.request.Policy != nil { issuerFormat = e.request.Policy.JWTIssuerFormat } switch issuerFormat { case "uri": - return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname) + return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname), nil case "", "hostOnly": - return e.request.HTTP.Hostname + return e.request.HTTP.Hostname, nil default: - return "" + return "", fmt.Errorf("unsupported JWT issuer format: %s", issuerFormat) } } @@ -340,14 +347,19 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) stri return "" } -func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[string]any { +func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) (map[string]any, error) { if e.gotJWTPayload { - return e.cachedJWTPayload + return e.cachedJWTPayload, nil + } + + iss, err := e.getJWTPayloadIss() + if err != nil { + return nil, err } e.gotJWTPayload = true e.cachedJWTPayload = map[string]any{ - "iss": e.getJWTPayloadIss(), + "iss": iss, "aud": e.getJWTPayloadAud(), "jti": e.getJWTPayloadJTI(), "iat": e.getJWTPayloadIAT(), @@ -375,7 +387,7 @@ func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[stri e.cachedJWTPayload[claimKey] = strings.Join(vs, ",") } } - return e.cachedJWTPayload + return e.cachedJWTPayload, nil } func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string { @@ -404,7 +416,11 @@ func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string { return "" } - jwtPayload := e.getJWTPayload(ctx) + jwtPayload, err := e.getJWTPayload(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT payload") + return "" + } bs, err := json.Marshal(jwtPayload) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload")