authorize: rewrite header evaluator to use go instead of rego (#5362)

* authorize: rewrite header evaluator to use go instead of rego

* cache signed jwt

* re-add missing trace

* address comments
This commit is contained in:
Caleb Doxsey 2024-11-07 13:07:16 -07:00 committed by GitHub
parent 177f789e63
commit 37017e2a5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 576 additions and 411 deletions

View file

@ -129,10 +129,7 @@ func New(
e.headersEvaluators = previous.headersEvaluators
cachedPolicyEvaluators = previous.policyEvaluators
} else {
e.headersEvaluators, err = NewHeadersEvaluator(ctx, store)
if err != nil {
return nil, err
}
e.headersEvaluators = NewHeadersEvaluator(store)
}
e.policyEvaluators, err = getOrCreatePolicyEvaluators(ctx, cfg, store, cachedPolicyEvaluators)
if err != nil {

View file

@ -4,14 +4,11 @@ import (
"context"
"fmt"
"net/http"
"os"
"time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/types"
"github.com/pomerium/pomerium/authorize/evaluator/opa"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace"
@ -69,115 +66,30 @@ type HeadersResponse struct {
Headers http.Header
}
var variableSubstitutionFunctionRegoOption = rego.Function2(&rego.Function{
Name: "pomerium.variable_substitution",
Decl: types.NewFunction(
types.Args(
types.Named("input_string", types.S),
types.Named("replacements",
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S))),
),
types.Named("output", types.S),
),
}, func(_ rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
inputString, ok := op1.Value.(ast.String)
if !ok {
return nil, fmt.Errorf("invalid input_string type: %T", op1.Value)
}
replacements, ok := op2.Value.(ast.Object)
if !ok {
return nil, fmt.Errorf("invalid replacements type: %T", op2.Value)
}
var err error
output := os.Expand(string(inputString), func(key string) string {
if key == "$" {
return "$" // allow a dollar sign to be escaped using $$
}
r := replacements.Get(ast.StringTerm(key))
if r == nil {
return ""
}
s, ok := r.Value.(ast.String)
if !ok {
err = fmt.Errorf("invalid replacement value type for key %q: %T", key, r.Value)
}
return string(s)
})
if err != nil {
return nil, err
}
return ast.StringTerm(output), nil
})
// A HeadersEvaluator evaluates the headers.rego script.
type HeadersEvaluator struct {
q rego.PreparedEvalQuery
store *store.Store
}
// NewHeadersEvaluator creates a new HeadersEvaluator.
func NewHeadersEvaluator(ctx context.Context, store *store.Store, options ...func(rego *rego.Rego)) (*HeadersEvaluator, error) {
r := rego.New(append([]func(*rego.Rego){
rego.Store(store),
rego.Module("pomerium.headers", opa.HeadersRego),
rego.Query("result := data.pomerium.headers"),
rego.EnablePrintStatements(true),
getGoogleCloudServerlessHeadersRegoOption,
variableSubstitutionFunctionRegoOption,
store.GetDataBrokerRecordOption(),
rego.SetRegoVersion(ast.RegoV1),
}, options...)...)
q, err := r.PrepareForEval(ctx)
if err != nil {
return nil, err
}
func NewHeadersEvaluator(store *store.Store) *HeadersEvaluator {
return &HeadersEvaluator{
q: q,
}, nil
store: store,
}
}
// Evaluate evaluates the headers.rego script.
func (e *HeadersEvaluator) Evaluate(ctx context.Context, req *HeadersRequest, options ...rego.EvalOption) (*HeadersResponse, error) {
ctx, span := trace.StartSpan(ctx, "authorize.HeadersEvaluator.Evaluate")
defer span.End()
rs, err := safeEval(ctx, e.q, append([]rego.EvalOption{rego.EvalInput(req)}, options...)...)
if err != nil {
return nil, fmt.Errorf("authorize: error evaluating headers.rego: %w", err)
}
if len(rs) == 0 {
return nil, fmt.Errorf("authorize: unexpected empty result from evaluating headers.rego")
ectx := new(rego.EvalContext)
for _, option := range options {
option(ectx)
}
return &HeadersResponse{
Headers: e.getHeader(rs[0].Bindings),
}, nil
}
func (e *HeadersEvaluator) getHeader(vars rego.Vars) http.Header {
h := make(http.Header)
m, ok := vars["result"].(map[string]any)
if !ok {
return h
}
m, ok = m["identity_headers"].(map[string]any)
if !ok {
return h
}
for k := range m {
vs, ok := m[k].([]any)
if !ok {
continue
}
for _, v := range vs {
h.Add(k, fmt.Sprintf("%v", v))
}
}
return h
now := ectx.Time()
if now.IsZero() {
now = time.Now()
}
return newHeadersEvaluatorEvaluation(e, req, now).execute(ctx)
}

View file

@ -0,0 +1,496 @@
package evaluator
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"os"
"reflect"
"strings"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
// A headersEvaluatorEvaluation is a single evaluation of the headers evaluator.
type headersEvaluatorEvaluation struct {
evaluator *HeadersEvaluator
request *HeadersRequest
response *HeadersResponse
now time.Time
gotSessionOrServiceAccount bool
cachedSession *session.Session
cachedServiceAccount *user.ServiceAccount
gotUser bool
cachedUser *user.User
gotDirectoryUser bool
cachedDirectoryUser *structpb.Struct
gotJWTPayloadJTI bool
cachedJWTPayloadJTI string
gotJWTPayload bool
cachedJWTPayload map[string]any
gotSignedJWT bool
cachedSignedJWT string
}
func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *HeadersRequest, now time.Time) *headersEvaluatorEvaluation {
return &headersEvaluatorEvaluation{
evaluator: evaluator,
request: request,
response: &HeadersResponse{Headers: make(http.Header)},
now: now,
}
}
func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) {
e.fillHeaders(ctx)
return e.response, nil
}
func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) {
e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx))
}
func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) {
claims := e.getJWTPayload(ctx)
for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
claim, ok := claims[claimKey]
if !ok {
e.response.Headers.Add(headerName, "")
continue
}
e.response.Headers.Add(headerName, getHeaderStringValue(claim))
}
}
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
if e.request.KubernetesServiceAccountToken == "" {
return
}
e.response.Headers.Add("Authorization", "Bearer "+e.request.KubernetesServiceAccountToken)
impersonateUser := e.getJWTPayloadEmail(ctx)
if impersonateUser != "" {
e.response.Headers.Add("Impersonate-User", impersonateUser)
}
impersonateGroups := e.getJWTPayloadGroups(ctx)
if len(impersonateGroups) > 0 {
e.response.Headers.Add("Impersonate-Group", strings.Join(impersonateGroups, ","))
}
}
func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) {
if e.request.EnableGoogleCloudServerlessAuthentication {
h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), e.request.ToAudience)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers")
return
}
for k, v := range h {
e.response.Headers.Add(k, v)
}
}
}
func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() {
if e.request.EnableRoutingKey {
e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID))
}
}
func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) {
for k, v := range e.request.SetRequestHeaders {
e.response.Headers.Add(k, os.Expand(v, func(name string) string {
switch name {
case "$":
return "$"
case "pomerium.access_token":
s, _ := e.getSessionOrServiceAccount(ctx)
return s.GetOauthToken().GetAccessToken()
case "pomerium.client_cert_fingerprint":
return e.getClientCertFingerprint()
case "pomerium.id_token":
s, _ := e.getSessionOrServiceAccount(ctx)
return s.GetIdToken().GetRaw()
case "pomerium.jwt":
return e.getSignedJWT(ctx)
}
return ""
}))
}
}
func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) {
e.fillJWTAssertionHeader(ctx)
e.fillJWTClaimHeaders(ctx)
e.fillKubernetesHeaders(ctx)
e.fillGoogleCloudServerlessHeaders(ctx)
e.fillRoutingKeyHeaders()
e.fillSetRequestHeaders(ctx)
}
func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) {
if e.gotSessionOrServiceAccount {
return e.cachedSession, e.cachedServiceAccount
}
e.gotSessionOrServiceAccount = true
if e.request.Session.ID != "" {
e.cachedServiceAccount, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.ServiceAccount", e.request.Session.ID).(*user.ServiceAccount)
}
if e.request.Session.ID != "" && e.cachedServiceAccount == nil {
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.request.Session.ID).(*session.Session)
}
if e.cachedSession != nil && e.cachedSession.GetImpersonateSessionId() != "" {
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.cachedSession.GetImpersonateSessionId()).(*session.Session)
}
return e.cachedSession, e.cachedServiceAccount
}
func (e *headersEvaluatorEvaluation) getUser(ctx context.Context) *user.User {
if e.gotUser {
return e.cachedUser
}
e.gotUser = true
s, sa := e.getSessionOrServiceAccount(ctx)
if sa != nil && sa.UserId != "" {
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", sa.UserId).(*user.User)
} else if s != nil && s.UserId != "" {
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", s.UserId).(*user.User)
}
return e.cachedUser
}
func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string {
cert, err := cryptutil.ParsePEMCertificate([]byte(e.request.ClientCertificate.Leaf))
if err != nil {
return ""
}
return cryptoSHA256(cert.Raw)
}
func (e *headersEvaluatorEvaluation) getDirectoryUser(ctx context.Context) *structpb.Struct {
if e.gotDirectoryUser {
return e.cachedDirectoryUser
}
e.gotDirectoryUser = true
s, sa := e.getSessionOrServiceAccount(ctx)
if sa != nil && sa.UserId != "" {
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, sa.UserId).(*structpb.Struct)
} else if s != nil && s.UserId != "" {
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, s.UserId).(*structpb.Struct)
}
return e.cachedDirectoryUser
}
func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string {
du := e.getDirectoryUser(ctx)
if groupIDs, ok := getStructStringSlice(du, "group_ids"); ok {
return groupIDs
}
return make([]string, 0)
}
func (e *headersEvaluatorEvaluation) getJWTPayloadIss() string {
return e.request.Issuer
}
func (e *headersEvaluatorEvaluation) getJWTPayloadAud() string {
return e.request.Audience
}
func (e *headersEvaluatorEvaluation) getJWTPayloadJTI() string {
if e.gotJWTPayloadJTI {
return e.cachedJWTPayloadJTI
}
e.gotJWTPayloadJTI = true
e.cachedJWTPayloadJTI = uuid.New().String()
return e.cachedJWTPayloadJTI
}
func (e *headersEvaluatorEvaluation) getJWTPayloadIAT() int64 {
return e.now.Unix()
}
func (e *headersEvaluatorEvaluation) getJWTPayloadExp() int64 {
return e.now.Add(5 * time.Minute).Unix()
}
func (e *headersEvaluatorEvaluation) getJWTPayloadSub(ctx context.Context) string {
return e.getJWTPayloadUser(ctx)
}
func (e *headersEvaluatorEvaluation) getJWTPayloadUser(ctx context.Context) string {
s, sa := e.getSessionOrServiceAccount(ctx)
if sa != nil {
return sa.UserId
}
if s != nil {
return s.UserId
}
return ""
}
func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) string {
du := e.getDirectoryUser(ctx)
if v, ok := getStructString(du, "email"); ok {
return v
}
u := e.getUser(ctx)
if u != nil {
return u.Email
}
return ""
}
func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string {
groupIDs := e.getGroupIDs(ctx)
if len(groupIDs) > 0 {
groups := make([]string, 0, len(groupIDs)*2)
groups = append(groups, groupIDs...)
groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...)
return groups
}
s, _ := e.getSessionOrServiceAccount(ctx)
groups, _ := getClaimStringSlice(s, "groups")
return groups
}
func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string {
return e.request.Session.ID
}
func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) string {
s, _ := e.getSessionOrServiceAccount(ctx)
if names, ok := getClaimStringSlice(s, "name"); ok {
return strings.Join(names, ",")
}
u := e.getUser(ctx)
if names, ok := getClaimStringSlice(u, "name"); ok {
return strings.Join(names, ",")
}
return ""
}
func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) map[string]any {
if e.gotJWTPayload {
return e.cachedJWTPayload
}
e.gotJWTPayload = true
e.cachedJWTPayload = map[string]any{
"iss": e.getJWTPayloadIss(),
"aud": e.getJWTPayloadAud(),
"jti": e.getJWTPayloadJTI(),
"iat": e.getJWTPayloadIAT(),
"exp": e.getJWTPayloadExp(),
"sub": e.getJWTPayloadSub(ctx),
"user": e.getJWTPayloadUser(ctx),
"email": e.getJWTPayloadEmail(ctx),
"groups": e.getJWTPayloadGroups(ctx),
"sid": e.getJWTPayloadSID(),
"name": e.getJWTPayloadName(ctx),
}
s, _ := e.getSessionOrServiceAccount(ctx)
u := e.getUser(ctx)
for _, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
// ignore base claims
if _, ok := e.cachedJWTPayload[claimKey]; ok {
continue
}
if vs, ok := getClaimStringSlice(s, claimKey); ok {
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
} else if vs, ok := getClaimStringSlice(u, claimKey); ok {
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
}
}
return e.cachedJWTPayload
}
func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
if e.gotSignedJWT {
return e.cachedSignedJWT
}
e.gotSignedJWT = true
signingKey := e.evaluator.store.GetSigningKey()
if signingKey == nil {
log.Ctx(ctx).Error().Msg("authorize/header-evaluator: missing signing key")
return ""
}
signingOptions := new(jose.SignerOptions).
WithType("JWT").
WithHeader("kid", signingKey.KeyID).
WithHeader("alg", signingKey.Algorithm)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(signingKey.Algorithm),
Key: signingKey.Key,
}, signingOptions)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT signer")
return ""
}
jwtPayload := e.getJWTPayload(ctx)
bs, err := json.Marshal(jwtPayload)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload")
return ""
}
jwt, err := signer.Sign(bs)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error signing JWT")
return ""
}
e.cachedSignedJWT, err = jwt.CompactSerialize()
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error serializing JWT")
return ""
}
return e.cachedSignedJWT
}
func (e *headersEvaluatorEvaluation) getDataBrokerGroupNames(ctx context.Context, groupIDs []string) []string {
groupNames := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
dg, _ := e.evaluator.store.GetDataBrokerRecord(ctx, directory.GroupRecordType, groupID).(*structpb.Struct)
if name, ok := getStructString(dg, "name"); ok {
groupNames = append(groupNames, name)
}
}
return groupNames
}
type hasGetClaims interface {
GetClaims() map[string]*structpb.ListValue
}
func getClaimStringSlice(msg hasGetClaims, field string) ([]string, bool) {
if msg == nil {
return nil, false
}
claims := msg.GetClaims()
if claims == nil {
return nil, false
}
lv, ok := claims[field]
if !ok {
return nil, false
}
strs := make([]string, 0, len(lv.Values))
for _, v := range lv.Values {
switch v := v.GetKind().(type) {
case *structpb.Value_NumberValue:
strs = append(strs, fmt.Sprint(v.NumberValue))
case *structpb.Value_StringValue:
strs = append(strs, v.StringValue)
case *structpb.Value_BoolValue:
strs = append(strs, fmt.Sprint(v.BoolValue))
// just ignore these types
case *structpb.Value_NullValue:
case *structpb.Value_StructValue:
case *structpb.Value_ListValue:
}
}
return strs, true
}
func getStructString(s *structpb.Struct, field string) (string, bool) {
if s == nil || s.Fields == nil {
return "", false
}
v, ok := s.Fields[field]
if !ok {
return "", false
}
return fmt.Sprint(v.AsInterface()), true
}
func getStructStringSlice(s *structpb.Struct, field string) ([]string, bool) {
if s == nil || s.Fields == nil {
return nil, false
}
v, ok := s.Fields[field]
if !ok {
return nil, false
}
lv := v.GetListValue()
if lv == nil {
return nil, false
}
strs := make([]string, len(lv.Values))
for i, vv := range lv.Values {
sv, ok := vv.Kind.(*structpb.Value_StringValue)
if !ok {
return nil, false
}
strs[i] = sv.StringValue
}
return strs, true
}
func cryptoSHA256[T string | []byte](input T) string {
output := sha256.Sum256([]byte(input))
return hex.EncodeToString(output[:])
}
func getHeaderStringValue(obj any) string {
v := reflect.ValueOf(obj)
switch v.Kind() {
case reflect.Slice:
var str strings.Builder
for i := 0; i < v.Len(); i++ {
if i > 0 {
str.WriteByte(',')
}
str.WriteString(getHeaderStringValue(v.Index(i).Interface()))
}
return str.String()
}
return fmt.Sprint(obj)
}

