pomerium/authorize/evaluator/headers_evaluator_test.go
Kenneth Jenkins b0f89fc688
authorize: log JWT groups filtering (#5432)
Add a new Authorize Log Fields option for logging the number of groups
removed during JWT groups filtering. This will be enabled by default.

Additionally, when the log level is Debug (or more verbose), store and
log the IDs of any groups removed during JWT groups filtering.
2025-01-27 12:11:52 -08:00

612 lines
20 KiB
Go

package evaluator
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"math"
"net/http"
"strings"
"testing"
"time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
)
func BenchmarkHeadersEvaluator(b *testing.B) {
ctx := context.Background()
privateJWK, _ := newJWK(b)
iat := time.Unix(1686870680, 0)
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier([]proto.Message{
&session.Session{Id: "s1", ImpersonateSessionId: proto.String("s2"), UserId: "u1"},
&session.Session{Id: "s2", UserId: "u2", Claims: map[string]*structpb.ListValue{
"name": {Values: []*structpb.Value{
structpb.NewStringValue("n1"),
}},
}, IssuedAt: timestamppb.New(iat)},
&user.User{Id: "u2", Name: "USER#2"},
newDirectoryUserRecord(directory.User{ID: "u2", GroupIDs: []string{"g1", "g2", "g3", "g4"}}),
newDirectoryGroupRecord(directory.Group{ID: "g1", Name: "GROUP1"}),
newDirectoryGroupRecord(directory.Group{ID: "g2", Name: "GROUP2"}),
newDirectoryGroupRecord(directory.Group{ID: "g3", Name: "GROUP3"}),
newDirectoryGroupRecord(directory.Group{ID: "g4", Name: "GROUP4"}),
}...))
s := store.New()
s.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
s.UpdateSigningKey(privateJWK)
e := NewHeadersEvaluator(s)
req := &Request{
HTTP: RequestHTTP{
Method: "GET",
Hostname: "from.example.com",
},
Policy: &config.Policy{
SetRequestHeaders: map[string]string{
"X-Custom-Header": "CUSTOM_VALUE",
"X-ID-Token": "${pomerium.id_token}",
"X-Access-Token": "${pomerium.access_token}",
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
"Authorization": "Bearer ${pomerium.jwt}",
},
},
Session: RequestSession{
ID: "s1",
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
res, err := e.Evaluate(ctx, req, rego.EvalTime(iat))
require.NoError(b, err)
_ = res
}
}
func TestHeadersEvaluator(t *testing.T) {
t.Parallel()
type A = []any
type M = map[string]any
privateJWK, publicJWK := newJWK(t)
iat := time.Unix(1686870680, 0)
eval := func(_ *testing.T, data []proto.Message, input *Request) (*HeadersResponse, error) {
ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New()
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("name", "email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK)
e := NewHeadersEvaluator(store)
return e.Evaluate(ctx, input, rego.EvalTime(iat))
}
t.Run("jwt", func(t *testing.T) {
output, err := eval(t,
[]proto.Message{
&session.Session{Id: "s1", ImpersonateSessionId: proto.String("s2"), UserId: "u1"},
&session.Session{Id: "s2", UserId: "u2", Claims: map[string]*structpb.ListValue{
"name": {Values: []*structpb.Value{
structpb.NewStringValue("n1"),
}},
"CUSTOM_KEY": {Values: []*structpb.Value{
structpb.NewStringValue("v1"),
structpb.NewStringValue("v2"),
structpb.NewStringValue("v3"),
}},
}, IssuedAt: timestamppb.New(iat)},
&user.User{Id: "u2", Claims: map[string]*structpb.ListValue{
"name": {Values: []*structpb.Value{
structpb.NewStringValue("n1"),
}},
}},
newDirectoryUserRecord(directory.User{ID: "u2", GroupIDs: []string{"g1", "g2", "g3", "g4"}}),
newDirectoryGroupRecord(directory.Group{ID: "g1", Name: "GROUP1", Email: "g1@example.com"}),
newDirectoryGroupRecord(directory.Group{ID: "g2", Name: "GROUP2", Email: "g2@example.com"}),
newDirectoryGroupRecord(directory.Group{ID: "g3", Name: "GROUP3", Email: "g3@example.com"}),
newDirectoryGroupRecord(directory.Group{ID: "g4", Name: "GROUP4", Email: "g4@example.com"}),
},
&Request{
HTTP: RequestHTTP{
Hostname: "from.example.com",
},
Policy: &config.Policy{},
Session: RequestSession{
ID: "s1",
},
})
require.NoError(t, err)
jwtHeader := output.Headers.Get("X-Pomerium-Jwt-Assertion")
// Make sure the 'iat' and 'exp' claims can be parsed as an integer. We
// need to do some explicit decoding in order to be able to verify
// this, as by default json.Unmarshal() will make no distinction
// between numeric formats.
d := json.NewDecoder(bytes.NewReader(decodeJWSPayload(t, jwtHeader)))
d.UseNumber()
var jwtPayloadDecoded map[string]any
err = d.Decode(&jwtPayloadDecoded)
require.NoError(t, err)
// The 'iat' and 'exp' claims are set based on the current time.
assert.Equal(t, json.Number(fmt.Sprint(iat.Unix())), jwtPayloadDecoded["iat"],
"unexpected 'iat' timestamp format")
assert.Equal(t, json.Number(fmt.Sprint(iat.Add(5*time.Minute).Unix())), jwtPayloadDecoded["exp"],
"unexpected 'exp' timestamp format")
rawJWT, err := jwt.ParseSigned(jwtHeader)
require.NoError(t, err)
var claims M
err = rawJWT.Claims(publicJWK, &claims)
require.NoError(t, err)
assert.NotEmpty(t, claims["jti"])
assert.Equal(t, claims["iss"], "from.example.com")
assert.Equal(t, claims["aud"], "from.example.com")
assert.Equal(t, claims["exp"], math.Round(claims["exp"].(float64)))
assert.LessOrEqual(t, claims["exp"], float64(time.Now().Add(time.Minute*6).Unix()),
"JWT should expire within 5 minutes, but got: %v", claims["exp"])
assert.Equal(t, "s1", claims["sid"], "should set session id to input session id")
assert.Equal(t, "u2", claims["sub"], "should set subject to user id")
assert.Equal(t, "u2", claims["user"], "should set user to user id")
assert.Equal(t, "n1", claims["name"], "should set name")
assert.Equal(t, "v1,v2,v3", claims["CUSTOM_KEY"], "should set CUSTOM_KEY")
assert.Equal(t, []any{"g1", "g2", "g3", "g4", "GROUP1", "GROUP2", "GROUP3", "GROUP4"}, claims["groups"])
})
t.Run("jwt no groups", func(t *testing.T) {
t.Parallel()
output, err := eval(t,
[]protoreflect.ProtoMessage{
&session.Session{Id: "s1", UserId: "u1", Claims: map[string]*structpb.ListValue{
"name": {Values: []*structpb.Value{
structpb.NewStringValue("User Name"),
}},
}},
},
&Request{
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
jwtHeader := output.Headers.Get("X-Pomerium-Jwt-Assertion")
var decoded map[string]any
err = json.Unmarshal(decodeJWSPayload(t, jwtHeader), &decoded)
require.NoError(t, err)
assert.Equal(t, []any{}, decoded["groups"])
})
t.Run("set_request_headers", func(t *testing.T) {
output, err := eval(t,
[]proto.Message{
&session.Session{Id: "s1", IdToken: &session.IDToken{
Raw: "ID_TOKEN",
}, OauthToken: &session.OAuthToken{
AccessToken: "ACCESS_TOKEN",
}},
},
&Request{
HTTP: RequestHTTP{
Hostname: "from.example.com",
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
},
Policy: &config.Policy{
SetRequestHeaders: map[string]string{
"X-Custom-Header": "CUSTOM_VALUE",
"X-ID-Token": "${pomerium.id_token}",
"X-Access-Token": "${pomerium.access_token}",
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
"Authorization": "Bearer ${pomerium.jwt}",
"Foo": "escaped $$dollar sign",
},
},
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "CUSTOM_VALUE", output.Headers.Get("X-Custom-Header"))
assert.Equal(t, "ID_TOKEN", output.Headers.Get("X-ID-Token"))
assert.Equal(t, "ACCESS_TOKEN", output.Headers.Get("X-Access-Token"))
assert.Equal(t, "3febe6467787e93f0a01030e0803072feaa710f724a9dc74de05cfba3d4a6d23",
output.Headers.Get("Client-Cert-Fingerprint"))
assert.Equal(t, "escaped $dollar sign", output.Headers.Get("Foo"))
authHeader := output.Headers.Get("Authorization")
assert.True(t, strings.HasPrefix(authHeader, "Bearer "))
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
token, err := jwt.ParseSigned(authHeader)
require.NoError(t, err)
var claims jwt.Claims
require.NoError(t, token.Claims(publicJWK, &claims))
assert.Equal(t, "from.example.com", claims.Issuer)
assert.Equal(t, jwt.Audience{"from.example.com"}, claims.Audience)
})
t.Run("set_request_headers no repeated substitution", func(t *testing.T) {
output, err := eval(t,
[]proto.Message{
&session.Session{Id: "s1", IdToken: &session.IDToken{
Raw: "$pomerium.access_token",
}, OauthToken: &session.OAuthToken{
AccessToken: "ACCESS_TOKEN",
}},
},
&Request{
Session: RequestSession{ID: "s1"},
Policy: &config.Policy{
SetRequestHeaders: map[string]string{
"X-ID-Token": "${pomerium.id_token}",
},
},
})
require.NoError(t, err)
assert.Equal(t, "$pomerium.access_token", output.Headers.Get("X-ID-Token"))
})
t.Run("set_request_headers original behavior", func(t *testing.T) {
output, err := eval(t,
[]proto.Message{
&session.Session{Id: "s1", IdToken: &session.IDToken{
Raw: "ID_TOKEN",
}, OauthToken: &session.OAuthToken{
AccessToken: "ACCESS_TOKEN",
}},
},
&Request{
Policy: &config.Policy{
SetRequestHeaders: map[string]string{
"Authorization": "Bearer ${pomerium.id_token}",
},
},
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "Bearer ID_TOKEN", output.Headers.Get("Authorization"))
})
t.Run("set_request_headers no client cert", func(t *testing.T) {
output, err := eval(t, nil,
&Request{
Policy: &config.Policy{
SetRequestHeaders: map[string]string{
"fingerprint": "${pomerium.client_cert_fingerprint}",
},
},
})
require.NoError(t, err)
assert.Equal(t, "", output.Headers.Get("fingerprint"))
})
t.Run("kubernetes", func(t *testing.T) {
t.Parallel()
output, err := eval(t,
[]protoreflect.ProtoMessage{
&session.Session{Id: "s1", UserId: "u1"},
&user.User{Id: "u1", Email: "u1@example.com"},
newDirectoryUserRecord(directory.User{
ID: "u1",
GroupIDs: []string{"g1", "g2", "g3"},
}),
newDirectoryGroupRecord(directory.Group{
ID: "g1",
Name: "GROUP1",
}),
newDirectoryGroupRecord(directory.Group{
ID: "g2",
Name: "GROUP2",
}),
},
&Request{
Policy: &config.Policy{
KubernetesServiceAccountToken: "TOKEN",
},
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "Bearer TOKEN", output.Headers.Get("Authorization"))
assert.Equal(t, "u1@example.com", output.Headers.Get("Impersonate-User"))
assert.Equal(t, "g1,g2,g3,GROUP1,GROUP2", output.Headers.Get("Impersonate-Group"))
})
t.Run("routing key", func(t *testing.T) {
t.Parallel()
output, err := eval(t,
[]protoreflect.ProtoMessage{},
&Request{
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Empty(t, output.Headers.Get("X-Pomerium-Routing-Key"))
output, err = eval(t,
[]protoreflect.ProtoMessage{},
&Request{
Policy: &config.Policy{
EnvoyOpts: &envoy_config_cluster_v3.Cluster{
LbPolicy: envoy_config_cluster_v3.Cluster_MAGLEV,
},
},
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "e8bc163c82eee18733288c7d4ac636db3a6deb013ef2d37b68322be20edc45cc", output.Headers.Get("X-Pomerium-Routing-Key"))
})
t.Run("jwt payload email", func(t *testing.T) {
t.Parallel()
output, err := eval(t,
[]protoreflect.ProtoMessage{
&session.Session{Id: "s1", UserId: "u1"},
&user.User{Id: "u1", Email: "user@example.com"},
},
&Request{
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "user@example.com", output.Headers.Get("X-Pomerium-Claim-Email"))
output, err = eval(t,
[]protoreflect.ProtoMessage{
&session.Session{Id: "s1", UserId: "u1"},
newDirectoryUserRecord(directory.User{ID: "u1", Email: "directory-user@example.com"}),
},
&Request{
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "directory-user@example.com", output.Headers.Get("X-Pomerium-Claim-Email"))
})
t.Run("jwt payload name", func(t *testing.T) {
t.Parallel()
output, err := eval(t,
[]protoreflect.ProtoMessage{
&session.Session{Id: "s1", UserId: "u1", Claims: map[string]*structpb.ListValue{
"name": {Values: []*structpb.Value{
structpb.NewStringValue("NAME_FROM_SESSION"),
}},
}},
},
&Request{
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "NAME_FROM_SESSION", output.Headers.Get("X-Pomerium-Claim-Name"))
output, err = eval(t,
[]protoreflect.ProtoMessage{
&session.Session{Id: "s1", UserId: "u1"},
&user.User{Id: "u1", Claims: map[string]*structpb.ListValue{
"name": {Values: []*structpb.Value{
structpb.NewStringValue("NAME_FROM_USER"),
}},
}},
},
&Request{
Session: RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "NAME_FROM_USER", output.Headers.Get("X-Pomerium-Claim-Name"))
})
t.Run("service account", func(t *testing.T) {
t.Parallel()
output, err := eval(t,
[]protoreflect.ProtoMessage{
&user.ServiceAccount{Id: "sa1", UserId: "u1"},
&user.User{Id: "u1", Email: "u1@example.com"},
},
&Request{
Session: RequestSession{ID: "sa1"},
})
require.NoError(t, err)
assert.Equal(t, "u1@example.com", output.Headers.Get("X-Pomerium-Claim-Email"))
})
t.Run("issuer format", func(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
format string
input string
output string
}{
{"", "example.com", "example.com"},
{"hostOnly", "host-only.example.com", "host-only.example.com"},
{"uri", "uri.example.com", "https://uri.example.com/"},
} {
output, err := eval(t,
nil,
&Request{
HTTP: RequestHTTP{
Hostname: tc.input,
},
Policy: &config.Policy{
JWTIssuerFormat: tc.format,
},
})
require.NoError(t, err)
m := decodeJWTAssertion(t, output.Headers)
assert.Equal(t, tc.output, m["iss"], "unexpected issuer for format=%s", tc.format)
}
})
}
func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) {
t.Parallel()
privateJWK, _ := newJWK(t)
// Create some user and groups data.
var records []proto.Message
groupsCount := 50
for i := 1; i <= groupsCount; i++ {
id := fmt.Sprint(i)
records = append(records, newDirectoryGroupRecord(directory.Group{ID: id, Name: "GROUP-" + id}))
}
for i := 1; i <= 10; i++ {
id := fmt.Sprintf("USER-%d", i)
// User 1 will be in every group, user 2 in every other group, user 3 in every third group, etc.
var groups []string
for j := i; j <= groupsCount; j += i {
groups = append(groups, fmt.Sprint(j))
}
records = append(records,
&session.Session{Id: fmt.Sprintf("SESSION-%d", i), UserId: id},
newDirectoryUserRecord(directory.User{ID: id, GroupIDs: groups}),
)
}
// Also add a user session with an upstream "groups" claim from the IdP.
records = append(records,
&session.Session{Id: "SESSION-11", UserId: "USER-11", Claims: map[string]*structpb.ListValue{
"groups": newList("foo", "bar", "baz"),
}},
)
cases := []struct {
name string
globalFilter []string
routeFilter []string
sessionID string
expected []any
removed int
}{
{"global filter 1", []string{"42", "1"}, nil, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}, 48},
{"global filter 2", []string{"42", "1"}, nil, "SESSION-2", []any{"42", "GROUP-42"}, 24},
{"route filter 1", nil, []string{"42", "1"}, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}, 48},
{"route filter 2", nil, []string{"42", "1"}, "SESSION-2", []any{"42", "GROUP-42"}, 24},
{"both filters 1", []string{"1"}, []string{"42"}, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}, 48},
{"both filters 2", []string{"1"}, []string{"42"}, "SESSION-2", []any{"42", "GROUP-42"}, 24},
{"cannot filter by name", []string{"GROUP-1"}, nil, "SESSION-1", []any{}, 50},
{"overlapping", []string{"1"}, []string{"1"}, "SESSION-1", []any{"1", "GROUP-1"}, 49},
{"empty route filter", []string{"1", "2", "3"}, []string{}, "SESSION-1", []any{"1", "2", "3", "GROUP-1", "GROUP-2", "GROUP-3"}, 47},
{
"no filtering", nil, nil, "SESSION-10",
[]any{"10", "20", "30", "40", "50", "GROUP-10", "GROUP-20", "GROUP-30", "GROUP-40", "GROUP-50"},
0,
},
// filtering has no effect on groups from an IdP "groups" claim
{"groups claim", []string{"foo", "quux"}, nil, "SESSION-11", []any{"foo", "bar", "baz"}, 0},
}
ctx := storage.WithQuerier(context.Background(), storage.NewStaticQuerier(records...))
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
store := store.New()
store.UpdateSigningKey(privateJWK)
store.UpdateJWTGroupsFilter(config.NewJWTGroupsFilter(c.globalFilter))
req := &Request{Session: RequestSession{ID: c.sessionID}}
if c.routeFilter != nil {
req.Policy = &config.Policy{
JWTGroupsFilter: config.NewJWTGroupsFilter(c.routeFilter),
}
}
e := NewHeadersEvaluator(store)
resp, err := e.Evaluate(ctx, req)
require.NoError(t, err)
decoded := decodeJWTAssertion(t, resp.Headers)
assert.Equal(t, c.expected, decoded["groups"])
if c.removed > 0 {
assert.Equal(t, c.removed, resp.AdditionalLogFields[log.AuthorizeLogFieldRemovedGroupsCount])
} else {
assert.Nil(t, resp.AdditionalLogFields[log.AuthorizeLogFieldRemovedGroupsCount])
}
})
}
}
func newJWK(t testing.TB) (privateJWK, publicJWK *jose.JSONWebKey) {
t.Helper()
signingKey, err := cryptutil.NewSigningKey()
require.NoError(t, err)
encodedSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
require.NoError(t, err)
privateJWK, err = cryptutil.PrivateJWKFromBytes(encodedSigningKey)
require.NoError(t, err)
publicJWK, err = cryptutil.PublicJWKFromBytes(encodedSigningKey)
require.NoError(t, err)
return
}
func decodeJWTAssertion(t *testing.T, headers http.Header) map[string]any {
jwtHeader := headers.Get("X-Pomerium-Jwt-Assertion")
// Make sure the 'iat' and 'exp' claims can be parsed as an integer. We
// need to do some explicit decoding in order to be able to verify
// this, as by default json.Unmarshal() will make no distinction
// between numeric formats.
d := json.NewDecoder(bytes.NewReader(decodeJWSPayload(t, jwtHeader)))
d.UseNumber()
var m map[string]any
err := d.Decode(&m)
require.NoError(t, err)
return m
}
func decodeJWSPayload(t *testing.T, jws string) []byte {
t.Helper()
// A compact JWS string should consist of three base64-encoded values,
// separated by a '.' character. The payload is the middle one of these.
// cf. https://www.rfc-editor.org/rfc/rfc7515#section-7.1
parts := strings.Split(jws, ".")
require.Equal(t, 3, len(parts), "jws should have 3 parts: %s", jws)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
require.NoError(t, err)
return payload
}
func newDirectoryGroupRecord(directoryGroup directory.Group) *databroker.Record {
m := map[string]any{}
bs, _ := json.Marshal(directoryGroup)
_ = json.Unmarshal(bs, &m)
s, _ := structpb.NewStruct(m)
return storage.NewStaticRecord(directory.GroupRecordType, s)
}
func newDirectoryUserRecord(directoryUser directory.User) *databroker.Record {
m := map[string]any{}
bs, _ := json.Marshal(directoryUser)
_ = json.Unmarshal(bs, &m)
s, _ := structpb.NewStruct(m)
return storage.NewStaticRecord(directory.UserRecordType, s)
}
func newList(v ...any) *structpb.ListValue {
lv, _ := structpb.NewList(v)
return lv
}