mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
authorize: rewrite header evaluator to use go instead of rego (#5362)
* authorize: rewrite header evaluator to use go instead of rego * cache signed jwt * re-add missing trace * address comments
This commit is contained in:
parent
177f789e63
commit
37017e2a5b
7 changed files with 576 additions and 411 deletions
|
@ -129,10 +129,7 @@ func New(
|
|||
e.headersEvaluators = previous.headersEvaluators
|
||||
cachedPolicyEvaluators = previous.policyEvaluators
|
||||
} else {
|
||||
e.headersEvaluators, err = NewHeadersEvaluator(ctx, store)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.headersEvaluators = NewHeadersEvaluator(store)
|
||||
}
|
||||
e.policyEvaluators, err = getOrCreatePolicyEvaluators(ctx, cfg, store, cachedPolicyEvaluators)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,14 +4,11 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"github.com/open-policy-agent/opa/types"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator/opa"
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
|
@ -69,115 +66,30 @@ type HeadersResponse struct {
|
|||
Headers http.Header
|
||||
}
|
||||
|
||||
var variableSubstitutionFunctionRegoOption = rego.Function2(®o.Function{
|
||||
Name: "pomerium.variable_substitution",
|
||||
Decl: types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("input_string", types.S),
|
||||
types.Named("replacements",
|
||||
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S))),
|
||||
),
|
||||
types.Named("output", types.S),
|
||||
),
|
||||
}, func(_ rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
|
||||
inputString, ok := op1.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid input_string type: %T", op1.Value)
|
||||
}
|
||||
|
||||
replacements, ok := op2.Value.(ast.Object)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid replacements type: %T", op2.Value)
|
||||
}
|
||||
|
||||
var err error
|
||||
output := os.Expand(string(inputString), func(key string) string {
|
||||
if key == "$" {
|
||||
return "$" // allow a dollar sign to be escaped using $$
|
||||
}
|
||||
r := replacements.Get(ast.StringTerm(key))
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
s, ok := r.Value.(ast.String)
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid replacement value type for key %q: %T", key, r.Value)
|
||||
}
|
||||
return string(s)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.StringTerm(output), nil
|
||||
})
|
||||
|
||||
// A HeadersEvaluator evaluates the headers.rego script.
|
||||
type HeadersEvaluator struct {
|
||||
q rego.PreparedEvalQuery
|
||||
store *store.Store
|
||||
}
|
||||
|
||||
// NewHeadersEvaluator creates a new HeadersEvaluator.
|
||||
func NewHeadersEvaluator(ctx context.Context, store *store.Store, options ...func(rego *rego.Rego)) (*HeadersEvaluator, error) {
|
||||
r := rego.New(append([]func(*rego.Rego){
|
||||
rego.Store(store),
|
||||
rego.Module("pomerium.headers", opa.HeadersRego),
|
||||
rego.Query("result := data.pomerium.headers"),
|
||||
rego.EnablePrintStatements(true),
|
||||
getGoogleCloudServerlessHeadersRegoOption,
|
||||
variableSubstitutionFunctionRegoOption,
|
||||
store.GetDataBrokerRecordOption(),
|
||||
rego.SetRegoVersion(ast.RegoV1),
|
||||
}, options...)...)
|
||||
|
||||
q, err := r.PrepareForEval(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewHeadersEvaluator(store *store.Store) *HeadersEvaluator {
|
||||
return &HeadersEvaluator{
|
||||
q: q,
|
||||
}, nil
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate evaluates the headers.rego script.
|
||||
func (e *HeadersEvaluator) Evaluate(ctx context.Context, req *HeadersRequest, options ...rego.EvalOption) (*HeadersResponse, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.HeadersEvaluator.Evaluate")
|
||||
defer span.End()
|
||||
rs, err := safeEval(ctx, e.q, append([]rego.EvalOption{rego.EvalInput(req)}, options...)...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error evaluating headers.rego: %w", err)
|
||||
}
|
||||
|
||||
if len(rs) == 0 {
|
||||
return nil, fmt.Errorf("authorize: unexpected empty result from evaluating headers.rego")
|
||||
ectx := new(rego.EvalContext)
|
||||
for _, option := range options {
|
||||
option(ectx)
|
||||
}
|
||||
|
||||
return &HeadersResponse{
|
||||
Headers: e.getHeader(rs[0].Bindings),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *HeadersEvaluator) getHeader(vars rego.Vars) http.Header {
|
||||
h := make(http.Header)
|
||||
|
||||
m, ok := vars["result"].(map[string]any)
|
||||
if !ok {
|
||||
return h
|
||||
}
|
||||
|
||||
m, ok = m["identity_headers"].(map[string]any)
|
||||
if !ok {
|
||||
return h
|
||||
}
|
||||
|
||||
for k := range m {
|
||||
vs, ok := m[k].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, v := range vs {
|
||||
h.Add(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
return h
|
||||
now := ectx.Time()
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
return newHeadersEvaluatorEvaluation(e, req, now).execute(ctx)
|
||||
}
|
||||
|
|
496
authorize/evaluator/headers_evaluator_evaluation.go
Normal file
496
authorize/evaluator/headers_evaluator_evaluation.go
Normal file
|
@ -0,0 +1,496 @@
|
|||
package evaluator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
// A headersEvaluatorEvaluation is a single evaluation of the headers evaluator.
|
||||
type headersEvaluatorEvaluation struct {
|
||||
evaluator *HeadersEvaluator
|
||||
request *HeadersRequest
|
||||
response *HeadersResponse
|
||||
now time.Time
|
||||
|
||||
gotSessionOrServiceAccount bool
|
||||
cachedSession *session.Session
|
||||
cachedServiceAccount *user.ServiceAccount
|
||||
|
||||
gotUser bool
|
||||
cachedUser *user.User
|
||||
|
||||
gotDirectoryUser bool
|
||||
cachedDirectoryUser *structpb.Struct
|
||||
|
||||
gotJWTPayloadJTI bool
|
||||
cachedJWTPayloadJTI string
|
||||
|
||||
gotJWTPayload bool
|
||||
cachedJWTPayload map[string]any
|
||||
|
||||
gotSignedJWT bool
|
||||
cachedSignedJWT string
|
||||
}
|
||||
|
||||
func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *HeadersRequest, now time.Time) *headersEvaluatorEvaluation {
|
||||
return &headersEvaluatorEvaluation{
|
||||
evaluator: evaluator,
|
||||
request: request,
|
||||
response: &HeadersResponse{Headers: make(http.Header)},
|
||||
now: now,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) {
|
||||
e.fillHeaders(ctx)
|
||||
return e.response, nil
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) {
|
||||
e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx))
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) {
|
||||
claims := e.getJWTPayload(ctx)
|
||||
for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
|
||||
claim, ok := claims[claimKey]
|
||||
if !ok {
|
||||
e.response.Headers.Add(headerName, "")
|
||||
continue
|
||||
}
|
||||
e.response.Headers.Add(headerName, getHeaderStringValue(claim))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
|
||||
if e.request.KubernetesServiceAccountToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
e.response.Headers.Add("Authorization", "Bearer "+e.request.KubernetesServiceAccountToken)
|
||||
impersonateUser := e.getJWTPayloadEmail(ctx)
|
||||
if impersonateUser != "" {
|
||||
e.response.Headers.Add("Impersonate-User", impersonateUser)
|
||||
}
|
||||
impersonateGroups := e.getJWTPayloadGroups(ctx)
|
||||
if len(impersonateGroups) > 0 {
|
||||
e.response.Headers.Add("Impersonate-Group", strings.Join(impersonateGroups, ","))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) {
|
||||
if e.request.EnableGoogleCloudServerlessAuthentication {
|
||||
h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), e.request.ToAudience)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers")
|
||||
return
|
||||
}
|
||||
for k, v := range h {
|
||||
e.response.Headers.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() {
|
||||
if e.request.EnableRoutingKey {
|
||||
e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) {
|
||||
for k, v := range e.request.SetRequestHeaders {
|
||||
e.response.Headers.Add(k, os.Expand(v, func(name string) string {
|
||||
switch name {
|
||||
case "$":
|
||||
return "$"
|
||||
case "pomerium.access_token":
|
||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||
return s.GetOauthToken().GetAccessToken()
|
||||
case "pomerium.client_cert_fingerprint":
|
||||
return e.getClientCertFingerprint()
|
||||
case "pomerium.id_token":
|
||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||
return s.GetIdToken().GetRaw()
|
||||
case "pomerium.jwt":
|
||||
return e.getSignedJWT(ctx)
|
||||
}
|
||||
|
||||
return ""
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) {
|
||||
e.fillJWTAssertionHeader(ctx)
|
||||
e.fillJWTClaimHeaders(ctx)
|
||||
e.fillKubernetesHeaders(ctx)
|
||||
e.fillGoogleCloudServerlessHeaders(ctx)
|
||||
e.fillRoutingKeyHeaders()
|
||||
e.fillSetRequestHeaders(ctx)
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) {
|
||||
if e.gotSessionOrServiceAccount {
|
||||
return e.cachedSession, e.cachedServiceAccount
|
||||
}
|
||||
|
||||
e.gotSessionOrServiceAccount = true
|
||||
if e.request.Session.ID != "" {
|
||||
e.cachedServiceAccount, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.ServiceAccount", e.request.Session.ID).(*user.ServiceAccount)
|
||||
}
|
||||
|
||||
if e.request.Session.ID != "" && e.cachedServiceAccount == nil {
|
||||
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.request.Session.ID).(*session.Session)
|
||||
}
|
||||
if e.cachedSession != nil && e.cachedSession.GetImpersonateSessionId() != "" {
|
||||
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.cachedSession.GetImpersonateSessionId()).(*session.Session)
|
||||
}
|
||||
return e.cachedSession, e.cachedServiceAccount
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getUser(ctx context.Context) *user.User {
|
||||
if e.gotUser {
|
||||
return e.cachedUser
|
||||
}
|
||||
|
||||
e.gotUser = true
|
||||
s, sa := e.getSessionOrServiceAccount(ctx)
|
||||
if sa != nil && sa.UserId != "" {
|
||||
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", sa.UserId).(*user.User)
|
||||
} else if s != nil && s.UserId != "" {
|
||||
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", s.UserId).(*user.User)
|
||||
}
|
||||
return e.cachedUser
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string {
|
||||
cert, err := cryptutil.ParsePEMCertificate([]byte(e.request.ClientCertificate.Leaf))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cryptoSHA256(cert.Raw)
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getDirectoryUser(ctx context.Context) *structpb.Struct {
|
||||
if e.gotDirectoryUser {
|
||||
return e.cachedDirectoryUser
|
||||
}
|
||||
|
||||
e.gotDirectoryUser = true
|
||||
s, sa := e.getSessionOrServiceAccount(ctx)
|
||||
if sa != nil && sa.UserId != "" {
|
||||
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, sa.UserId).(*structpb.Struct)
|
||||
} else if s != nil && s.UserId != "" {
|
||||
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, s.UserId).(*structpb.Struct)
|
||||
}
|
||||
return e.cachedDirectoryUser
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string {
|
||||
du := e.getDirectoryUser(ctx)
|
||||
if groupIDs, ok := getStructStringSlice(du, "group_ids"); ok {
|
||||
return groupIDs
|
||||
}
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string {
|
||||
return e.request.Issuer
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadAud() string {
|
||||
return e.request.Audience
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadJTI() string {
|
||||
if e.gotJWTPayloadJTI {
|
||||
return e.cachedJWTPayloadJTI
|
||||
}
|
||||
|
||||
e.gotJWTPayloadJTI = true
|
||||
e.cachedJWTPayloadJTI = uuid.New().String()
|
||||
return e.cachedJWTPayloadJTI
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadIAT() int64 {
|
||||
return e.now.Unix()
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadExp() int64 {
|
||||
return e.now.Add(5 * time.Minute).Unix()
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadSub(ctx context.Context) string {
|
||||
return e.getJWTPayloadUser(ctx)
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadUser(ctx context.Context) string {
|
||||
s, sa := e.getSessionOrServiceAccount(ctx)
|
||||
if sa != nil {
|
||||
return sa.UserId
|
||||
}
|
||||
|
||||
if s != nil {
|
||||
return s.UserId
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) string {
|
||||
du := e.getDirectoryUser(ctx)
|
||||
if v, ok := getStructString(du, "email"); ok {
|
||||
return v
|
||||
}
|
||||
|
||||
u := e.getUser(ctx)
|
||||
if u != nil {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string {
|
||||
groupIDs := e.getGroupIDs(ctx)
|
||||
if len(groupIDs) > 0 {
|
||||
groups := make([]string, 0, len(groupIDs)*2)
|
||||
groups = append(groups, groupIDs...)
|
||||
groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...)
|
||||
return groups
|
||||
}
|
||||
|
||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||
groups, _ := getClaimStringSlice(s, "groups")
|
||||
return groups
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string {
|
||||
return e.request.Session.ID
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) string {
|
||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||
if names, ok := getClaimStringSlice(s, "name"); ok {
|
||||
return strings.Join(names, ",")
|
||||
}
|
||||
|
||||
u := e.getUser(ctx)
|
||||
if names, ok := getClaimStringSlice(u, "name"); ok {
|
||||
return strings.Join(names, ",")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[string]any {
|
||||
if e.gotJWTPayload {
|
||||
return e.cachedJWTPayload
|
||||
}
|
||||
|
||||
e.gotJWTPayload = true
|
||||
e.cachedJWTPayload = map[string]any{
|
||||
"iss": e.getJWTPayloadIss(),
|
||||
"aud": e.getJWTPayloadAud(),
|
||||
"jti": e.getJWTPayloadJTI(),
|
||||
"iat": e.getJWTPayloadIAT(),
|
||||
"exp": e.getJWTPayloadExp(),
|
||||
"sub": e.getJWTPayloadSub(ctx),
|
||||
"user": e.getJWTPayloadUser(ctx),
|
||||
"email": e.getJWTPayloadEmail(ctx),
|
||||
"groups": e.getJWTPayloadGroups(ctx),
|
||||
"sid": e.getJWTPayloadSID(),
|
||||
"name": e.getJWTPayloadName(ctx),
|
||||
}
|
||||
|
||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||
u := e.getUser(ctx)
|
||||
|
||||
for _, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
|
||||
// ignore base claims
|
||||
if _, ok := e.cachedJWTPayload[claimKey]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if vs, ok := getClaimStringSlice(s, claimKey); ok {
|
||||
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
|
||||
} else if vs, ok := getClaimStringSlice(u, claimKey); ok {
|
||||
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
|
||||
}
|
||||
}
|
||||
return e.cachedJWTPayload
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
|
||||
if e.gotSignedJWT {
|
||||
return e.cachedSignedJWT
|
||||
}
|
||||
|
||||
e.gotSignedJWT = true
|
||||
signingKey := e.evaluator.store.GetSigningKey()
|
||||
if signingKey == nil {
|
||||
log.Ctx(ctx).Error().Msg("authorize/header-evaluator: missing signing key")
|
||||
return ""
|
||||
}
|
||||
|
||||
signingOptions := new(jose.SignerOptions).
|
||||
WithType("JWT").
|
||||
WithHeader("kid", signingKey.KeyID).
|
||||
WithHeader("alg", signingKey.Algorithm)
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(signingKey.Algorithm),
|
||||
Key: signingKey.Key,
|
||||
}, signingOptions)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT signer")
|
||||
return ""
|
||||
}
|
||||
|
||||
jwtPayload := e.getJWTPayload(ctx)
|
||||
bs, err := json.Marshal(jwtPayload)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload")
|
||||
return ""
|
||||
}
|
||||
|
||||
jwt, err := signer.Sign(bs)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error signing JWT")
|
||||
return ""
|
||||
}
|
||||
|
||||
e.cachedSignedJWT, err = jwt.CompactSerialize()
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error serializing JWT")
|
||||
return ""
|
||||
}
|
||||
return e.cachedSignedJWT
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) getDataBrokerGroupNames(ctx context.Context, groupIDs []string) []string {
|
||||
groupNames := make([]string, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
dg, _ := e.evaluator.store.GetDataBrokerRecord(ctx, directory.GroupRecordType, groupID).(*structpb.Struct)
|
||||
if name, ok := getStructString(dg, "name"); ok {
|
||||
groupNames = append(groupNames, name)
|
||||
}
|
||||
}
|
||||
return groupNames
|
||||
}
|
||||
|
||||
type hasGetClaims interface {
|
||||
GetClaims() map[string]*structpb.ListValue
|
||||
}
|
||||
|
||||
func getClaimStringSlice(msg hasGetClaims, field string) ([]string, bool) {
|
||||
if msg == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
claims := msg.GetClaims()
|
||||
if claims == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
lv, ok := claims[field]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
strs := make([]string, 0, len(lv.Values))
|
||||
for _, v := range lv.Values {
|
||||
switch v := v.GetKind().(type) {
|
||||
case *structpb.Value_NumberValue:
|
||||
strs = append(strs, fmt.Sprint(v.NumberValue))
|
||||
case *structpb.Value_StringValue:
|
||||
strs = append(strs, v.StringValue)
|
||||
case *structpb.Value_BoolValue:
|
||||
strs = append(strs, fmt.Sprint(v.BoolValue))
|
||||
|
||||
// just ignore these types
|
||||
case *structpb.Value_NullValue:
|
||||
case *structpb.Value_StructValue:
|
||||
case *structpb.Value_ListValue:
|
||||
}
|
||||
}
|
||||
return strs, true
|
||||
}
|
||||
|
||||
func getStructString(s *structpb.Struct, field string) (string, bool) {
|
||||
if s == nil || s.Fields == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
v, ok := s.Fields[field]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return fmt.Sprint(v.AsInterface()), true
|
||||
}
|
||||
|
||||
func getStructStringSlice(s *structpb.Struct, field string) ([]string, bool) {
|
||||
if s == nil || s.Fields == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := s.Fields[field]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
lv := v.GetListValue()
|
||||
if lv == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
strs := make([]string, len(lv.Values))
|
||||
for i, vv := range lv.Values {
|
||||
sv, ok := vv.Kind.(*structpb.Value_StringValue)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
strs[i] = sv.StringValue
|
||||
}
|
||||
return strs, true
|
||||
}
|
||||
|
||||
func cryptoSHA256[T string | []byte](input T) string {
|
||||
output := sha256.Sum256([]byte(input))
|
||||
return hex.EncodeToString(output[:])
|
||||
}
|
||||
|
||||
func getHeaderStringValue(obj any) string {
|
||||
v := reflect.ValueOf(obj)
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
var str strings.Builder
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if i > 0 {
|
||||
str.WriteByte(',')
|
||||
}
|
||||
str.WriteString(getHeaderStringValue(v.Index(i).Interface()))
|
||||
}
|
||||
return str.String()
|
||||
}
|
||||
|
||||
return fmt.Sprint(obj)
|
||||
}
|
|
@ -60,8 +60,7 @@ func BenchmarkHeadersEvaluator(b *testing.B) {
|
|||
s.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
||||
s.UpdateSigningKey(privateJWK)
|
||||
|
||||
e, err := NewHeadersEvaluator(ctx, s, rego.Time(iat))
|
||||
require.NoError(b, err)
|
||||
e := NewHeadersEvaluator(s)
|
||||
|
||||
req := &HeadersRequest{
|
||||
EnableRoutingKey: true,
|
||||
|
@ -198,14 +197,13 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
|
||||
iat := time.Unix(1686870680, 0)
|
||||
|
||||
eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
|
||||
eval := func(_ *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
|
||||
ctx := context.Background()
|
||||
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
|
||||
store := store.New()
|
||||
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("name", "email", "groups", "user", "CUSTOM_KEY"))
|
||||
store.UpdateSigningKey(privateJWK)
|
||||
e, err := NewHeadersEvaluator(ctx, store, rego.Time(iat))
|
||||
require.NoError(t, err)
|
||||
e := NewHeadersEvaluator(store)
|
||||
return e.Evaluate(ctx, input, rego.EvalTime(iat))
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
// Package opa implements the policy evaluator interface to make authorization
|
||||
// decisions.
|
||||
package opa
|
||||
|
||||
import _ "embed" // to embed files
|
||||
|
||||
// HeadersRego is the headers.rego script.
|
||||
//
|
||||
//go:embed policy/headers.rego
|
||||
var HeadersRego string
|
|
@ -1,265 +0,0 @@
|
|||
package pomerium.headers
|
||||
|
||||
import rego.v1
|
||||
|
||||
# input:
|
||||
# enable_google_cloud_serverless_authentication: boolean
|
||||
# enable_routing_key: boolean
|
||||
# client_certificate:
|
||||
# leaf: string
|
||||
# issuer: string
|
||||
# kubernetes_service_account_token: string
|
||||
# session:
|
||||
# id: string
|
||||
# to_audience: string
|
||||
# set_request_headers: map[string]string
|
||||
#
|
||||
# data:
|
||||
# jwt_claim_headers: map[string]string
|
||||
# signing_key:
|
||||
# alg: string
|
||||
# kid: string
|
||||
#
|
||||
# functions:
|
||||
# get_databroker_record
|
||||
# get_google_cloud_serverless_headers
|
||||
#
|
||||
#
|
||||
# output:
|
||||
# identity_headers: map[string][]string
|
||||
|
||||
now_s := round(time.now_ns() / 1e9)
|
||||
|
||||
# get the session
|
||||
session := v if {
|
||||
# try a service account
|
||||
v = get_databroker_record("type.googleapis.com/user.ServiceAccount", input.session.id)
|
||||
v != null
|
||||
} else := iv if {
|
||||
# try an impersonated session
|
||||
v = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
|
||||
v != null
|
||||
object.get(v, "impersonate_session_id", "") != ""
|
||||
|
||||
iv = get_databroker_record("type.googleapis.com/session.Session", v.impersonate_session_id)
|
||||
iv != null
|
||||
} else := v if {
|
||||
# try a normal session
|
||||
v = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
|
||||
v != null
|
||||
object.get(v, "impersonate_session_id", "") == ""
|
||||
} else := {}
|
||||
|
||||
user := u if {
|
||||
u = get_databroker_record("type.googleapis.com/user.User", session.user_id)
|
||||
u != null
|
||||
} else := {}
|
||||
|
||||
directory_user := du if {
|
||||
du = get_databroker_record("pomerium.io/DirectoryUser", session.user_id)
|
||||
du != null
|
||||
} else := {}
|
||||
|
||||
group_ids := gs if {
|
||||
gs = directory_user.group_ids
|
||||
gs != null
|
||||
} else := []
|
||||
|
||||
groups := array.concat(group_ids, array.concat(get_databroker_group_names(group_ids), get_databroker_group_emails(group_ids)))
|
||||
|
||||
jwt_headers := {
|
||||
"typ": "JWT",
|
||||
"alg": data.signing_key.alg,
|
||||
"kid": data.signing_key.kid,
|
||||
}
|
||||
|
||||
jwt_payload_aud := v if {
|
||||
v := input.audience
|
||||
} else := ""
|
||||
|
||||
jwt_payload_iss := v if {
|
||||
v := input.issuer
|
||||
} else := ""
|
||||
|
||||
jwt_payload_jti := uuid.rfc4122("jti")
|
||||
|
||||
jwt_payload_iat := now_s
|
||||
|
||||
jwt_payload_exp := now_s + (5*60) # 5 minutes from now
|
||||
|
||||
jwt_payload_sub := v if {
|
||||
v = session.user_id
|
||||
} else := ""
|
||||
|
||||
jwt_payload_user := v if {
|
||||
v = session.user_id
|
||||
} else := ""
|
||||
|
||||
jwt_payload_email := v if {
|
||||
v = directory_user.email
|
||||
} else := v if {
|
||||
v = user.email
|
||||
} else := ""
|
||||
|
||||
jwt_payload_groups := v if {
|
||||
v = array.concat(group_ids, get_databroker_group_names(group_ids))
|
||||
v != []
|
||||
} else := v if {
|
||||
v = session.claims.groups
|
||||
v != null
|
||||
} else := []
|
||||
|
||||
jwt_payload_name := v if {
|
||||
v = get_header_string_value(session.claims.name)
|
||||
} else := v if {
|
||||
v = get_header_string_value(user.claims.name)
|
||||
} else := ""
|
||||
|
||||
# the session id is always set to the input session id, even if impersonating
|
||||
jwt_payload_sid := input.session.id
|
||||
|
||||
base_jwt_claims := [
|
||||
["iss", jwt_payload_iss],
|
||||
["aud", jwt_payload_aud],
|
||||
["jti", jwt_payload_jti],
|
||||
["iat", jwt_payload_iat],
|
||||
["exp", jwt_payload_exp],
|
||||
["sub", jwt_payload_sub],
|
||||
["user", jwt_payload_user],
|
||||
["email", jwt_payload_email],
|
||||
["groups", jwt_payload_groups],
|
||||
["sid", jwt_payload_sid],
|
||||
["name", jwt_payload_name],
|
||||
]
|
||||
|
||||
additional_jwt_claims := [[k, v] |
|
||||
some header_name
|
||||
claim_key := data.jwt_claim_headers[header_name]
|
||||
|
||||
# exclude base_jwt_claims
|
||||
count([1 |
|
||||
[xk, xv] := base_jwt_claims[_]
|
||||
xk == claim_key
|
||||
]) == 0
|
||||
|
||||
# the claim value can come from session claims or user claims
|
||||
claim_value := object.get(session.claims, claim_key, object.get(user.claims, claim_key, null))
|
||||
|
||||
k := claim_key
|
||||
v := get_header_string_value(claim_value)
|
||||
]
|
||||
|
||||
jwt_claims := array.concat(base_jwt_claims, additional_jwt_claims)
|
||||
|
||||
jwt_payload := {key: value |
|
||||
# use a comprehension over an array to remove nil values
|
||||
[key, value] := jwt_claims[_]
|
||||
value != null
|
||||
}
|
||||
|
||||
signed_jwt := io.jwt.encode_sign(jwt_headers, jwt_payload, data.signing_key)
|
||||
|
||||
kubernetes_headers := h if {
|
||||
input.kubernetes_service_account_token != ""
|
||||
|
||||
h := remove_empty_header_values([
|
||||
["Authorization", concat(" ", ["Bearer", input.kubernetes_service_account_token])],
|
||||
["Impersonate-User", jwt_payload_email],
|
||||
["Impersonate-Group", get_header_string_value(jwt_payload_groups)],
|
||||
])
|
||||
} else := []
|
||||
|
||||
google_cloud_serverless_authentication_service_account := s if {
|
||||
s := data.google_cloud_serverless_authentication_service_account
|
||||
} else := ""
|
||||
|
||||
google_cloud_serverless_headers := h if {
|
||||
input.enable_google_cloud_serverless_authentication
|
||||
h := get_google_cloud_serverless_headers(google_cloud_serverless_authentication_service_account, input.to_audience)
|
||||
} else := {}
|
||||
|
||||
routing_key_headers := h if {
|
||||
input.enable_routing_key
|
||||
h := [["x-pomerium-routing-key", crypto.sha256(input.session.id)]]
|
||||
} else := []
|
||||
|
||||
session_id_token := v if {
|
||||
v := session.id_token.raw
|
||||
} else := ""
|
||||
|
||||
session_access_token := v if {
|
||||
v := session.oauth_token.access_token
|
||||
} else := ""
|
||||
|
||||
client_cert_fingerprint := v if {
|
||||
cert := crypto.x509.parse_certificates(trim_space(input.client_certificate.leaf))[0]
|
||||
v := crypto.sha256(base64.decode(cert.Raw))
|
||||
} else := ""
|
||||
|
||||
set_request_headers := h if {
|
||||
replacements := {
|
||||
"pomerium.id_token": session_id_token,
|
||||
"pomerium.access_token": session_access_token,
|
||||
"pomerium.client_cert_fingerprint": client_cert_fingerprint,
|
||||
"pomerium.jwt": signed_jwt,
|
||||
}
|
||||
h := [[header_name, header_value] |
|
||||
some header_name
|
||||
v := input.set_request_headers[header_name]
|
||||
header_value := pomerium.variable_substitution(v, replacements)
|
||||
]
|
||||
} else := []
|
||||
|
||||
identity_headers := {key: values |
|
||||
h1 := [["x-pomerium-jwt-assertion", signed_jwt]]
|
||||
h2 := [[header_name, header_value] |
|
||||
some header_name
|
||||
k := data.jwt_claim_headers[header_name]
|
||||
raw_header_value := array.concat(
|
||||
[cv |
|
||||
[ck, cv] := jwt_claims[_]
|
||||
ck == k
|
||||
],
|
||||
[""],
|
||||
)[0]
|
||||
|
||||
header_value := get_header_string_value(raw_header_value)
|
||||
]
|
||||
|
||||
h3 := kubernetes_headers
|
||||
h4 := [[k, v] | v := google_cloud_serverless_headers[k]]
|
||||
h5 := routing_key_headers
|
||||
h6 := set_request_headers
|
||||
|
||||
h := array.concat(array.concat(array.concat(array.concat(array.concat(h1, h2), h3), h4), h5), h6)
|
||||
|
||||
some i
|
||||
[key, v1] := h[i]
|
||||
values := [v2 |
|
||||
some j
|
||||
[k2, v2] := h[j]
|
||||
key == k2
|
||||
]
|
||||
}
|
||||
|
||||
get_databroker_group_names(ids) := gs if {
|
||||
gs := [name | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); name := group.name]
|
||||
}
|
||||
|
||||
get_databroker_group_emails(ids) := gs if {
|
||||
gs := [email | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); email := group.email]
|
||||
}
|
||||
|
||||
get_header_string_value(obj) := s if {
|
||||
is_array(obj)
|
||||
s := concat(",", obj)
|
||||
} else := s if {
|
||||
s := concat(",", [obj])
|
||||
}
|
||||
|
||||
remove_empty_header_values(arr) := [[k, v] |
|
||||
some idx
|
||||
k := arr[idx][0]
|
||||
v := arr[idx][1]
|
||||
v != ""
|
||||
]
|
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
|
@ -27,6 +28,10 @@ import (
|
|||
// A Store stores data for the OPA rego policy evaluation.
|
||||
type Store struct {
|
||||
opastorage.Store
|
||||
|
||||
googleCloudServerlessAuthenticationServiceAccount atomic.Pointer[string]
|
||||
jwtClaimHeaders atomic.Pointer[map[string]string]
|
||||
signingKey atomic.Pointer[jose.JSONWebKey]
|
||||
}
|
||||
|
||||
// New creates a new Store.
|
||||
|
@ -36,15 +41,37 @@ func New() *Store {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Store) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
|
||||
v := s.googleCloudServerlessAuthenticationServiceAccount.Load()
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
func (s *Store) GetJWTClaimHeaders() map[string]string {
|
||||
m := s.jwtClaimHeaders.Load()
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return *m
|
||||
}
|
||||
|
||||
func (s *Store) GetSigningKey() *jose.JSONWebKey {
|
||||
return s.signingKey.Load()
|
||||
}
|
||||
|
||||
// UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication
|
||||
// service account in the store.
|
||||
func (s *Store) UpdateGoogleCloudServerlessAuthenticationServiceAccount(serviceAccount string) {
|
||||
s.write("/google_cloud_serverless_authentication_service_account", serviceAccount)
|
||||
s.googleCloudServerlessAuthenticationServiceAccount.Store(&serviceAccount)
|
||||
}
|
||||
|
||||
// UpdateJWTClaimHeaders updates the jwt claim headers in the store.
|
||||
func (s *Store) UpdateJWTClaimHeaders(jwtClaimHeaders map[string]string) {
|
||||
s.write("/jwt_claim_headers", jwtClaimHeaders)
|
||||
s.jwtClaimHeaders.Store(&jwtClaimHeaders)
|
||||
}
|
||||
|
||||
// UpdateRoutePolicies updates the route policies in the store.
|
||||
|
@ -56,6 +83,7 @@ func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) {
|
|||
// in rego use JWKs, so we take in that format.
|
||||
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||
s.write("/signing_key", signingKey)
|
||||
s.signingKey.Store(signingKey)
|
||||
}
|
||||
|
||||
func (s *Store) write(rawPath string, value any) {
|
||||
|
@ -111,40 +139,17 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
|||
}
|
||||
span.AddAttributes(octrace.StringAttribute("record_type", recordType.String()))
|
||||
|
||||
value, ok := op2.Value.(ast.String)
|
||||
recordIDOrIndex, ok := op2.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid record id: %T", op2)
|
||||
}
|
||||
span.AddAttributes(octrace.StringAttribute("record_id", value.String()))
|
||||
span.AddAttributes(octrace.StringAttribute("record_id", recordIDOrIndex.String()))
|
||||
|
||||
req := &databroker.QueryRequest{
|
||||
Type: string(recordType),
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(string(value))
|
||||
|
||||
res, err := storage.GetQuerier(ctx).Query(ctx, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||
msg := s.GetDataBrokerRecord(ctx, string(recordType), string(recordIDOrIndex))
|
||||
if msg == nil {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
// exclude expired records
|
||||
if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil {
|
||||
if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
}
|
||||
|
||||
obj := toMap(msg)
|
||||
|
||||
regoValue, err := ast.InterfaceToValue(obj)
|
||||
|
@ -157,6 +162,38 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message {
|
||||
req := &databroker.QueryRequest{
|
||||
Type: recordType,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(recordIDOrIndex)
|
||||
|
||||
res, err := storage.GetQuerier(ctx).Query(ctx, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// exclude expired records
|
||||
if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil {
|
||||
if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func toMap(msg proto.Message) map[string]any {
|
||||
bs, _ := json.Marshal(msg)
|
||||
var obj map[string]any
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue