evaluator: halt on databroker errors during policy and header evaluation

(cherry picked from commit 53f731f241d3beec813bbcc02e99151d426e4f4c)
This commit is contained in:
Joe Kralicky 2024-12-30 23:07:24 +00:00
parent 2bb70258c3
commit dd755ede7f
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
5 changed files with 281 additions and 116 deletions

View file

@ -13,6 +13,8 @@ import (
"github.com/hashicorp/go-set/v3" "github.com/hashicorp/go-set/v3"
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -23,6 +25,7 @@ import (
"github.com/pomerium/pomerium/pkg/contextutil" "github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/policy/criteria" "github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/storage"
) )
// Request contains the inputs needed for evaluation. // Request contains the inputs needed for evaluation.
@ -219,11 +222,21 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
eg.Go(func() error { eg.Go(func() error {
var err error var err error
headersOutput, err = e.evaluateHeaders(ctx, req) headersOutput, err = e.evaluateHeaders(ctx, req)
if storage.IsNotFound(err) {
headersOutput = &HeadersResponse{}
return nil
}
return err return err
}) })
err := eg.Wait() err := eg.Wait()
if err != nil { if err != nil {
if status.Code(err) == codes.Unavailable {
return &Result{
Allow: NewRuleResult(false, criteria.ReasonInternalServerError),
Deny: NewRuleResult(true, criteria.ReasonInternalServerError),
}, nil
}
return nil, err return nil, err
} }

View file

