From e18c04216e143392217d88524008cb08375e8fe5 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Tue, 18 Jun 2024 21:40:03 -0400 Subject: [PATCH] authorize: hot path identity provider lookup optimizations --- authenticate/handlers_test.go | 6 +- authorize/grpc.go | 2 +- config/identity.go | 133 ++++++++++++- config/identity_benchmark_test.go | 184 ++++++++++++++++++ config/session.go | 25 ++- config/session_test.go | 11 +- go.mod | 2 + go.sum | 4 + internal/sessions/cookie/cookie_store.go | 3 +- internal/sessions/cookie/cookie_store_test.go | 3 +- internal/sessions/header/header_store.go | 3 +- internal/sessions/middleware.go | 2 +- internal/sessions/mock/mock_store.go | 3 +- internal/sessions/mock/mock_store_test.go | 3 +- internal/sessions/queryparam/query_store.go | 3 +- internal/sessions/store.go | 7 +- pkg/cryptutil/kek.go | 2 +- pkg/cryptutil/token.go | 8 +- pkg/encoding/base58/alphabet.go | 49 ----- pkg/encoding/base58/base58.go | 75 ------- pkg/encoding/base58/base58_test.go | 100 ---------- pkg/encoding/base58/doc.go | 22 --- pkg/grpc/device/device.go | 2 +- pkg/grpc/identity/identity.go | 2 +- pkg/telemetry/requestid/requestid.go | 2 +- pkg/webauthnutil/credential_storage.go | 2 +- proxy/data.go | 2 +- proxy/handlers.go | 2 +- proxy/proxy.go | 9 + 29 files changed, 387 insertions(+), 284 deletions(-) create mode 100644 config/identity_benchmark_test.go delete mode 100644 pkg/encoding/base58/alphabet.go delete mode 100644 pkg/encoding/base58/base58.go delete mode 100644 pkg/encoding/base58/base58_test.go delete mode 100644 pkg/encoding/base58/doc.go diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index c8f5342a5..6fa0948ee 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -249,7 +249,7 @@ func TestAuthenticate_SignOut(t *testing.T) { } u.RawQuery = params.Encode() r := httptest.NewRequest(tt.method, u.String(), nil) - state, err := tt.sessionStore.LoadSession(r) + state, err := tt.sessionStore.LoadSession(context.TODO(), r) if err != nil { t.Fatal(err) } @@ -481,7 +481,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { options: config.NewAtomicOptions(), } r := httptest.NewRequest(http.MethodGet, "/", nil) - state, err := tt.session.LoadSession(r) + state, err := tt.session.LoadSession(context.TODO(), r) if err != nil { t.Fatal(err) } @@ -586,7 +586,7 @@ func TestAuthenticate_userInfo(t *testing.T) { }), } r := httptest.NewRequest(http.MethodGet, tt.url, nil) - state, err := tt.sessionStore.LoadSession(r) + state, err := tt.sessionStore.LoadSession(context.TODO(), r) if err != nil { t.Fatal(err) } diff --git a/authorize/grpc.go b/authorize/grpc.go index 85aafbb62..6f865efc7 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -47,7 +47,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe hreq := getHTTPRequestFromCheckRequest(in) ctx = requestid.WithValue(ctx, requestid.FromHTTPHeader(hreq.Header)) - sessionState, _ := state.sessionStore.LoadSessionState(hreq) + sessionState, _ := state.sessionStore.LoadSessionState(ctx, hreq) var s sessionOrServiceAccount var u *user.User diff --git a/config/identity.go b/config/identity.go index 4fd3a25bd..94513f660 100644 --- a/config/identity.go +++ b/config/identity.go @@ -1,6 +1,13 @@ package config import ( + "context" + "fmt" + "net" + "slices" + "strings" + + art "github.com/kralicky/go-adaptive-radix-tree" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/grpc/identity" ) @@ -51,7 +58,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid } // GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL. -func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity.Provider, error) { +func (o *Options) GetIdentityProviderForRequestURL(ctx context.Context, requestURL string) (*identity.Provider, error) { u, err := urlutil.ParseAndValidateURL(requestURL) if err != nil { return nil, err @@ -65,3 +72,127 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity } return o.GetIdentityProviderForPolicy(nil) } + +type PolicyCache struct { + domainTree art.Tree[*domainNode] + matchPorts bool +} + +func NewPolicyCache(options *Options) (*PolicyCache, error) { + tree := art.New[*domainNode]() + shouldMatchPorts := !options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort) + for _, policy := range options.GetAllPolicies() { + u, err := urlutil.ParseAndValidateURL(policy.From) + if err != nil { + return nil, err + } + domains := urlutil.GetDomainsForURL(u, shouldMatchPorts) + for _, domain := range domains { + host, port, err := net.SplitHostPort(domain) // todo: this is not optimal + if err != nil { + host, port = domain, "" + } + domainKey := radixKeyForHostPort(host, port) + tree.Update(domainKey, newDomainNode, func(dn **domainNode) { + if policy.Prefix != "" { + (*dn).policiesByPrefix.Insert(art.Key(policy.Prefix), policy) + } else if policy.Path != "" { + (*dn).policiesByPrefix.Insert(art.Key(policy.Path), policy) + } else if policy.compiledRegex != nil { + (*dn).policiesByRegex = append((*dn).policiesByRegex, policy) + } else { + (*dn).policiesNoPathMatching = append((*dn).policiesNoPathMatching, policy) + } + }) + } + } + return &PolicyCache{ + domainTree: tree, + matchPorts: shouldMatchPorts, + }, nil +} + +func (pc *PolicyCache) GetIdentityProviderForRequestURL(ctx context.Context, o *Options, requestURL string) (*identity.Provider, error) { + u, err := urlutil.ParseAndValidateURL(requestURL) + if err != nil { + return nil, err + } + + domainKey := radixKeyForHostPort(u.Hostname(), u.Port()) + domain, ok := pc.domainTree.Resolve(domainKey, wildcardResolver) + if !ok { + return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL) + } + var policy *Policy + if len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/') && len(domain.policiesNoPathMatching) > 0 { + policy = &domain.policiesNoPathMatching[0] + } else { + if domain.policiesByPrefix.Size() > 0 { + pathKey := art.Key(u.Path) + actualKey, val, found := domain.policiesByPrefix.SearchNearest(pathKey) + if found { + // check for prefix match or exact match + if c := actualKey.Compare(pathKey); c < 0 { + if val.Prefix != "" && strings.HasPrefix(u.Path, val.Prefix) { + policy = &val + } + } else if c == 0 { + if val.Path != "" || val.Prefix != "" { + policy = &val + } + } + } + } + if policy == nil { + for _, p := range domain.policiesByRegex { + if p.compiledRegex.MatchString(u.Path) { + policy = &p + break + } + } + } + } + if policy != nil { + return o.GetIdentityProviderForPolicy(policy) + } + + return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL) +} + +type domainNode struct { + policiesByPrefix art.Tree[Policy] + policiesByRegex []Policy + policiesNoPathMatching []Policy +} + +func newDomainNode() *domainNode { + return &domainNode{ + policiesByPrefix: art.New[Policy](), + } +} + +func radixKeyForHostPort(host, port string) art.Key { + parts := strings.Split(host, ".") + sb := strings.Builder{} + sb.WriteString(port) + for i := len(parts) - 1; i >= 0; i-- { + sb.WriteByte('.') + sb.WriteString(parts[i]) + } + return art.Key(sb.String()) +} + +func wildcardResolver(key art.Key, conflictIndex int) (art.Key, int) { + if conflictIndex >= len(key) { + return nil, -1 + } + c := key[conflictIndex] + if c != '*' && c != '.' { + nextDot := slices.Index(key[conflictIndex:], '.') + if nextDot == -1 { + return art.Key("*"), len(key) + } + return art.Key("*"), conflictIndex + nextDot + } + return nil, -1 +} diff --git a/config/identity_benchmark_test.go b/config/identity_benchmark_test.go new file mode 100644 index 000000000..1c1dda52b --- /dev/null +++ b/config/identity_benchmark_test.go @@ -0,0 +1,184 @@ +package config_test + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + "testing" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/stretchr/testify/require" +) + +func BenchmarkGetIdentityProviderForRequestURL_Old(b *testing.B) { + runBench := func(numPolicies int) func(b *testing.B) { + return func(b *testing.B) { + b.ReportAllocs() + options := config.NewDefaultOptions() + sharedKey := cryptutil.NewKey() + options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey) + options.Provider = "oidc" + options.ProviderURL = "https://oidc.example.com" + options.ClientID = "client_id" + options.ClientSecret = "client_secret" + urlFormat := "https://*.foo.bar.test-%d.example.com" + for i := range numPolicies { + options.Policies = append(options.Policies, + config.Policy{ + From: fmt.Sprintf(urlFormat, i), + To: mustParseWeightedURLs(b, fmt.Sprintf("https://p2-%d", i)), + IDPClientID: fmt.Sprintf("client_id_%d", i), + IDPClientSecret: fmt.Sprintf("client_secret_%d", i), + }, + ) + } + require.NoError(b, options.Validate()) + + b.ResetTimer() + for range b.N { + idp, err := options.GetIdentityProviderForRequestURL(context.Background(), fmt.Sprintf(urlFormat, numPolicies-1)) + require.NoError(b, err) + require.Equal(b, fmt.Sprintf("client_id_%d", numPolicies-1), idp.ClientId) + require.Equal(b, fmt.Sprintf("client_secret_%d", numPolicies-1), idp.ClientSecret) + } + } + } + + b.Run("5 policies", runBench(5)) + b.Run("50 policies", runBench(50)) + b.Run("500 policies", runBench(500)) + b.Run("5000 policies", runBench(5000)) +} + +var bench = func(fill func(i int, p *config.Policy) string, numPolicies int) func(b *testing.B) { + return func(b *testing.B) { + b.ReportAllocs() + options := config.NewDefaultOptions() + sharedKey := cryptutil.NewKey() + options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey) + options.Provider = "oidc" + options.ProviderURL = "https://oidc.example.com" + options.ClientID = "client_id" + options.ClientSecret = "client_secret" + var allUrls []string + for i := range numPolicies { + p := config.Policy{ + To: mustParseWeightedURLs(b, fmt.Sprintf("https://p2-%d", i)), + IDPClientID: fmt.Sprintf("client_id_%d", i), + IDPClientSecret: fmt.Sprintf("client_secret_%d", i), + } + + allUrls = append(allUrls, fill(i, &p)) + options.Policies = append(options.Policies, p) + } + require.NoError(b, options.Validate()) + + cache, err := config.NewPolicyCache(options) + require.NoError(b, err) + + b.ResetTimer() + for i := range b.N { + reqUrl := strings.Replace(allUrls[i%numPolicies], "*", fmt.Sprint(i), 1) + idp, err := cache.GetIdentityProviderForRequestURL(context.Background(), options, reqUrl) + require.NoError(b, err) + require.Equal(b, fmt.Sprintf("client_id_%d", i%numPolicies), idp.ClientId) + require.Equal(b, fmt.Sprintf("client_secret_%d", i%numPolicies), idp.ClientSecret) + } + } +} + +func BenchmarkGetIdentityProviderForRequestURL_New_DomainMatchOnly(b *testing.B) { + domainMatchingOnly := func(i int, p *config.Policy) string { + p.From = fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i) + return p.From + } + + b.Run("5 policies (domain matching only)", bench(domainMatchingOnly, 5)) + b.Run("50 policies (domain matching only)", bench(domainMatchingOnly, 50)) + b.Run("500 policies (domain matching only)", bench(domainMatchingOnly, 500)) + b.Run("5000 policies (domain matching only)", bench(domainMatchingOnly, 5000)) +} + +func BenchmarkGetIdentityProviderForRequestURL_New_PathMatchOnly(b *testing.B) { + pathMatchingOnly := func(i int, p *config.Policy) string { + p.From = "https://example.com" + p.Path = fmt.Sprintf("/foo/bar/path%d", i) + return p.From + p.Path + } + b.Run("5 policies (path matching only)", bench(pathMatchingOnly, 5)) + b.Run("50 policies (path matching only)", bench(pathMatchingOnly, 50)) + b.Run("500 policies (path matching only)", bench(pathMatchingOnly, 500)) + b.Run("5000 policies (path matching only)", bench(pathMatchingOnly, 5000)) +} + +func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPathMatching(b *testing.B) { + combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string { + // returns a sequence of policies (ex: numPathsPerDomain=3) + // https://*.foo.bar.test-0.example.com + // https://*.foo.bar.test-0.example.com/foo/bar/path1 + // https://*.foo.bar.test-0.example.com/foo/bar/path2 + // https://*.foo.bar.test-1.example.com + // https://*.foo.bar.test-1.example.com/foo/bar/path1 + // https://*.foo.bar.test-1.example.com/foo/bar/path2 + return func(i int, p *config.Policy) string { + domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain) + pathIdx := i % numPathsPerDomain + var path string + if pathIdx == 0 { + path = "" + } else { + path = fmt.Sprintf("/foo/bar/path%d", pathIdx) + } + p.From = domain + p.Path = path + return domain + path + } + } + + b.Run("25 policies (5 domains, 5 paths per domain)", bench(combinedMatching(5), 25)) + b.Run("500 policies (50 domains, 10 paths per domain)", bench(combinedMatching(10), 500)) + b.Run("500 policies (10 domains, 50 paths per domain)", bench(combinedMatching(50), 500)) + b.Run("2500 policies (50 domains, 50 paths per domain)", bench(combinedMatching(50), 2500)) + b.Run("5000 policies (100 domains, 50 paths per domain)", bench(combinedMatching(50), 5000)) + b.Run("5000 policies (50 domains, 100 paths per domain)", bench(combinedMatching(100), 5000)) + b.Run("10000 policies (100 domains, 100 paths per domain)", bench(combinedMatching(100), 10000)) +} + +func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPrefixMatching(b *testing.B) { + combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string { + // returns a sequence of policies (ex: numPathsPerDomain=3) + // https://*.foo.bar.test-0.example.com/0 + // https://*.foo.bar.test-0.example.com/0/1 + // https://*.foo.bar.test-0.example.com/0/1/2 + // https://*.foo.bar.test-1.example.com/0 + // https://*.foo.bar.test-1.example.com/0/1 + // https://*.foo.bar.test-1.example.com/0/1/2 + return func(i int, p *config.Policy) string { + domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain) + pathIdx := i % numPathsPerDomain + var prefix strings.Builder + for i := range pathIdx + 1 { + fmt.Fprintf(&prefix, "/%d", i) + } + p.From = domain + p.Prefix = prefix.String() + return domain + p.Prefix + } + } + + b.Run("25 policies (5 domains, 5 paths per domain)", bench(combinedMatching(5), 25)) + b.Run("500 policies (50 domains, 10 paths per domain)", bench(combinedMatching(10), 500)) + b.Run("500 policies (10 domains, 50 paths per domain)", bench(combinedMatching(50), 500)) + b.Run("2500 policies (50 domains, 50 paths per domain)", bench(combinedMatching(50), 2500)) + b.Run("5000 policies (100 domains, 50 paths per domain)", bench(combinedMatching(50), 5000)) + b.Run("5000 policies (50 domains, 100 paths per domain)", bench(combinedMatching(100), 5000)) + b.Run("10000 policies (100 domains, 100 paths per domain)", bench(combinedMatching(100), 10000)) +} + +func mustParseWeightedURLs(t testing.TB, urls ...string) []config.WeightedURL { + wu, err := config.ParseWeightedUrls(urls...) + require.NoError(t, err) + return wu +} diff --git a/config/session.go b/config/session.go index f0268d611..067394455 100644 --- a/config/session.go +++ b/config/session.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "net/http" @@ -10,20 +11,27 @@ import ( "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/sessions/header" "github.com/pomerium/pomerium/internal/sessions/queryparam" + "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" ) // A SessionStore saves and loads sessions based on the options. type SessionStore struct { - options *Options - encoder encoding.MarshalUnmarshaler - loader sessions.SessionLoader + options *Options + encoder encoding.MarshalUnmarshaler + loader sessions.SessionLoader + policyCache *PolicyCache } // NewSessionStore creates a new SessionStore from the Options. func NewSessionStore(options *Options) (*SessionStore, error) { + cache, err := NewPolicyCache(options) + if err != nil { + return nil, err + } store := &SessionStore{ - options: options, + options: options, + policyCache: cache, } sharedKey, err := options.GetSharedKey() @@ -57,8 +65,11 @@ func NewSessionStore(options *Options) (*SessionStore, error) { } // LoadSessionState loads the session state from a request. -func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, error) { - rawJWT, err := store.loader.LoadSession(r) +func (store *SessionStore) LoadSessionState(ctx context.Context, r *http.Request) (*sessions.State, error) { + ctx, span := trace.StartSpan(ctx, "session_store.load_session_state") + defer span.End() + + rawJWT, err := store.loader.LoadSession(ctx, r) if err != nil { return nil, err } @@ -71,7 +82,7 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e // confirm that the identity provider id matches the state if state.IdentityProviderID != "" { - idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) + idp, err := store.policyCache.GetIdentityProviderForRequestURL(ctx, store.options, urlutil.GetAbsoluteURL(r).String()) if err != nil { return nil, err } diff --git a/config/session_test.go b/config/session_test.go index 058850b86..93b3c792d 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -1,6 +1,7 @@ package config import ( + "context" "encoding/base64" "net/http" "net/url" @@ -70,7 +71,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { t.Run("mssing", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil) require.NoError(t, err) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionState(context.TODO(), r) assert.ErrorIs(t, err, sessions.ErrNoSessionFound) assert.Nil(t, s) }) @@ -85,7 +86,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { urlutil.QuerySession: {rawJWS}, }.Encode(), nil) require.NoError(t, err) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionState(context.TODO(), r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", @@ -103,7 +104,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionState(context.TODO(), r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", @@ -121,7 +122,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionState(context.TODO(), r) assert.Error(t, err) assert.Nil(t, s) }) @@ -134,7 +135,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionState(context.TODO(), r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", diff --git a/go.mod b/go.mod index 189b52900..20c66e6d6 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jackc/pgx/v5 v5.6.0 github.com/klauspost/compress v1.17.8 + github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31 github.com/martinlindhe/base36 v1.1.1 github.com/mholt/acmez/v2 v2.0.1 github.com/minio/minio-go/v7 v7.0.70 @@ -104,6 +105,7 @@ require ( github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/OneOfOne/xxhash v1.2.8 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect + github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect diff --git a/go.sum b/go.sum index 23ca4df22..5824fbbeb 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/VictoriaMetrics/fastcache v1.12.2 h1:N0y9ASrJ0F6h0QaC3o6uJb3NIZ9VKLjC github.com/VictoriaMetrics/fastcache v1.12.2/go.mod h1:AmC+Nzz1+3G2eCPapF6UcsnkThDcMsQicp4xDukwJYI= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= +github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f h1:z8MkSJCUyTmW5YQlxsMLBlwA7GmjxC7L4ooicxqnhz8= +github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f/go.mod h1:UdUwYgAXBiL+kLfcqxoQJYkHA/vl937/PbFhZM34aZs= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -414,6 +416,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31 h1:cOZUoNuv9OuCZKepddeIW87ScJWwveLfLcJMaii6YCA= +github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= diff --git a/internal/sessions/cookie/cookie_store.go b/internal/sessions/cookie/cookie_store.go index 70318a6cb..b48cdbdc4 100644 --- a/internal/sessions/cookie/cookie_store.go +++ b/internal/sessions/cookie/cookie_store.go @@ -2,6 +2,7 @@ package cookie import ( + "context" "errors" "fmt" "net/http" @@ -117,7 +118,7 @@ func getCookies(r *http.Request, name string) []*http.Cookie { } // LoadSession returns a State from the cookie in the request. -func (cs *Store) LoadSession(r *http.Request) (string, error) { +func (cs *Store) LoadSession(ctx context.Context, r *http.Request) (string, error) { opts := cs.getOptions() cookies := getCookies(r, opts.Name) if len(cookies) == 0 { diff --git a/internal/sessions/cookie/cookie_store_test.go b/internal/sessions/cookie/cookie_store_test.go index 90232608a..0cb75d0bf 100644 --- a/internal/sessions/cookie/cookie_store_test.go +++ b/internal/sessions/cookie/cookie_store_test.go @@ -1,6 +1,7 @@ package cookie import ( + "context" "crypto/rand" "errors" "fmt" @@ -145,7 +146,7 @@ func TestStore_SaveSession(t *testing.T) { r.AddCookie(cookie) } - jwt, err := s.LoadSession(r) + jwt, err := s.LoadSession(context.TODO(), r) if (err != nil) != tt.wantLoadErr { t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr) return diff --git a/internal/sessions/header/header_store.go b/internal/sessions/header/header_store.go index f04ab069c..686cbc3fb 100644 --- a/internal/sessions/header/header_store.go +++ b/internal/sessions/header/header_store.go @@ -3,6 +3,7 @@ package header import ( + "context" "net/http" "strings" @@ -31,7 +32,7 @@ func NewStore(enc encoding.Unmarshaler) *Store { } // LoadSession tries to retrieve the token string from the Authorization header. -func (as *Store) LoadSession(r *http.Request) (string, error) { +func (as *Store) LoadSession(_ context.Context, r *http.Request) (string, error) { jwt := TokenFromHeaders(r) if jwt == "" { return "", sessions.ErrNoSessionFound diff --git a/internal/sessions/middleware.go b/internal/sessions/middleware.go index 6d584e9b2..c16861430 100644 --- a/internal/sessions/middleware.go +++ b/internal/sessions/middleware.go @@ -23,7 +23,7 @@ func retrieve(s SessionLoader) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { hfn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jwt, err := s.LoadSession(r) + jwt, err := s.LoadSession(ctx, r) ctx = NewContext(ctx, jwt, err) next.ServeHTTP(w, r.WithContext(ctx)) } diff --git a/internal/sessions/mock/mock_store.go b/internal/sessions/mock/mock_store.go index 0b2f76a7d..a1d25dde0 100644 --- a/internal/sessions/mock/mock_store.go +++ b/internal/sessions/mock/mock_store.go @@ -2,6 +2,7 @@ package mock import ( + "context" "net/http" "github.com/pomerium/pomerium/internal/encoding" @@ -30,7 +31,7 @@ func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) { } // LoadSession returns the session and a error -func (ms Store) LoadSession(*http.Request) (string, error) { +func (ms Store) LoadSession(context.Context, *http.Request) (string, error) { var signer encoding.MarshalUnmarshaler signer, _ = jws.NewHS256Signer(ms.Secret) jwt, _ := signer.Marshal(ms.Session) diff --git a/internal/sessions/mock/mock_store_test.go b/internal/sessions/mock/mock_store_test.go index bea3befd5..e228dd5f4 100644 --- a/internal/sessions/mock/mock_store_test.go +++ b/internal/sessions/mock/mock_store_test.go @@ -1,6 +1,7 @@ package mock import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -40,7 +41,7 @@ func TestStore(t *testing.T) { t.Errorf("mockstore.SaveSession() error = %v, wantSaveErr %v", err, tt.wantSaveErr) return } - got, err := ms.LoadSession(nil) + got, err := ms.LoadSession(context.TODO(), nil) if (err != nil) != tt.wantLoadErr { t.Errorf("mockstore.LoadSession() error = %v, wantLoadErr %v", err, tt.wantLoadErr) return diff --git a/internal/sessions/queryparam/query_store.go b/internal/sessions/queryparam/query_store.go index b86d18f3c..cbae4b13c 100644 --- a/internal/sessions/queryparam/query_store.go +++ b/internal/sessions/queryparam/query_store.go @@ -3,6 +3,7 @@ package queryparam import ( + "context" "net/http" "github.com/pomerium/pomerium/internal/encoding" @@ -43,7 +44,7 @@ func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store { } // LoadSession tries to retrieve the token string from URL query parameters. -func (qp *Store) LoadSession(r *http.Request) (string, error) { +func (qp *Store) LoadSession(_ context.Context, r *http.Request) (string, error) { jwt := r.URL.Query().Get(qp.queryParamKey) if jwt == "" { return "", sessions.ErrNoSessionFound diff --git a/internal/sessions/store.go b/internal/sessions/store.go index 32044fe98..6e6d04b05 100644 --- a/internal/sessions/store.go +++ b/internal/sessions/store.go @@ -3,6 +3,7 @@ package sessions import ( + "context" "errors" "net/http" ) @@ -16,14 +17,14 @@ type SessionStore interface { // SessionLoader defines an interface for loading a session. type SessionLoader interface { - LoadSession(*http.Request) (string, error) + LoadSession(context.Context, *http.Request) (string, error) } type multiSessionLoader []SessionLoader -func (l multiSessionLoader) LoadSession(r *http.Request) (string, error) { +func (l multiSessionLoader) LoadSession(ctx context.Context, r *http.Request) (string, error) { for _, ll := range l { - s, err := ll.LoadSession(r) + s, err := ll.LoadSession(ctx, r) if errors.Is(err, ErrNoSessionFound) { continue } diff --git a/pkg/cryptutil/kek.go b/pkg/cryptutil/kek.go index 7689cc546..2f29df5e2 100644 --- a/pkg/cryptutil/kek.go +++ b/pkg/cryptutil/kek.go @@ -7,7 +7,7 @@ import ( "golang.org/x/crypto/curve25519" "golang.org/x/crypto/nacl/box" - "github.com/pomerium/pomerium/pkg/encoding/base58" + "github.com/akamensky/base58" ) // A KeyEncryptionKey (KEK) is used to implement *envelope encryption*, similar to how data is stored at rest with diff --git a/pkg/cryptutil/token.go b/pkg/cryptutil/token.go index f610cc0d9..1c6dc806e 100644 --- a/pkg/cryptutil/token.go +++ b/pkg/cryptutil/token.go @@ -9,7 +9,7 @@ import ( "github.com/google/uuid" - "github.com/pomerium/pomerium/pkg/encoding/base58" + "github.com/akamensky/base58" ) // TokenLength is the length of a token. @@ -27,7 +27,7 @@ func NewRandomToken() (tok Token) { // TokenFromString parses a base58-encoded string into a token. func TokenFromString(rawstr string) (tok Token, ok bool) { - result := base58.Decode(rawstr) + result, _ := base58.Decode(rawstr) if len(result) != TokenLength { return tok, false } @@ -57,7 +57,7 @@ type SecretToken struct { // SecretTokenFromString parses a base58-encoded string into a secret token. func SecretTokenFromString(rawstr string) (tok SecretToken, ok bool) { - result := base58.Decode(rawstr) + result, _ := base58.Decode(rawstr) if len(result) != TokenLength*2 { return tok, false } @@ -104,7 +104,7 @@ func GenerateSecureToken(key []byte, expiry time.Time, token Token) SecureToken // SecureTokenFromString parses a base58-encoded string into a SecureToken. func SecureTokenFromString(rawstr string) (secureToken SecureToken, ok bool) { - result := base58.Decode(rawstr) + result, _ := base58.Decode(rawstr) if len(result) != SecureTokenLength { return secureToken, false } diff --git a/pkg/encoding/base58/alphabet.go b/pkg/encoding/base58/alphabet.go deleted file mode 100644 index d0f73c9b0..000000000 --- a/pkg/encoding/base58/alphabet.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2015 The btcsuite developers -// Use of this source code is governed by an ISC -// license that can be found in the LICENSE file. -// -// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE - -package base58 - -const ( - // alphabet is the modified base58 alphabet used by Bitcoin. - alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" - - alphabetIdx0 = '1' -) - -var b58 = [256]byte{ - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 0, 1, 2, 3, 4, 5, 6, - 7, 8, 255, 255, 255, 255, 255, 255, - 255, 9, 10, 11, 12, 13, 14, 15, - 16, 255, 17, 18, 19, 20, 21, 255, - 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 255, 255, 255, 255, 255, - 255, 33, 34, 35, 36, 37, 38, 39, - 40, 41, 42, 43, 255, 44, 45, 46, - 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, -} diff --git a/pkg/encoding/base58/base58.go b/pkg/encoding/base58/base58.go deleted file mode 100644 index 928cb9a32..000000000 --- a/pkg/encoding/base58/base58.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2013-2015 The btcsuite developers -// Use of this source code is governed by an ISC -// license that can be found in the LICENSE file. -// -// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE - -package base58 - -import "math/big" - -var ( - bigRadix = big.NewInt(58) - bigZero = big.NewInt(0) -) - -// Decode decodes a modified base58 string to a byte slice. -func Decode(b string) []byte { - answer := big.NewInt(0) - j := big.NewInt(1) - - scratch := new(big.Int) - for i := len(b) - 1; i >= 0; i-- { - tmp := b58[b[i]] - if tmp == 255 { - return []byte("") - } - scratch.SetInt64(int64(tmp)) - scratch.Mul(j, scratch) - answer.Add(answer, scratch) - j.Mul(j, bigRadix) - } - - tmpval := answer.Bytes() - - var numZeros int - for numZeros = 0; numZeros < len(b); numZeros++ { - if b[numZeros] != alphabetIdx0 { - break - } - } - flen := numZeros + len(tmpval) - val := make([]byte, flen) - copy(val[numZeros:], tmpval) - - return val -} - -// Encode encodes a byte slice to a modified base58 string. -func Encode(b []byte) string { - x := new(big.Int) - x.SetBytes(b) - - answer := make([]byte, 0, len(b)*136/100) - for x.Cmp(bigZero) > 0 { - mod := new(big.Int) - x.DivMod(x, bigRadix, mod) - answer = append(answer, alphabet[mod.Int64()]) - } - - // leading zero bytes - for _, i := range b { - if i != 0 { - break - } - answer = append(answer, alphabetIdx0) - } - - // reverse - alen := len(answer) - for i := 0; i < alen/2; i++ { - answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i] - } - - return string(answer) -} diff --git a/pkg/encoding/base58/base58_test.go b/pkg/encoding/base58/base58_test.go deleted file mode 100644 index 2f75b00b8..000000000 --- a/pkg/encoding/base58/base58_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2013-2017 The btcsuite developers -// Use of this source code is governed by an ISC -// license that can be found in the LICENSE file. -// -// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE - -package base58_test - -import ( - "bytes" - "encoding/hex" - "testing" - - "github.com/pomerium/pomerium/pkg/encoding/base58" -) - -var stringTests = []struct { - in string - out string -}{ - {"", ""}, - {" ", "Z"}, - {"-", "n"}, - {"0", "q"}, - {"1", "r"}, - {"-1", "4SU"}, - {"11", "4k8"}, - {"abc", "ZiCa"}, - {"1234598760", "3mJr7AoUXx2Wqd"}, - {"abcdefghijklmnopqrstuvwxyz", "3yxU3u1igY8WkgtjK92fbJQCd4BZiiT1v25f"}, - {"00000000000000000000000000000000000000000000000000000000000000", "3sN2THZeE9Eh9eYrwkvZqNstbHGvrxSAM7gXUXvyFQP8XvQLUqNCS27icwUeDT7ckHm4FUHM2mTVh1vbLmk7y"}, -} - -var invalidStringTests = []struct { - in string - out string -}{ - {"0", ""}, - {"O", ""}, - {"I", ""}, - {"l", ""}, - {"3mJr0", ""}, - {"O3yxU", ""}, - {"3sNI", ""}, - {"4kl8", ""}, - {"0OIl", ""}, - {"!@#$%^&*()-_=+~`", ""}, -} - -var hexTests = []struct { - in string - out string -}{ - {"61", "2g"}, - {"626262", "a3gV"}, - {"636363", "aPEr"}, - {"73696d706c792061206c6f6e6720737472696e67", "2cFupjhnEsSn59qHXstmK2ffpLv2"}, - {"00eb15231dfceb60925886b67d065299925915aeb172c06647", "1NS17iag9jJgTHD1VXjvLCEnZuQ3rJDE9L"}, - {"516b6fcd0f", "ABnLTmg"}, - {"bf4f89001e670274dd", "3SEo3LWLoPntC"}, - {"572e4794", "3EFU7m"}, - {"ecac89cad93923c02321", "EJDM8drfXA6uyA"}, - {"10c8511e", "Rt5zm"}, - {"00000000000000000000", "1111111111"}, -} - -func TestBase58(t *testing.T) { - // Encode tests - for x, test := range stringTests { - tmp := []byte(test.in) - if res := base58.Encode(tmp); res != test.out { - t.Errorf("Encode test #%d failed: got: %s want: %s", - x, res, test.out) - continue - } - } - - // Decode tests - for x, test := range hexTests { - b, err := hex.DecodeString(test.in) - if err != nil { - t.Errorf("hex.DecodeString failed failed #%d: got: %s", x, test.in) - continue - } - if res := base58.Decode(test.out); !bytes.Equal(res, b) { - t.Errorf("Decode test #%d failed: got: %q want: %q", - x, res, test.in) - continue - } - } - - // Decode with invalid input - for x, test := range invalidStringTests { - if res := base58.Decode(test.in); string(res) != test.out { - t.Errorf("Decode invalidString test #%d failed: got: %q want: %q", - x, res, test.out) - continue - } - } -} diff --git a/pkg/encoding/base58/doc.go b/pkg/encoding/base58/doc.go deleted file mode 100644 index edd5f6df3..000000000 --- a/pkg/encoding/base58/doc.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2014 The btcsuite developers -// Use of this source code is governed by an ISC -// license that can be found in the LICENSE file. -// -// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE - -/* -Package base58 provides an API for working with modified base58 and Base58Check -encodings. - -# Modified Base58 Encoding - -Standard base58 encoding is similar to standard base64 encoding except, as the -name implies, it uses a 58 character alphabet which results in an alphanumeric -string and allows some characters which are problematic for humans to be -excluded. Due to this, there can be various base58 alphabets. - -The modified base58 alphabet used by Bitcoin, and hence this package, omits the -0, O, I, and l characters that look the same in many fonts and are therefore -hard to humans to distinguish. -*/ -package base58 diff --git a/pkg/grpc/device/device.go b/pkg/grpc/device/device.go index 3a67eac71..bc6744017 100644 --- a/pkg/grpc/device/device.go +++ b/pkg/grpc/device/device.go @@ -9,7 +9,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/pomerium/pomerium/pkg/encoding/base58" + "github.com/akamensky/base58" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" ) diff --git a/pkg/grpc/identity/identity.go b/pkg/grpc/identity/identity.go index 20bc9c475..065e4c9c1 100644 --- a/pkg/grpc/identity/identity.go +++ b/pkg/grpc/identity/identity.go @@ -6,7 +6,7 @@ import ( "google.golang.org/protobuf/proto" - "github.com/pomerium/pomerium/pkg/encoding/base58" + "github.com/akamensky/base58" ) // Clone clones the Provider. diff --git a/pkg/telemetry/requestid/requestid.go b/pkg/telemetry/requestid/requestid.go index df780b24d..91868c0a4 100644 --- a/pkg/telemetry/requestid/requestid.go +++ b/pkg/telemetry/requestid/requestid.go @@ -6,7 +6,7 @@ import ( "github.com/google/uuid" - "github.com/pomerium/pomerium/pkg/encoding/base58" + "github.com/akamensky/base58" ) const headerName = "x-request-id" diff --git a/pkg/webauthnutil/credential_storage.go b/pkg/webauthnutil/credential_storage.go index 84a520ba1..b3f1cbf41 100644 --- a/pkg/webauthnutil/credential_storage.go +++ b/pkg/webauthnutil/credential_storage.go @@ -6,7 +6,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/pomerium/pomerium/pkg/encoding/base58" + "github.com/akamensky/base58" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/device" "github.com/pomerium/webauthn" diff --git a/proxy/data.go b/proxy/data.go index ce13ce53e..076d393f1 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -34,7 +34,7 @@ func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Se func (p *Proxy) getSessionState(r *http.Request) (sessions.State, error) { state := p.state.Load() - rawJWT, err := state.sessionStore.LoadSession(r) + rawJWT, err := state.sessionStore.LoadSession(r.Context(), r) if err != nil { return sessions.State{}, err } diff --git a/proxy/handlers.go b/proxy/handlers.go index 682e0d100..ecc5b9f5a 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -113,7 +113,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri")) } - idp, err := options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) + idp, err := p.policyCache.Load().GetIdentityProviderForRequestURL(r.Context(), options, urlutil.GetAbsoluteURL(r).String()) if err != nil { return httputil.NewError(http.StatusInternalServerError, err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 8d3c78c18..8e7c131fd 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -56,6 +56,7 @@ type Proxy struct { currentOptions *atomicutil.Value[*config.Options] currentRouter *atomicutil.Value[*mux.Router] webauthn *webauthn.Handler + policyCache *atomicutil.Value[*config.PolicyCache] } // New takes a Proxy service from options and a validation function. @@ -66,10 +67,15 @@ func New(cfg *config.Config) (*Proxy, error) { return nil, err } + cache, err := config.NewPolicyCache(cfg.Options) + if err != nil { + return nil, err + } p := &Proxy{ state: atomicutil.NewValue(state), currentOptions: config.NewAtomicOptions(), currentRouter: atomicutil.NewValue(httputil.NewRouter()), + policyCache: atomicutil.NewValue(cache), } p.webauthn = webauthn.New(p.getWebauthnState) @@ -91,6 +97,9 @@ func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) { return } + if cache, err := config.NewPolicyCache(cfg.Options); err == nil { + p.policyCache.Store(cache) + } p.currentOptions.Store(cfg.Options) if err := p.setHandlers(cfg.Options); err != nil { log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")