View file

@ -60,8 +60,7 @@ func BenchmarkHeadersEvaluator(b *testing.B) {
s.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
s.UpdateSigningKey(privateJWK)
e, err := NewHeadersEvaluator(ctx, s, rego.Time(iat))
require.NoError(b, err)
e := NewHeadersEvaluator(s)
req := &HeadersRequest{
EnableRoutingKey: true,
@ -198,14 +197,13 @@ func TestHeadersEvaluator(t *testing.T) {
iat := time.Unix(1686870680, 0)
eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
eval := func(_ *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New()
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("name", "email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK)
e, err := NewHeadersEvaluator(ctx, store, rego.Time(iat))
require.NoError(t, err)
e := NewHeadersEvaluator(store)
return e.Evaluate(ctx, input, rego.EvalTime(iat))
}

View file

@ -1,10 +0,0 @@
// Package opa implements the policy evaluator interface to make authorization
// decisions.
package opa
import _ "embed" // to embed files
// HeadersRego is the headers.rego script.
//
//go:embed policy/headers.rego
var HeadersRego string

View file

@ -1,265 +0,0 @@
package pomerium.headers
import rego.v1
# input:
# enable_google_cloud_serverless_authentication: boolean
# enable_routing_key: boolean
# client_certificate:
# leaf: string
# issuer: string
# kubernetes_service_account_token: string
# session:
# id: string
# to_audience: string
# set_request_headers: map[string]string
#
# data:
# jwt_claim_headers: map[string]string
# signing_key:
# alg: string
# kid: string
#
# functions:
# get_databroker_record
# get_google_cloud_serverless_headers
#
#
# output:
# identity_headers: map[string][]string
now_s := round(time.now_ns() / 1e9)
# get the session
session := v if {
# try a service account
v = get_databroker_record("type.googleapis.com/user.ServiceAccount", input.session.id)
v != null
} else := iv if {
# try an impersonated session
v = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
v != null
object.get(v, "impersonate_session_id", "") != ""
iv = get_databroker_record("type.googleapis.com/session.Session", v.impersonate_session_id)
iv != null
} else := v if {
# try a normal session
v = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
v != null
object.get(v, "impersonate_session_id", "") == ""
} else := {}
user := u if {
u = get_databroker_record("type.googleapis.com/user.User", session.user_id)
u != null
} else := {}
directory_user := du if {
du = get_databroker_record("pomerium.io/DirectoryUser", session.user_id)
du != null
} else := {}
group_ids := gs if {
gs = directory_user.group_ids
gs != null
} else := []
groups := array.concat(group_ids, array.concat(get_databroker_group_names(group_ids), get_databroker_group_emails(group_ids)))
jwt_headers := {
"typ": "JWT",
"alg": data.signing_key.alg,
"kid": data.signing_key.kid,
}
jwt_payload_aud := v if {
v := input.audience
} else := ""
jwt_payload_iss := v if {
v := input.issuer
} else := ""
jwt_payload_jti := uuid.rfc4122("jti")
jwt_payload_iat := now_s
jwt_payload_exp := now_s + (5*60) # 5 minutes from now
jwt_payload_sub := v if {
v = session.user_id
} else := ""
jwt_payload_user := v if {
v = session.user_id
} else := ""
jwt_payload_email := v if {
v = directory_user.email
} else := v if {
v = user.email
} else := ""
jwt_payload_groups := v if {
v = array.concat(group_ids, get_databroker_group_names(group_ids))
v != []
} else := v if {
v = session.claims.groups
v != null
} else := []
jwt_payload_name := v if {
v = get_header_string_value(session.claims.name)
} else := v if {
v = get_header_string_value(user.claims.name)
} else := ""
# the session id is always set to the input session id, even if impersonating
jwt_payload_sid := input.session.id
base_jwt_claims := [
["iss", jwt_payload_iss],
["aud", jwt_payload_aud],
["jti", jwt_payload_jti],
["iat", jwt_payload_iat],
["exp", jwt_payload_exp],
["sub", jwt_payload_sub],
["user", jwt_payload_user],
["email", jwt_payload_email],
["groups", jwt_payload_groups],
["sid", jwt_payload_sid],
["name", jwt_payload_name],
]
additional_jwt_claims := [[k, v] |
some header_name
claim_key := data.jwt_claim_headers[header_name]
# exclude base_jwt_claims
count([1 |
[xk, xv] := base_jwt_claims[_]
xk == claim_key
]) == 0
# the claim value can come from session claims or user claims
claim_value := object.get(session.claims, claim_key, object.get(user.claims, claim_key, null))
k := claim_key
v := get_header_string_value(claim_value)
]
jwt_claims := array.concat(base_jwt_claims, additional_jwt_claims)
jwt_payload := {key: value |
# use a comprehension over an array to remove nil values
[key, value] := jwt_claims[_]
value != null
}
signed_jwt := io.jwt.encode_sign(jwt_headers, jwt_payload, data.signing_key)
kubernetes_headers := h if {
input.kubernetes_service_account_token != ""
h := remove_empty_header_values([
["Authorization", concat(" ", ["Bearer", input.kubernetes_service_account_token])],
["Impersonate-User", jwt_payload_email],
["Impersonate-Group", get_header_string_value(jwt_payload_groups)],
])
} else := []
google_cloud_serverless_authentication_service_account := s if {
s := data.google_cloud_serverless_authentication_service_account
} else := ""
google_cloud_serverless_headers := h if {
input.enable_google_cloud_serverless_authentication
h := get_google_cloud_serverless_headers(google_cloud_serverless_authentication_service_account, input.to_audience)
} else := {}
routing_key_headers := h if {
input.enable_routing_key
h := [["x-pomerium-routing-key", crypto.sha256(input.session.id)]]
} else := []
session_id_token := v if {
v := session.id_token.raw
} else := ""
session_access_token := v if {
v := session.oauth_token.access_token
} else := ""
client_cert_fingerprint := v if {
cert := crypto.x509.parse_certificates(trim_space(input.client_certificate.leaf))[0]
v := crypto.sha256(base64.decode(cert.Raw))
} else := ""
set_request_headers := h if {
replacements := {
"pomerium.id_token": session_id_token,
"pomerium.access_token": session_access_token,
"pomerium.client_cert_fingerprint": client_cert_fingerprint,
"pomerium.jwt": signed_jwt,
}
h := [[header_name, header_value] |
some header_name
v := input.set_request_headers[header_name]
header_value := pomerium.variable_substitution(v, replacements)
]
} else := []
identity_headers := {key: values |
h1 := [["x-pomerium-jwt-assertion", signed_jwt]]
h2 := [[header_name, header_value] |
some header_name
k := data.jwt_claim_headers[header_name]
raw_header_value := array.concat(
[cv |
[ck, cv] := jwt_claims[_]
ck == k
],
[""],
)[0]
header_value := get_header_string_value(raw_header_value)
]
h3 := kubernetes_headers
h4 := [[k, v] | v := google_cloud_serverless_headers[k]]
h5 := routing_key_headers
h6 := set_request_headers
h := array.concat(array.concat(array.concat(array.concat(array.concat(h1, h2), h3), h4), h5), h6)
some i
[key, v1] := h[i]
values := [v2 |
some j
[k2, v2] := h[j]
key == k2
]
}
get_databroker_group_names(ids) := gs if {
gs := [name | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); name := group.name]
}
get_databroker_group_emails(ids) := gs if {
gs := [email | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); email := group.email]
}
get_header_string_value(obj) := s if {
is_array(obj)
s := concat(",", obj)
} else := s if {
s := concat(",", [obj])
}
remove_empty_header_values(arr) := [[k, v] |
some idx
k := arr[idx][0]
v := arr[idx][1]
v != ""
]

View file

@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync/atomic"
"time"
"github.com/go-jose/go-jose/v3"
@ -27,6 +28,10 @@ import (
// A Store stores data for the OPA rego policy evaluation.
type Store struct {
opastorage.Store
googleCloudServerlessAuthenticationServiceAccount atomic.Pointer[string]
jwtClaimHeaders atomic.Pointer[map[string]string]
signingKey atomic.Pointer[jose.JSONWebKey]
}
// New creates a new Store.
@ -36,15 +41,37 @@ func New() *Store {
}
}
func (s *Store) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
v := s.googleCloudServerlessAuthenticationServiceAccount.Load()
if v == nil {
return ""
}
return *v
}
func (s *Store) GetJWTClaimHeaders() map[string]string {
m := s.jwtClaimHeaders.Load()
if m == nil {
return nil
}
return *m
}
func (s *Store) GetSigningKey() *jose.JSONWebKey {
return s.signingKey.Load()
}
// UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication
// service account in the store.
func (s *Store) UpdateGoogleCloudServerlessAuthenticationServiceAccount(serviceAccount string) {
s.write("/google_cloud_serverless_authentication_service_account", serviceAccount)
s.googleCloudServerlessAuthenticationServiceAccount.Store(&serviceAccount)
}
// UpdateJWTClaimHeaders updates the jwt claim headers in the store.
func (s *Store) UpdateJWTClaimHeaders(jwtClaimHeaders map[string]string) {
s.write("/jwt_claim_headers", jwtClaimHeaders)
s.jwtClaimHeaders.Store(&jwtClaimHeaders)
}
// UpdateRoutePolicies updates the route policies in the store.
@ -56,6 +83,7 @@ func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) {
// in rego use JWKs, so we take in that format.
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
s.write("/signing_key", signingKey)
s.signingKey.Store(signingKey)
}
func (s *Store) write(rawPath string, value any) {
@ -111,40 +139,17 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
}
span.AddAttributes(octrace.StringAttribute("record_type", recordType.String()))
value, ok := op2.Value.(ast.String)
recordIDOrIndex, ok := op2.Value.(ast.String)
if !ok {
return nil, fmt.Errorf("invalid record id: %T", op2)
}
span.AddAttributes(octrace.StringAttribute("record_id", value.String()))
span.AddAttributes(octrace.StringAttribute("record_id", recordIDOrIndex.String()))
req := &databroker.QueryRequest{
Type: string(recordType),
Limit: 1,
}
req.SetFilterByIDOrIndex(string(value))
res, err := storage.GetQuerier(ctx).Query(ctx, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
return ast.NullTerm(), nil
}
if len(res.GetRecords()) == 0 {
return ast.NullTerm(), nil
}
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
msg := s.GetDataBrokerRecord(ctx, string(recordType), string(recordIDOrIndex))
if msg == nil {
return ast.NullTerm(), nil
}
// exclude expired records
if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil {
if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) {
return ast.NullTerm(), nil
}
}
obj := toMap(msg)
regoValue, err := ast.InterfaceToValue(obj)
@ -157,6 +162,38 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
})
}
func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message {
req := &databroker.QueryRequest{
Type: recordType,
Limit: 1,
}
req.SetFilterByIDOrIndex(recordIDOrIndex)
res, err := storage.GetQuerier(ctx).Query(ctx, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
return nil
}
if len(res.GetRecords()) == 0 {
return nil
}
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
if msg == nil {
return nil
}
// exclude expired records
if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil {
if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) {
return nil
}
}
return msg
}
func toMap(msg proto.Message) map[string]any {
bs, _ := json.Marshal(msg)
var obj map[string]any