@ -60,16 +60,26 @@ func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Headers
} }
func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) { func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) {
e.fillHeaders(ctx) if err := e.fillHeaders(ctx); err != nil {
return nil, err
}
return e.response, nil return e.response, nil
} }
func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) error {
e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx)) jwt, err := e.getSignedJWT(ctx)
if err != nil {
return err
}
e.response.Headers.Add("x-pomerium-jwt-assertion", jwt)
return nil
} }
func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) error {
claims := e.getJWTPayload(ctx) claims, err := e.getJWTPayload(ctx)
if err != nil {
return err
}
for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() { for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
claim, ok := claims[claimKey] claim, ok := claims[claimKey]
if !ok { if !ok {
@ -78,107 +88,171 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) {
} }
e.response.Headers.Add(headerName, getHeaderStringValue(claim)) e.response.Headers.Add(headerName, getHeaderStringValue(claim))
} }
return nil
} }
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) error {
if e.request.KubernetesServiceAccountToken == "" { if e.request.KubernetesServiceAccountToken == "" {
return return nil
} }
e.response.Headers.Add("Authorization", "Bearer "+e.request.KubernetesServiceAccountToken) e.response.Headers.Add("Authorization", "Bearer "+e.request.KubernetesServiceAccountToken)
impersonateUser := e.getJWTPayloadEmail(ctx) impersonateUser, err := e.getJWTPayloadEmail(ctx)
if err != nil {
return err
}
if impersonateUser != "" { if impersonateUser != "" {
e.response.Headers.Add("Impersonate-User", impersonateUser) e.response.Headers.Add("Impersonate-User", impersonateUser)
} }
impersonateGroups := e.getJWTPayloadGroups(ctx) impersonateGroups, err := e.getJWTPayloadGroups(ctx)
if err != nil {
return err
}
if len(impersonateGroups) > 0 { if len(impersonateGroups) > 0 {
e.response.Headers.Add("Impersonate-Group", strings.Join(impersonateGroups, ",")) e.response.Headers.Add("Impersonate-Group", strings.Join(impersonateGroups, ","))
} }
return nil
} }
func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) error {
if e.request.EnableGoogleCloudServerlessAuthentication { if e.request.EnableGoogleCloudServerlessAuthentication {
h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), e.request.ToAudience) h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), e.request.ToAudience)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers")
return return err
} }
for k, v := range h { for k, v := range h {
e.response.Headers.Add(k, v) e.response.Headers.Add(k, v)
} }
} }
return nil
} }
func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() { func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() error {
if e.request.EnableRoutingKey { if e.request.EnableRoutingKey {
e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID)) e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID))
} }
return nil
} }
func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) error {
for k, v := range e.request.SetRequestHeaders { for k, v := range e.request.SetRequestHeaders {
var retErr error
e.response.Headers.Add(k, os.Expand(v, func(name string) string { e.response.Headers.Add(k, os.Expand(v, func(name string) string {
switch name { switch name {
case "$": case "$":
return "$" return "$"
case "pomerium.access_token": case "pomerium.access_token":
s, _ := e.getSessionOrServiceAccount(ctx) s, _, err := e.getSessionOrServiceAccount(ctx)
if err != nil {
retErr = err
return ""
}
return s.GetOauthToken().GetAccessToken() return s.GetOauthToken().GetAccessToken()
case "pomerium.client_cert_fingerprint": case "pomerium.client_cert_fingerprint":
return e.getClientCertFingerprint() return e.getClientCertFingerprint()
case "pomerium.id_token": case "pomerium.id_token":
s, _ := e.getSessionOrServiceAccount(ctx) s, _, err := e.getSessionOrServiceAccount(ctx)
if err != nil {
retErr = err
return ""
}
return s.GetIdToken().GetRaw() return s.GetIdToken().GetRaw()
case "pomerium.jwt": case "pomerium.jwt":
return e.getSignedJWT(ctx) jwt, err := e.getSignedJWT(ctx)
if err != nil {
retErr = err
return ""
}
return jwt
} }
return "" return ""
})) }))
if retErr != nil {
return retErr
}
} }
return nil
} }
func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) { func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error {
e.fillJWTAssertionHeader(ctx) if err := e.fillJWTAssertionHeader(ctx); err != nil {
e.fillJWTClaimHeaders(ctx) return err
e.fillKubernetesHeaders(ctx) }
e.fillGoogleCloudServerlessHeaders(ctx) if err := e.fillJWTClaimHeaders(ctx); err != nil {
e.fillRoutingKeyHeaders() return err
e.fillSetRequestHeaders(ctx) }
if err := e.fillKubernetesHeaders(ctx); err != nil {
return err
}
if err := e.fillGoogleCloudServerlessHeaders(ctx); err != nil {
return err
}
if err := e.fillRoutingKeyHeaders(); err != nil {
return err
}
if err := e.fillSetRequestHeaders(ctx); err != nil {
return err
}
return nil
} }
func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) { func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount, error) {
if e.gotSessionOrServiceAccount { if e.gotSessionOrServiceAccount {
return e.cachedSession, e.cachedServiceAccount return e.cachedSession, e.cachedServiceAccount, nil
} }
e.gotSessionOrServiceAccount = true
if e.request.Session.ID != "" { if e.request.Session.ID != "" {
e.cachedServiceAccount, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.ServiceAccount", e.request.Session.ID).(*user.ServiceAccount) msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.ServiceAccount", e.request.Session.ID)
if err != nil {
return nil, nil, fmt.Errorf("error looking up service account: %w", err)
}
e.cachedServiceAccount, _ = msg.(*user.ServiceAccount)
} }
if e.request.Session.ID != "" && e.cachedServiceAccount == nil { if e.request.Session.ID != "" && e.cachedServiceAccount == nil {
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.request.Session.ID).(*session.Session) msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.request.Session.ID)
if err != nil {
return nil, nil, fmt.Errorf("error looking up session: %w", err)
}
e.cachedSession, _ = msg.(*session.Session)
} }
if e.cachedSession != nil && e.cachedSession.GetImpersonateSessionId() != "" { if e.cachedSession != nil && e.cachedSession.GetImpersonateSessionId() != "" {
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.cachedSession.GetImpersonateSessionId()).(*session.Session) msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.cachedSession.GetImpersonateSessionId())
if err != nil {
return nil, nil, fmt.Errorf("error looking up session: %w", err)
}
e.cachedSession, _ = msg.(*session.Session)
} }
return e.cachedSession, e.cachedServiceAccount e.gotSessionOrServiceAccount = true
return e.cachedSession, e.cachedServiceAccount, nil
} }
func (e *headersEvaluatorEvaluation) getUser(ctx context.Context) *user.User { func (e *headersEvaluatorEvaluation) getUser(ctx context.Context) (*user.User, error) {
if e.gotUser { if e.gotUser {
return e.cachedUser return e.cachedUser, nil
} }
e.gotUser = true s, sa, err := e.getSessionOrServiceAccount(ctx)
s, sa := e.getSessionOrServiceAccount(ctx) if err != nil {
if sa != nil && sa.UserId != "" { return nil, err
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", sa.UserId).(*user.User)
} else if s != nil && s.UserId != "" {
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", s.UserId).(*user.User)
} }
return e.cachedUser if sa != nil && sa.UserId != "" {
msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", sa.UserId)
if err != nil {
return nil, fmt.Errorf("error looking up user: %w", err)
}
e.cachedUser, _ = msg.(*user.User)
} else if s != nil && s.UserId != "" {
msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", s.UserId)
if err != nil {
return nil, fmt.Errorf("error looking up user: %w", err)
}
e.cachedUser, _ = msg.(*user.User)
}
e.gotUser = true
return e.cachedUser, nil
} }
func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string { func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string {
@ -189,27 +263,41 @@ func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string {
return cryptoSHA256(cert.Raw) return cryptoSHA256(cert.Raw)
} }
func (e *headersEvaluatorEvaluation) getDirectoryUser(ctx context.Context) *structpb.Struct { func (e *headersEvaluatorEvaluation) getDirectoryUser(ctx context.Context) (*structpb.Struct, error) {
if e.gotDirectoryUser { if e.gotDirectoryUser {
return e.cachedDirectoryUser return e.cachedDirectoryUser, nil
} }
e.gotDirectoryUser = true s, sa, err := e.getSessionOrServiceAccount(ctx)
s, sa := e.getSessionOrServiceAccount(ctx) if err != nil {
if sa != nil && sa.UserId != "" { return nil, err
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, sa.UserId).(*structpb.Struct)
} else if s != nil && s.UserId != "" {
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, s.UserId).(*structpb.Struct)
} }
return e.cachedDirectoryUser if sa != nil && sa.UserId != "" {
msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, sa.UserId)
if err != nil {
return nil, fmt.Errorf("error looking up directory user: %w", err)
}
e.cachedDirectoryUser, _ = msg.(*structpb.Struct)
} else if s != nil && s.UserId != "" {
msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, s.UserId)
if err != nil {
return nil, fmt.Errorf("error looking up directory user: %w", err)
}
e.cachedDirectoryUser, _ = msg.(*structpb.Struct)
}
e.gotDirectoryUser = true
return e.cachedDirectoryUser, nil
} }
func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string { func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) ([]string, error) {
du := e.getDirectoryUser(ctx) du, err := e.getDirectoryUser(ctx)
if groupIDs, ok := getStructStringSlice(du, "group_ids"); ok { if err != nil {
return groupIDs return nil, err
} }
return make([]string, 0) if groupIDs, ok := getStructStringSlice(du, "group_ids"); ok {
return groupIDs, nil
}
return make([]string, 0), nil
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string { func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string {
@ -238,91 +326,138 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadExp() int64 {
return e.now.Add(5 * time.Minute).Unix() return e.now.Add(5 * time.Minute).Unix()
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadSub(ctx context.Context) string { func (e *headersEvaluatorEvaluation) getJWTPayloadSub(ctx context.Context) (string, error) {
return e.getJWTPayloadUser(ctx) return e.getJWTPayloadUser(ctx)
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadUser(ctx context.Context) string { func (e *headersEvaluatorEvaluation) getJWTPayloadUser(ctx context.Context) (string, error) {
s, sa := e.getSessionOrServiceAccount(ctx) s, sa, err := e.getSessionOrServiceAccount(ctx)
if err != nil {
return "", err
}
if sa != nil { if sa != nil {
return sa.UserId return sa.UserId, nil
} }
if s != nil { if s != nil {
return s.UserId return s.UserId, nil
} }
return "" return "", nil
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) string { func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) (string, error) {
du := e.getDirectoryUser(ctx) du, err := e.getDirectoryUser(ctx)
if err != nil {
return "", err
}
if v, ok := getStructString(du, "email"); ok { if v, ok := getStructString(du, "email"); ok {
return v return v, nil
} }
u := e.getUser(ctx) u, err := e.getUser(ctx)
if err != nil {
return "", err
}
if u != nil { if u != nil {
return u.Email return u.Email, nil
} }
return "" return "", nil
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string { func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) ([]string, error) {
groupIDs := e.getGroupIDs(ctx) groupIDs, err := e.getGroupIDs(ctx)
if err != nil {
return nil, err
}
if len(groupIDs) > 0 { if len(groupIDs) > 0 {
groups := make([]string, 0, len(groupIDs)*2) groups := make([]string, 0, len(groupIDs)*2)
groups = append(groups, groupIDs...) groups = append(groups, groupIDs...)
groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...) groupNames, err := e.getDataBrokerGroupNames(ctx, groupIDs)
return groups if err != nil {
return nil, err
}
groups = append(groups, groupNames...)
return groups, nil
} }
s, _ := e.getSessionOrServiceAccount(ctx) s, _, err := e.getSessionOrServiceAccount(ctx)
if err != nil {
return nil, err
}
groups, _ := getClaimStringSlice(s, "groups") groups, _ := getClaimStringSlice(s, "groups")
return groups return groups, nil
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string { func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string {
return e.request.Session.ID return e.request.Session.ID
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) string { func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) (string, error) {
s, _ := e.getSessionOrServiceAccount(ctx) s, _, err := e.getSessionOrServiceAccount(ctx)
if err != nil {
return "", err
}
if names, ok := getClaimStringSlice(s, "name"); ok { if names, ok := getClaimStringSlice(s, "name"); ok {
return strings.Join(names, ",") return strings.Join(names, ","), nil
} }
u := e.getUser(ctx) u, err := e.getUser(ctx)
if err != nil {
return "", err
}
if names, ok := getClaimStringSlice(u, "name"); ok { if names, ok := getClaimStringSlice(u, "name"); ok {
return strings.Join(names, ",") return strings.Join(names, ","), nil
} }
return "" return "", nil
} }
func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[string]any { func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) (map[string]any, error) {
if e.gotJWTPayload { if e.gotJWTPayload {
return e.cachedJWTPayload return e.cachedJWTPayload, nil
} }
e.gotJWTPayload = true
e.cachedJWTPayload = map[string]any{ e.cachedJWTPayload = map[string]any{
"iss": e.getJWTPayloadIss(), "iss": e.getJWTPayloadIss(),
"aud": e.getJWTPayloadAud(), "aud": e.getJWTPayloadAud(),
"jti": e.getJWTPayloadJTI(), "jti": e.getJWTPayloadJTI(),
"iat": e.getJWTPayloadIAT(), "iat": e.getJWTPayloadIAT(),
"exp": e.getJWTPayloadExp(), "exp": e.getJWTPayloadExp(),
"sub": e.getJWTPayloadSub(ctx), "sid": e.getJWTPayloadSID(),
"user": e.getJWTPayloadUser(ctx),
"email": e.getJWTPayloadEmail(ctx),
"groups": e.getJWTPayloadGroups(ctx),
"sid": e.getJWTPayloadSID(),
"name": e.getJWTPayloadName(ctx),
} }
s, _ := e.getSessionOrServiceAccount(ctx) var err error
u := e.getUser(ctx) e.cachedJWTPayload["sub"], err = e.getJWTPayloadSub(ctx)
if err != nil {
return nil, err
}
e.cachedJWTPayload["user"], err = e.getJWTPayloadUser(ctx)
if err != nil {
return nil, err
}
e.cachedJWTPayload["email"], err = e.getJWTPayloadEmail(ctx)
if err != nil {
return nil, err
}
e.cachedJWTPayload["groups"], err = e.getJWTPayloadGroups(ctx)
if err != nil {
return nil, err
}
e.cachedJWTPayload["name"], err = e.getJWTPayloadName(ctx)
if err != nil {
return nil, err
}
s, _, err := e.getSessionOrServiceAccount(ctx)
if err != nil {
return nil, err
}
u, err := e.getUser(ctx)
if err != nil {
return nil, err
}
for _, claimKey := range e.evaluator.store.GetJWTClaimHeaders() { for _, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
// ignore base claims // ignore base claims
@ -336,19 +471,20 @@ func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[stri
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",") e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
} }
} }
return e.cachedJWTPayload
e.gotJWTPayload = true
return e.cachedJWTPayload, nil
} }
func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string { func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) (string, error) {
if e.gotSignedJWT { if e.gotSignedJWT {
return e.cachedSignedJWT return e.cachedSignedJWT, nil
} }
e.gotSignedJWT = true
signingKey := e.evaluator.store.GetSigningKey() signingKey := e.evaluator.store.GetSigningKey()
if signingKey == nil { if signingKey == nil {
log.Ctx(ctx).Error().Msg("authorize/header-evaluator: missing signing key") log.Ctx(ctx).Error().Msg("authorize/header-evaluator: missing signing key")
return "" return "", nil
} }
signingOptions := new(jose.SignerOptions). signingOptions := new(jose.SignerOptions).
@ -362,39 +498,48 @@ func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
}, signingOptions) }, signingOptions)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT signer") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT signer")
return "" return "", err
} }
jwtPayload := e.getJWTPayload(ctx) jwtPayload, err := e.getJWTPayload(ctx)
if err != nil {
return "", err
}
bs, err := json.Marshal(jwtPayload) bs, err := json.Marshal(jwtPayload)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload")
return "" return "", err
} }
jwt, err := signer.Sign(bs) jwt, err := signer.Sign(bs)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error signing JWT") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error signing JWT")
return "" return "", err
} }
e.cachedSignedJWT, err = jwt.CompactSerialize() e.cachedSignedJWT, err = jwt.CompactSerialize()
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error serializing JWT") log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error serializing JWT")
return "" return "", err
} }
return e.cachedSignedJWT
e.gotSignedJWT = true
return e.cachedSignedJWT, nil
} }
func (e *headersEvaluatorEvaluation) getDataBrokerGroupNames(ctx context.Context, groupIDs []string) []string { func (e *headersEvaluatorEvaluation) getDataBrokerGroupNames(ctx context.Context, groupIDs []string) ([]string, error) {
groupNames := make([]string, 0, len(groupIDs)) groupNames := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
dg, _ := e.evaluator.store.GetDataBrokerRecord(ctx, directory.GroupRecordType, groupID).(*structpb.Struct) msg, err := e.evaluator.store.GetDataBrokerRecord(ctx, directory.GroupRecordType, groupID)
if err != nil {
return nil, fmt.Errorf("error looking up directory group: %w", err)
}
dg, _ := msg.(*structpb.Struct)
if name, ok := getStructString(dg, "name"); ok { if name, ok := getStructString(dg, "name"); ok {
groupNames = append(groupNames, name) groupNames = append(groupNames, name)
} }
} }
return groupNames return groupNames, nil
} }
type hasGetClaims interface { type hasGetClaims interface {

View file

@ -50,11 +50,14 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
requestID := requestid.FromHTTPHeader(hreq.Header) requestID := requestid.FromHTTPHeader(hreq.Header)
ctx = requestid.WithValue(ctx, requestID) ctx = requestid.WithValue(ctx, requestID)
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq) sessionState, err := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
if status.Code(err) == codes.Unavailable {
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
return nil, err
}
var s sessionOrServiceAccount var s sessionOrServiceAccount
var u *user.User var u *user.User
var err error
if sessionState != nil { if sessionState != nil {
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion) s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if status.Code(err) == codes.Unavailable { if status.Code(err) == codes.Unavailable {

View file

@ -145,7 +145,10 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
} }
span.AddAttributes(octrace.StringAttribute("record_id", recordIDOrIndex.String())) span.AddAttributes(octrace.StringAttribute("record_id", recordIDOrIndex.String()))
msg := s.GetDataBrokerRecord(ctx, string(recordType), string(recordIDOrIndex)) msg, err := s.GetDataBrokerRecord(ctx, string(recordType), string(recordIDOrIndex))
if err != nil {
return nil, rego.NewHaltError(err)
}
if msg == nil { if msg == nil {
return ast.NullTerm(), nil return ast.NullTerm(), nil
} }
@ -162,7 +165,7 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
}) })
} }
func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message { func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) (proto.Message, error) {
req := &databroker.QueryRequest{ req := &databroker.QueryRequest{
Type: recordType, Type: recordType,
Limit: 1, Limit: 1,
@ -172,26 +175,26 @@ func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrI
res, err := storage.GetQuerier(ctx).Query(ctx, req) res, err := storage.GetQuerier(ctx).Query(ctx, req)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record") log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
return nil return nil, err
} }
if len(res.GetRecords()) == 0 { if len(res.GetRecords()) == 0 {
return nil return nil, nil
} }
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew() msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
if msg == nil { if msg == nil {
return nil return nil, nil
} }
// exclude expired records // exclude expired records
if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil { if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil {
if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) { if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) {
return nil return nil, nil
} }
} }
return msg return msg, nil
} }
func toMap(msg proto.Message) map[string]any { func toMap(msg proto.Message) map[string]any {

View file

@ -37,6 +37,7 @@ const (
ReasonUserUnauthenticated = "user-unauthenticated" // user needs to log in ReasonUserUnauthenticated = "user-unauthenticated" // user needs to log in
ReasonUserUnauthorized = "user-unauthorized" // user does not have access ReasonUserUnauthorized = "user-unauthorized" // user does not have access
ReasonValidClientCertificate = "valid-client-certificate" ReasonValidClientCertificate = "valid-client-certificate"
ReasonInternalServerError = "internal-server-error"
) )
// Reasons is a collection of reasons. // Reasons is a collection of reasons.