pomerium/authorize/evaluator/headers_evaluator_evaluation.go
Kenneth Jenkins 21b9e7890c
authorize: add filter options for JWT groups (#5417)
Add a new option for filtering to a subset of directory groups in the
Pomerium JWT and Impersonate-Group headers. Add a JWTGroupsFilter field
to both the Options struct (for a global filter) and to the Policy
struct (for per-route filter). These will be populated only from the
config protos, and not from a config file.

If either filter is set, then for each of a user's groups, the group
name or group ID will be added to the JWT groups claim only if it is an
exact string match with one of the elements of either filter.
2025-01-08 13:57:57 -08:00

585 lines
16 KiB
Go

package evaluator
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"os"
"reflect"
"strings"
"time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/slices"
)
// A headersEvaluatorEvaluation is a single evaluation of the headers evaluator.
type headersEvaluatorEvaluation struct {
evaluator *HeadersEvaluator
request *Request
response *HeadersResponse
now time.Time
gotSessionOrServiceAccount bool
cachedSession *session.Session
cachedServiceAccount *user.ServiceAccount
gotUser bool
cachedUser *user.User
gotDirectoryUser bool
cachedDirectoryUser *structpb.Struct
gotJWTPayloadJTI bool
cachedJWTPayloadJTI string
gotJWTPayload bool
cachedJWTPayload map[string]any
gotSignedJWT bool
cachedSignedJWT string
}
func newHeadersEvaluatorEvaluation(evaluator *HeadersEvaluator, request *Request, now time.Time) *headersEvaluatorEvaluation {
return &headersEvaluatorEvaluation{
evaluator: evaluator,
request: request,
response: &HeadersResponse{Headers: make(http.Header)},
now: now,
}
}
func (e *headersEvaluatorEvaluation) execute(ctx context.Context) (*HeadersResponse, error) {
err := e.fillHeaders(ctx)
return e.response, err
}
func (e *headersEvaluatorEvaluation) fillJWTAssertionHeader(ctx context.Context) {
e.response.Headers.Add("x-pomerium-jwt-assertion", e.getSignedJWT(ctx))
}
func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) error {
claims, err := e.getJWTPayload(ctx)
if err != nil {
return err
}
for headerName, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
claim, ok := claims[claimKey]
if !ok {
e.response.Headers.Add(headerName, "")
continue
}
e.response.Headers.Add(headerName, getHeaderStringValue(claim))
}
return nil
}
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
if e.request.Policy == nil {
return
}
token, err := e.request.Policy.GetKubernetesServiceAccountToken()
if err != nil || token == "" {
return
}
e.response.Headers.Add("Authorization", "Bearer "+token)
impersonateUser := e.getJWTPayloadEmail(ctx)
if impersonateUser != "" {
e.response.Headers.Add("Impersonate-User", impersonateUser)
}
impersonateGroups := e.getJWTPayloadGroups(ctx)
if len(impersonateGroups) > 0 {
e.response.Headers.Add("Impersonate-Group", strings.Join(impersonateGroups, ","))
}
}
func (e *headersEvaluatorEvaluation) fillGoogleCloudServerlessHeaders(ctx context.Context) {
if e.request.Policy == nil || !e.request.Policy.EnableGoogleCloudServerlessAuthentication {
return
}
var toAudience string
for _, wu := range e.request.Policy.To {
toAudience = "https://" + wu.URL.Hostname()
}
h, err := getGoogleCloudServerlessHeaders(e.evaluator.store.GetGoogleCloudServerlessAuthenticationServiceAccount(), toAudience)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error retrieving google cloud serverless headers")
return
}
for k, v := range h {
e.response.Headers.Add(k, v)
}
}
func (e *headersEvaluatorEvaluation) fillRoutingKeyHeaders() {
if e.request.Policy == nil {
return
}
if e.request.Policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_RING_HASH ||
e.request.Policy.EnvoyOpts.GetLbPolicy() == envoy_config_cluster_v3.Cluster_MAGLEV {
e.response.Headers.Add("x-pomerium-routing-key", cryptoSHA256(e.request.Session.ID))
}
}
func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) {
if e.request.Policy == nil {
return
}
for k, v := range e.request.Policy.SetRequestHeaders {
e.response.Headers.Add(k, os.Expand(v, func(name string) string {
switch name {
case "$":
return "$"
case "pomerium.access_token":
s, _ := e.getSessionOrServiceAccount(ctx)
return s.GetOauthToken().GetAccessToken()
case "pomerium.client_cert_fingerprint":
return e.getClientCertFingerprint()
case "pomerium.id_token":
s, _ := e.getSessionOrServiceAccount(ctx)
return s.GetIdToken().GetRaw()
case "pomerium.jwt":
return e.getSignedJWT(ctx)
}
return ""
}))
}
}
func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error {
e.fillJWTAssertionHeader(ctx)
if err := e.fillJWTClaimHeaders(ctx); err != nil {
return err
}
e.fillKubernetesHeaders(ctx)
e.fillGoogleCloudServerlessHeaders(ctx)
e.fillRoutingKeyHeaders()
e.fillSetRequestHeaders(ctx)
return nil
}
func (e *headersEvaluatorEvaluation) getSessionOrServiceAccount(ctx context.Context) (*session.Session, *user.ServiceAccount) {
if e.gotSessionOrServiceAccount {
return e.cachedSession, e.cachedServiceAccount
}
e.gotSessionOrServiceAccount = true
if e.request.Session.ID != "" {
e.cachedServiceAccount, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.ServiceAccount", e.request.Session.ID).(*user.ServiceAccount)
}
if e.request.Session.ID != "" && e.cachedServiceAccount == nil {
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.request.Session.ID).(*session.Session)
}
if e.cachedSession != nil && e.cachedSession.GetImpersonateSessionId() != "" {
e.cachedSession, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/session.Session", e.cachedSession.GetImpersonateSessionId()).(*session.Session)
}
return e.cachedSession, e.cachedServiceAccount
}
func (e *headersEvaluatorEvaluation) getUser(ctx context.Context) *user.User {
if e.gotUser {
return e.cachedUser
}
e.gotUser = true
s, sa := e.getSessionOrServiceAccount(ctx)
if sa != nil && sa.UserId != "" {
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", sa.UserId).(*user.User)
} else if s != nil && s.UserId != "" {
e.cachedUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, "type.googleapis.com/user.User", s.UserId).(*user.User)
}
return e.cachedUser
}
func (e *headersEvaluatorEvaluation) getClientCertFingerprint() string {
cert, err := cryptutil.ParsePEMCertificate([]byte(e.request.HTTP.ClientCertificate.Leaf))
if err != nil {
return ""
}
return cryptoSHA256(cert.Raw)
}
func (e *headersEvaluatorEvaluation) getDirectoryUser(ctx context.Context) *structpb.Struct {
if e.gotDirectoryUser {
return e.cachedDirectoryUser
}
e.gotDirectoryUser = true
s, sa := e.getSessionOrServiceAccount(ctx)
if sa != nil && sa.UserId != "" {
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, sa.UserId).(*structpb.Struct)
} else if s != nil && s.UserId != "" {
e.cachedDirectoryUser, _ = e.evaluator.store.GetDataBrokerRecord(ctx, directory.UserRecordType, s.UserId).(*structpb.Struct)
}
return e.cachedDirectoryUser
}
func (e *headersEvaluatorEvaluation) getGroupIDs(ctx context.Context) []string {
du := e.getDirectoryUser(ctx)
if groupIDs, ok := getStructStringSlice(du, "group_ids"); ok {
return groupIDs
}
return make([]string, 0)
}
func (e *headersEvaluatorEvaluation) getJWTPayloadIss() (string, error) {
var issuerFormat string
if e.request.Policy != nil {
issuerFormat = e.request.Policy.JWTIssuerFormat
}
switch issuerFormat {
case "uri":
return fmt.Sprintf("https://%s/", e.request.HTTP.Hostname), nil
case "", "hostOnly":
return e.request.HTTP.Hostname, nil
default:
return "", fmt.Errorf("unsupported JWT issuer format: %s", issuerFormat)
}
}
func (e *headersEvaluatorEvaluation) getJWTPayloadAud() string {
return e.request.HTTP.Hostname
}
func (e *headersEvaluatorEvaluation) getJWTPayloadJTI() string {
if e.gotJWTPayloadJTI {
return e.cachedJWTPayloadJTI
}
e.gotJWTPayloadJTI = true
e.cachedJWTPayloadJTI = uuid.New().String()
return e.cachedJWTPayloadJTI
}
func (e *headersEvaluatorEvaluation) getJWTPayloadIAT() int64 {
return e.now.Unix()
}
func (e *headersEvaluatorEvaluation) getJWTPayloadExp() int64 {
return e.now.Add(5 * time.Minute).Unix()
}
func (e *headersEvaluatorEvaluation) getJWTPayloadSub(ctx context.Context) string {
return e.getJWTPayloadUser(ctx)
}
func (e *headersEvaluatorEvaluation) getJWTPayloadUser(ctx context.Context) string {
s, sa := e.getSessionOrServiceAccount(ctx)
if sa != nil {
return sa.UserId
}
if s != nil {
return s.UserId
}
return ""
}
func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) string {
du := e.getDirectoryUser(ctx)
if v, ok := getStructString(du, "email"); ok {
return v
}
u := e.getUser(ctx)
if u != nil {
return u.Email
}
return ""
}
func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string {
groups := e.getFilteredGroups(ctx)
if groups == nil {
// If there are no groups, marshal this claim as an empty list rather than a JSON null,
// for better compatibility with third-party libraries.
// See https://github.com/pomerium/pomerium/issues/5393 for one example.
groups = []string{}
}
return groups
}
func (e *headersEvaluatorEvaluation) getFilteredGroups(ctx context.Context) []string {
groups := e.getAllGroups(ctx)
// Apply the global groups filter or the per-route groups filter, if either is enabled.
filters := make([]config.JWTGroupsFilter, 0, 2)
if f := e.evaluator.store.GetJWTGroupsFilter(); f.Enabled() {
filters = append(filters, f)
}
if e.request.Policy != nil && e.request.Policy.JWTGroupsFilter.Enabled() {
filters = append(filters, e.request.Policy.JWTGroupsFilter)
}
if len(filters) == 0 {
return groups
}
return slices.Filter(groups, func(g string) bool {
// A group should be included if it appears in either the global or the route-level filter list.
for _, f := range filters {
if f.IsAllowed(g) {
return true
}
}
return false
})
}
// getAllGroups returns the full group names/IDs list (without any filtering).
func (e *headersEvaluatorEvaluation) getAllGroups(ctx context.Context) []string {
groupIDs := e.getGroupIDs(ctx)
if len(groupIDs) > 0 {
groups := make([]string, 0, len(groupIDs)*2)
groups = append(groups, groupIDs...)
groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...)
return groups
}
s, _ := e.getSessionOrServiceAccount(ctx)
groups, _ := getClaimStringSlice(s, "groups")
return groups
}
func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string {
return e.request.Session.ID
}
func (e *headersEvaluatorEvaluation) getJWTPayloadName(ctx context.Context) string {
s, _ := e.getSessionOrServiceAccount(ctx)
if names, ok := getClaimStringSlice(s, "name"); ok {
return strings.Join(names, ",")
}
u := e.getUser(ctx)
if names, ok := getClaimStringSlice(u, "name"); ok {
return strings.Join(names, ",")
}
return ""
}
func (e *headersEvaluatorEvaluation) getJWTPayload(ctx context.Context) (map[string]any, error) {
if e.gotJWTPayload {
return e.cachedJWTPayload, nil
}
iss, err := e.getJWTPayloadIss()
if err != nil {
return nil, err
}
e.gotJWTPayload = true
e.cachedJWTPayload = map[string]any{
"iss": iss,
"aud": e.getJWTPayloadAud(),
"jti": e.getJWTPayloadJTI(),
"iat": e.getJWTPayloadIAT(),
"exp": e.getJWTPayloadExp(),
"sub": e.getJWTPayloadSub(ctx),
"user": e.getJWTPayloadUser(ctx),
"email": e.getJWTPayloadEmail(ctx),
"groups": e.getJWTPayloadGroups(ctx),
"sid": e.getJWTPayloadSID(),
"name": e.getJWTPayloadName(ctx),
}
s, _ := e.getSessionOrServiceAccount(ctx)
u := e.getUser(ctx)
for _, claimKey := range e.evaluator.store.GetJWTClaimHeaders() {
// ignore base claims
if _, ok := e.cachedJWTPayload[claimKey]; ok {
continue
}
if vs, ok := getClaimStringSlice(s, claimKey); ok {
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
} else if vs, ok := getClaimStringSlice(u, claimKey); ok {
e.cachedJWTPayload[claimKey] = strings.Join(vs, ",")
}
}
return e.cachedJWTPayload, nil
}
func (e *headersEvaluatorEvaluation) getSignedJWT(ctx context.Context) string {
if e.gotSignedJWT {
return e.cachedSignedJWT
}
e.gotSignedJWT = true
signingKey := e.evaluator.store.GetSigningKey()
if signingKey == nil {
log.Ctx(ctx).Error().Msg("authorize/header-evaluator: missing signing key")
return ""
}
signingOptions := new(jose.SignerOptions).
WithType("JWT").
WithHeader("kid", signingKey.KeyID).
WithHeader("alg", signingKey.Algorithm)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(signingKey.Algorithm),
Key: signingKey.Key,
}, signingOptions)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT signer")
return ""
}
jwtPayload, err := e.getJWTPayload(ctx)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error creating JWT payload")
return ""
}
bs, err := json.Marshal(jwtPayload)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error marshaling JWT payload")
return ""
}
jwt, err := signer.Sign(bs)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error signing JWT")
return ""
}
e.cachedSignedJWT, err = jwt.CompactSerialize()
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/header-evaluator: error serializing JWT")
return ""
}
return e.cachedSignedJWT
}
func (e *headersEvaluatorEvaluation) getDataBrokerGroupNames(ctx context.Context, groupIDs []string) []string {
groupNames := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
dg, _ := e.evaluator.store.GetDataBrokerRecord(ctx, directory.GroupRecordType, groupID).(*structpb.Struct)
if name, ok := getStructString(dg, "name"); ok {
groupNames = append(groupNames, name)
}
}
return groupNames
}
type hasGetClaims interface {
GetClaims() map[string]*structpb.ListValue
}
func getClaimStringSlice(msg hasGetClaims, field string) ([]string, bool) {
if msg == nil {
return nil, false
}
claims := msg.GetClaims()
if claims == nil {
return nil, false
}
lv, ok := claims[field]
if !ok {
return nil, false
}
strs := make([]string, 0, len(lv.Values))
for _, v := range lv.Values {
switch v := v.GetKind().(type) {
case *structpb.Value_NumberValue:
strs = append(strs, fmt.Sprint(v.NumberValue))
case *structpb.Value_StringValue:
strs = append(strs, v.StringValue)
case *structpb.Value_BoolValue:
strs = append(strs, fmt.Sprint(v.BoolValue))
// just ignore these types
case *structpb.Value_NullValue:
case *structpb.Value_StructValue:
case *structpb.Value_ListValue:
}
}
return strs, true
}
func getStructString(s *structpb.Struct, field string) (string, bool) {
if s == nil || s.Fields == nil {
return "", false
}
v, ok := s.Fields[field]
if !ok {
return "", false
}
return fmt.Sprint(v.AsInterface()), true
}
func getStructStringSlice(s *structpb.Struct, field string) ([]string, bool) {
if s == nil || s.Fields == nil {
return nil, false
}
v, ok := s.Fields[field]
if !ok {
return nil, false
}
lv := v.GetListValue()
if lv == nil {
return nil, false
}
strs := make([]string, len(lv.Values))
for i, vv := range lv.Values {
sv, ok := vv.Kind.(*structpb.Value_StringValue)
if !ok {
return nil, false
}
strs[i] = sv.StringValue
}
return strs, true
}
func cryptoSHA256[T string | []byte](input T) string {
output := sha256.Sum256([]byte(input))
return hex.EncodeToString(output[:])
}
func getHeaderStringValue(obj any) string {
v := reflect.ValueOf(obj)
switch v.Kind() {
case reflect.Slice:
var str strings.Builder
for i := 0; i < v.Len(); i++ {
if i > 0 {
str.WriteByte(',')
}
str.WriteString(getHeaderStringValue(v.Index(i).Interface()))
}
return str.String()
}
return fmt.Sprint(obj)
}