authorize: hot path identity provider lookup optimizations

This commit is contained in:
Joe Kralicky 2024-06-18 21:40:03 -04:00
parent 7eca911292
commit e18c04216e
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
29 changed files with 387 additions and 284 deletions

View file

@ -249,7 +249,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
} }
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
r := httptest.NewRequest(tt.method, u.String(), nil) r := httptest.NewRequest(tt.method, u.String(), nil)
state, err := tt.sessionStore.LoadSession(r) state, err := tt.sessionStore.LoadSession(context.TODO(), r)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -481,7 +481,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
} }
r := httptest.NewRequest(http.MethodGet, "/", nil) r := httptest.NewRequest(http.MethodGet, "/", nil)
state, err := tt.session.LoadSession(r) state, err := tt.session.LoadSession(context.TODO(), r)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -586,7 +586,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
}), }),
} }
r := httptest.NewRequest(http.MethodGet, tt.url, nil) r := httptest.NewRequest(http.MethodGet, tt.url, nil)
state, err := tt.sessionStore.LoadSession(r) state, err := tt.sessionStore.LoadSession(context.TODO(), r)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -47,7 +47,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
hreq := getHTTPRequestFromCheckRequest(in) hreq := getHTTPRequestFromCheckRequest(in)
ctx = requestid.WithValue(ctx, requestid.FromHTTPHeader(hreq.Header)) ctx = requestid.WithValue(ctx, requestid.FromHTTPHeader(hreq.Header))
sessionState, _ := state.sessionStore.LoadSessionState(hreq) sessionState, _ := state.sessionStore.LoadSessionState(ctx, hreq)
var s sessionOrServiceAccount var s sessionOrServiceAccount
var u *user.User var u *user.User

View file

@ -1,6 +1,13 @@
package config package config
import ( import (
"context"
"fmt"
"net"
"slices"
"strings"
art "github.com/kralicky/go-adaptive-radix-tree"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/grpc/identity"
) )
@ -51,7 +58,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
} }
// GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL. // GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL.
func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity.Provider, error) { func (o *Options) GetIdentityProviderForRequestURL(ctx context.Context, requestURL string) (*identity.Provider, error) {
u, err := urlutil.ParseAndValidateURL(requestURL) u, err := urlutil.ParseAndValidateURL(requestURL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -65,3 +72,127 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
} }
return o.GetIdentityProviderForPolicy(nil) return o.GetIdentityProviderForPolicy(nil)
} }
type PolicyCache struct {
domainTree art.Tree[*domainNode]
matchPorts bool
}
func NewPolicyCache(options *Options) (*PolicyCache, error) {
tree := art.New[*domainNode]()
shouldMatchPorts := !options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)
for _, policy := range options.GetAllPolicies() {
u, err := urlutil.ParseAndValidateURL(policy.From)
if err != nil {
return nil, err
}
domains := urlutil.GetDomainsForURL(u, shouldMatchPorts)
for _, domain := range domains {
host, port, err := net.SplitHostPort(domain) // todo: this is not optimal
if err != nil {
host, port = domain, ""
}
domainKey := radixKeyForHostPort(host, port)
tree.Update(domainKey, newDomainNode, func(dn **domainNode) {
if policy.Prefix != "" {
(*dn).policiesByPrefix.Insert(art.Key(policy.Prefix), policy)
} else if policy.Path != "" {
(*dn).policiesByPrefix.Insert(art.Key(policy.Path), policy)
} else if policy.compiledRegex != nil {
(*dn).policiesByRegex = append((*dn).policiesByRegex, policy)
} else {
(*dn).policiesNoPathMatching = append((*dn).policiesNoPathMatching, policy)
}
})
}
}
return &PolicyCache{
domainTree: tree,
matchPorts: shouldMatchPorts,
}, nil
}
func (pc *PolicyCache) GetIdentityProviderForRequestURL(ctx context.Context, o *Options, requestURL string) (*identity.Provider, error) {
u, err := urlutil.ParseAndValidateURL(requestURL)
if err != nil {
return nil, err
}
domainKey := radixKeyForHostPort(u.Hostname(), u.Port())
domain, ok := pc.domainTree.Resolve(domainKey, wildcardResolver)
if !ok {
return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL)
}
var policy *Policy
if len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/') && len(domain.policiesNoPathMatching) > 0 {
policy = &domain.policiesNoPathMatching[0]
} else {
if domain.policiesByPrefix.Size() > 0 {
pathKey := art.Key(u.Path)
actualKey, val, found := domain.policiesByPrefix.SearchNearest(pathKey)
if found {
// check for prefix match or exact match
if c := actualKey.Compare(pathKey); c < 0 {
if val.Prefix != "" && strings.HasPrefix(u.Path, val.Prefix) {
policy = &val
}
} else if c == 0 {
if val.Path != "" || val.Prefix != "" {
policy = &val
}
}
}
}
if policy == nil {
for _, p := range domain.policiesByRegex {
if p.compiledRegex.MatchString(u.Path) {
policy = &p
break
}
}
}
}
if policy != nil {
return o.GetIdentityProviderForPolicy(policy)
}
return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL)
}
type domainNode struct {
policiesByPrefix art.Tree[Policy]
policiesByRegex []Policy
policiesNoPathMatching []Policy
}
func newDomainNode() *domainNode {
return &domainNode{
policiesByPrefix: art.New[Policy](),
}
}
func radixKeyForHostPort(host, port string) art.Key {
parts := strings.Split(host, ".")
sb := strings.Builder{}
sb.WriteString(port)
for i := len(parts) - 1; i >= 0; i-- {
sb.WriteByte('.')
sb.WriteString(parts[i])
}
return art.Key(sb.String())
}
func wildcardResolver(key art.Key, conflictIndex int) (art.Key, int) {
if conflictIndex >= len(key) {
return nil, -1
}
c := key[conflictIndex]
if c != '*' && c != '.' {
nextDot := slices.Index(key[conflictIndex:], '.')
if nextDot == -1 {
return art.Key("*"), len(key)
}
return art.Key("*"), conflictIndex + nextDot
}
return nil, -1
}

View file

@ -0,0 +1,184 @@
package config_test
import (
"context"
"encoding/base64"
"fmt"
"strings"
"testing"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/stretchr/testify/require"
)
func BenchmarkGetIdentityProviderForRequestURL_Old(b *testing.B) {
runBench := func(numPolicies int) func(b *testing.B) {
return func(b *testing.B) {
b.ReportAllocs()
options := config.NewDefaultOptions()
sharedKey := cryptutil.NewKey()
options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey)
options.Provider = "oidc"
options.ProviderURL = "https://oidc.example.com"
options.ClientID = "client_id"
options.ClientSecret = "client_secret"
urlFormat := "https://*.foo.bar.test-%d.example.com"
for i := range numPolicies {
options.Policies = append(options.Policies,
config.Policy{
From: fmt.Sprintf(urlFormat, i),
To: mustParseWeightedURLs(b, fmt.Sprintf("https://p2-%d", i)),
IDPClientID: fmt.Sprintf("client_id_%d", i),
IDPClientSecret: fmt.Sprintf("client_secret_%d", i),
},
)
}
require.NoError(b, options.Validate())
b.ResetTimer()
for range b.N {
idp, err := options.GetIdentityProviderForRequestURL(context.Background(), fmt.Sprintf(urlFormat, numPolicies-1))
require.NoError(b, err)
require.Equal(b, fmt.Sprintf("client_id_%d", numPolicies-1), idp.ClientId)
require.Equal(b, fmt.Sprintf("client_secret_%d", numPolicies-1), idp.ClientSecret)
}
}
}
b.Run("5 policies", runBench(5))
b.Run("50 policies", runBench(50))
b.Run("500 policies", runBench(500))
b.Run("5000 policies", runBench(5000))
}
var bench = func(fill func(i int, p *config.Policy) string, numPolicies int) func(b *testing.B) {
return func(b *testing.B) {
b.ReportAllocs()
options := config.NewDefaultOptions()
sharedKey := cryptutil.NewKey()
options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey)
options.Provider = "oidc"
options.ProviderURL = "https://oidc.example.com"
options.ClientID = "client_id"
options.ClientSecret = "client_secret"
var allUrls []string
for i := range numPolicies {
p := config.Policy{
To: mustParseWeightedURLs(b, fmt.Sprintf("https://p2-%d", i)),
IDPClientID: fmt.Sprintf("client_id_%d", i),
IDPClientSecret: fmt.Sprintf("client_secret_%d", i),
}
allUrls = append(allUrls, fill(i, &p))
options.Policies = append(options.Policies, p)
}
require.NoError(b, options.Validate())
cache, err := config.NewPolicyCache(options)
require.NoError(b, err)
b.ResetTimer()
for i := range b.N {
reqUrl := strings.Replace(allUrls[i%numPolicies], "*", fmt.Sprint(i), 1)
idp, err := cache.GetIdentityProviderForRequestURL(context.Background(), options, reqUrl)
require.NoError(b, err)
require.Equal(b, fmt.Sprintf("client_id_%d", i%numPolicies), idp.ClientId)
require.Equal(b, fmt.Sprintf("client_secret_%d", i%numPolicies), idp.ClientSecret)
}
}
}
func BenchmarkGetIdentityProviderForRequestURL_New_DomainMatchOnly(b *testing.B) {
domainMatchingOnly := func(i int, p *config.Policy) string {
p.From = fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i)
return p.From
}
b.Run("5 policies (domain matching only)", bench(domainMatchingOnly, 5))
b.Run("50 policies (domain matching only)", bench(domainMatchingOnly, 50))
b.Run("500 policies (domain matching only)", bench(domainMatchingOnly, 500))
b.Run("5000 policies (domain matching only)", bench(domainMatchingOnly, 5000))
}
func BenchmarkGetIdentityProviderForRequestURL_New_PathMatchOnly(b *testing.B) {
pathMatchingOnly := func(i int, p *config.Policy) string {
p.From = "https://example.com"
p.Path = fmt.Sprintf("/foo/bar/path%d", i)
return p.From + p.Path
}
b.Run("5 policies (path matching only)", bench(pathMatchingOnly, 5))
b.Run("50 policies (path matching only)", bench(pathMatchingOnly, 50))
b.Run("500 policies (path matching only)", bench(pathMatchingOnly, 500))
b.Run("5000 policies (path matching only)", bench(pathMatchingOnly, 5000))
}
func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPathMatching(b *testing.B) {
combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string {
// returns a sequence of policies (ex: numPathsPerDomain=3)
// https://*.foo.bar.test-0.example.com
// https://*.foo.bar.test-0.example.com/foo/bar/path1
// https://*.foo.bar.test-0.example.com/foo/bar/path2
// https://*.foo.bar.test-1.example.com
// https://*.foo.bar.test-1.example.com/foo/bar/path1
// https://*.foo.bar.test-1.example.com/foo/bar/path2
return func(i int, p *config.Policy) string {
domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain)
pathIdx := i % numPathsPerDomain
var path string
if pathIdx == 0 {
path = ""
} else {
path = fmt.Sprintf("/foo/bar/path%d", pathIdx)
}
p.From = domain
p.Path = path
return domain + path
}
}
b.Run("25 policies (5 domains, 5 paths per domain)", bench(combinedMatching(5), 25))
b.Run("500 policies (50 domains, 10 paths per domain)", bench(combinedMatching(10), 500))
b.Run("500 policies (10 domains, 50 paths per domain)", bench(combinedMatching(50), 500))
b.Run("2500 policies (50 domains, 50 paths per domain)", bench(combinedMatching(50), 2500))
b.Run("5000 policies (100 domains, 50 paths per domain)", bench(combinedMatching(50), 5000))
b.Run("5000 policies (50 domains, 100 paths per domain)", bench(combinedMatching(100), 5000))
b.Run("10000 policies (100 domains, 100 paths per domain)", bench(combinedMatching(100), 10000))
}
func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPrefixMatching(b *testing.B) {
combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string {
// returns a sequence of policies (ex: numPathsPerDomain=3)
// https://*.foo.bar.test-0.example.com/0
// https://*.foo.bar.test-0.example.com/0/1
// https://*.foo.bar.test-0.example.com/0/1/2
// https://*.foo.bar.test-1.example.com/0
// https://*.foo.bar.test-1.example.com/0/1
// https://*.foo.bar.test-1.example.com/0/1/2
return func(i int, p *config.Policy) string {
domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain)
pathIdx := i % numPathsPerDomain
var prefix strings.Builder
for i := range pathIdx + 1 {
fmt.Fprintf(&prefix, "/%d", i)
}
p.From = domain
p.Prefix = prefix.String()
return domain + p.Prefix
}
}
b.Run("25 policies (5 domains, 5 paths per domain)", bench(combinedMatching(5), 25))
b.Run("500 policies (50 domains, 10 paths per domain)", bench(combinedMatching(10), 500))
b.Run("500 policies (10 domains, 50 paths per domain)", bench(combinedMatching(50), 500))
b.Run("2500 policies (50 domains, 50 paths per domain)", bench(combinedMatching(50), 2500))
b.Run("5000 policies (100 domains, 50 paths per domain)", bench(combinedMatching(50), 5000))
b.Run("5000 policies (50 domains, 100 paths per domain)", bench(combinedMatching(100), 5000))
b.Run("10000 policies (100 domains, 100 paths per domain)", bench(combinedMatching(100), 10000))
}
func mustParseWeightedURLs(t testing.TB, urls ...string) []config.WeightedURL {
wu, err := config.ParseWeightedUrls(urls...)
require.NoError(t, err)
return wu
}

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
@ -10,20 +11,27 @@ import (
"github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header" "github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
// A SessionStore saves and loads sessions based on the options. // A SessionStore saves and loads sessions based on the options.
type SessionStore struct { type SessionStore struct {
options *Options options *Options
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
loader sessions.SessionLoader loader sessions.SessionLoader
policyCache *PolicyCache
} }
// NewSessionStore creates a new SessionStore from the Options. // NewSessionStore creates a new SessionStore from the Options.
func NewSessionStore(options *Options) (*SessionStore, error) { func NewSessionStore(options *Options) (*SessionStore, error) {
cache, err := NewPolicyCache(options)
if err != nil {
return nil, err
}
store := &SessionStore{ store := &SessionStore{
options: options, options: options,
policyCache: cache,
} }
sharedKey, err := options.GetSharedKey() sharedKey, err := options.GetSharedKey()
@ -57,8 +65,11 @@ func NewSessionStore(options *Options) (*SessionStore, error) {
} }
// LoadSessionState loads the session state from a request. // LoadSessionState loads the session state from a request.
func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, error) { func (store *SessionStore) LoadSessionState(ctx context.Context, r *http.Request) (*sessions.State, error) {
rawJWT, err := store.loader.LoadSession(r) ctx, span := trace.StartSpan(ctx, "session_store.load_session_state")
defer span.End()
rawJWT, err := store.loader.LoadSession(ctx, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,7 +82,7 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e
// confirm that the identity provider id matches the state // confirm that the identity provider id matches the state
if state.IdentityProviderID != "" { if state.IdentityProviderID != "" {
idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) idp, err := store.policyCache.GetIdentityProviderForRequestURL(ctx, store.options, urlutil.GetAbsoluteURL(r).String())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"net/url" "net/url"
@ -70,7 +71,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
t.Run("mssing", func(t *testing.T) { t.Run("mssing", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
s, err := store.LoadSessionState(r) s, err := store.LoadSessionState(context.TODO(), r)
assert.ErrorIs(t, err, sessions.ErrNoSessionFound) assert.ErrorIs(t, err, sessions.ErrNoSessionFound)
assert.Nil(t, s) assert.Nil(t, s)
}) })
@ -85,7 +86,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
urlutil.QuerySession: {rawJWS}, urlutil.QuerySession: {rawJWS},
}.Encode(), nil) }.Encode(), nil)
require.NoError(t, err) require.NoError(t, err)
s, err := store.LoadSessionState(r) s, err := store.LoadSessionState(context.TODO(), r)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{ assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com", Issuer: "authenticate.example.com",
@ -103,7 +104,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionState(r) s, err := store.LoadSessionState(context.TODO(), r)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{ assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com", Issuer: "authenticate.example.com",
@ -121,7 +122,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionState(r) s, err := store.LoadSessionState(context.TODO(), r)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, s) assert.Nil(t, s)
}) })
@ -134,7 +135,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionState(r) s, err := store.LoadSessionState(context.TODO(), r)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{ assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com", Issuer: "authenticate.example.com",

2
go.mod
View file

@ -36,6 +36,7 @@ require (
github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jackc/pgx/v5 v5.6.0 github.com/jackc/pgx/v5 v5.6.0
github.com/klauspost/compress v1.17.8 github.com/klauspost/compress v1.17.8
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31
github.com/martinlindhe/base36 v1.1.1 github.com/martinlindhe/base36 v1.1.1
github.com/mholt/acmez/v2 v2.0.1 github.com/mholt/acmez/v2 v2.0.1
github.com/minio/minio-go/v7 v7.0.70 github.com/minio/minio-go/v7 v7.0.70
@ -104,6 +105,7 @@ require (
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/OneOfOne/xxhash v1.2.8 // indirect github.com/OneOfOne/xxhash v1.2.8 // indirect
github.com/agnivade/levenshtein v1.1.1 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f // indirect
github.com/andybalholm/brotli v1.0.5 // indirect github.com/andybalholm/brotli v1.0.5 // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect

4
go.sum
View file

@ -74,6 +74,8 @@ github.com/VictoriaMetrics/fastcache v1.12.2 h1:N0y9ASrJ0F6h0QaC3o6uJb3NIZ9VKLjC
github.com/VictoriaMetrics/fastcache v1.12.2/go.mod h1:AmC+Nzz1+3G2eCPapF6UcsnkThDcMsQicp4xDukwJYI= github.com/VictoriaMetrics/fastcache v1.12.2/go.mod h1:AmC+Nzz1+3G2eCPapF6UcsnkThDcMsQicp4xDukwJYI=
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f h1:z8MkSJCUyTmW5YQlxsMLBlwA7GmjxC7L4ooicxqnhz8=
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f/go.mod h1:UdUwYgAXBiL+kLfcqxoQJYkHA/vl937/PbFhZM34aZs=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@ -414,6 +416,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31 h1:cOZUoNuv9OuCZKepddeIW87ScJWwveLfLcJMaii6YCA=
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss=
github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=

View file

@ -2,6 +2,7 @@
package cookie package cookie
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -117,7 +118,7 @@ func getCookies(r *http.Request, name string) []*http.Cookie {
} }
// LoadSession returns a State from the cookie in the request. // LoadSession returns a State from the cookie in the request.
func (cs *Store) LoadSession(r *http.Request) (string, error) { func (cs *Store) LoadSession(ctx context.Context, r *http.Request) (string, error) {
opts := cs.getOptions() opts := cs.getOptions()
cookies := getCookies(r, opts.Name) cookies := getCookies(r, opts.Name)
if len(cookies) == 0 { if len(cookies) == 0 {

View file

@ -1,6 +1,7 @@
package cookie package cookie
import ( import (
"context"
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt" "fmt"
@ -145,7 +146,7 @@ func TestStore_SaveSession(t *testing.T) {
r.AddCookie(cookie) r.AddCookie(cookie)
} }
jwt, err := s.LoadSession(r) jwt, err := s.LoadSession(context.TODO(), r)
if (err != nil) != tt.wantLoadErr { if (err != nil) != tt.wantLoadErr {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr) t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return return

View file

@ -3,6 +3,7 @@
package header package header
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
@ -31,7 +32,7 @@ func NewStore(enc encoding.Unmarshaler) *Store {
} }
// LoadSession tries to retrieve the token string from the Authorization header. // LoadSession tries to retrieve the token string from the Authorization header.
func (as *Store) LoadSession(r *http.Request) (string, error) { func (as *Store) LoadSession(_ context.Context, r *http.Request) (string, error) {
jwt := TokenFromHeaders(r) jwt := TokenFromHeaders(r)
if jwt == "" { if jwt == "" {
return "", sessions.ErrNoSessionFound return "", sessions.ErrNoSessionFound

View file

@ -23,7 +23,7 @@ func retrieve(s SessionLoader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) { hfn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jwt, err := s.LoadSession(r) jwt, err := s.LoadSession(ctx, r)
ctx = NewContext(ctx, jwt, err) ctx = NewContext(ctx, jwt, err)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} }

View file

@ -2,6 +2,7 @@
package mock package mock
import ( import (
"context"
"net/http" "net/http"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
@ -30,7 +31,7 @@ func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) {
} }
// LoadSession returns the session and a error // LoadSession returns the session and a error
func (ms Store) LoadSession(*http.Request) (string, error) { func (ms Store) LoadSession(context.Context, *http.Request) (string, error) {
var signer encoding.MarshalUnmarshaler var signer encoding.MarshalUnmarshaler
signer, _ = jws.NewHS256Signer(ms.Secret) signer, _ = jws.NewHS256Signer(ms.Secret)
jwt, _ := signer.Marshal(ms.Session) jwt, _ := signer.Marshal(ms.Session)

View file

@ -1,6 +1,7 @@
package mock package mock
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -40,7 +41,7 @@ func TestStore(t *testing.T) {
t.Errorf("mockstore.SaveSession() error = %v, wantSaveErr %v", err, tt.wantSaveErr) t.Errorf("mockstore.SaveSession() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
return return
} }
got, err := ms.LoadSession(nil) got, err := ms.LoadSession(context.TODO(), nil)
if (err != nil) != tt.wantLoadErr { if (err != nil) != tt.wantLoadErr {
t.Errorf("mockstore.LoadSession() error = %v, wantLoadErr %v", err, tt.wantLoadErr) t.Errorf("mockstore.LoadSession() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
return return

View file

@ -3,6 +3,7 @@
package queryparam package queryparam
import ( import (
"context"
"net/http" "net/http"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
@ -43,7 +44,7 @@ func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store {
} }
// LoadSession tries to retrieve the token string from URL query parameters. // LoadSession tries to retrieve the token string from URL query parameters.
func (qp *Store) LoadSession(r *http.Request) (string, error) { func (qp *Store) LoadSession(_ context.Context, r *http.Request) (string, error) {
jwt := r.URL.Query().Get(qp.queryParamKey) jwt := r.URL.Query().Get(qp.queryParamKey)
if jwt == "" { if jwt == "" {
return "", sessions.ErrNoSessionFound return "", sessions.ErrNoSessionFound

View file

@ -3,6 +3,7 @@
package sessions package sessions
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
) )
@ -16,14 +17,14 @@ type SessionStore interface {
// SessionLoader defines an interface for loading a session. // SessionLoader defines an interface for loading a session.
type SessionLoader interface { type SessionLoader interface {
LoadSession(*http.Request) (string, error) LoadSession(context.Context, *http.Request) (string, error)
} }
type multiSessionLoader []SessionLoader type multiSessionLoader []SessionLoader
func (l multiSessionLoader) LoadSession(r *http.Request) (string, error) { func (l multiSessionLoader) LoadSession(ctx context.Context, r *http.Request) (string, error) {
for _, ll := range l { for _, ll := range l {
s, err := ll.LoadSession(r) s, err := ll.LoadSession(ctx, r)
if errors.Is(err, ErrNoSessionFound) { if errors.Is(err, ErrNoSessionFound) {
continue continue
} }

View file

@ -7,7 +7,7 @@ import (
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"github.com/pomerium/pomerium/pkg/encoding/base58" "github.com/akamensky/base58"
) )
// A KeyEncryptionKey (KEK) is used to implement *envelope encryption*, similar to how data is stored at rest with // A KeyEncryptionKey (KEK) is used to implement *envelope encryption*, similar to how data is stored at rest with

View file

@ -9,7 +9,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pomerium/pomerium/pkg/encoding/base58" "github.com/akamensky/base58"
) )
// TokenLength is the length of a token. // TokenLength is the length of a token.
@ -27,7 +27,7 @@ func NewRandomToken() (tok Token) {
// TokenFromString parses a base58-encoded string into a token. // TokenFromString parses a base58-encoded string into a token.
func TokenFromString(rawstr string) (tok Token, ok bool) { func TokenFromString(rawstr string) (tok Token, ok bool) {
result := base58.Decode(rawstr) result, _ := base58.Decode(rawstr)
if len(result) != TokenLength { if len(result) != TokenLength {
return tok, false return tok, false
} }
@ -57,7 +57,7 @@ type SecretToken struct {
// SecretTokenFromString parses a base58-encoded string into a secret token. // SecretTokenFromString parses a base58-encoded string into a secret token.
func SecretTokenFromString(rawstr string) (tok SecretToken, ok bool) { func SecretTokenFromString(rawstr string) (tok SecretToken, ok bool) {
result := base58.Decode(rawstr) result, _ := base58.Decode(rawstr)
if len(result) != TokenLength*2 { if len(result) != TokenLength*2 {
return tok, false return tok, false
} }
@ -104,7 +104,7 @@ func GenerateSecureToken(key []byte, expiry time.Time, token Token) SecureToken
// SecureTokenFromString parses a base58-encoded string into a SecureToken. // SecureTokenFromString parses a base58-encoded string into a SecureToken.
func SecureTokenFromString(rawstr string) (secureToken SecureToken, ok bool) { func SecureTokenFromString(rawstr string) (secureToken SecureToken, ok bool) {
result := base58.Decode(rawstr) result, _ := base58.Decode(rawstr)
if len(result) != SecureTokenLength { if len(result) != SecureTokenLength {
return secureToken, false return secureToken, false
} }

View file

@ -1,49 +0,0 @@
// Copyright (c) 2015 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
//
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
package base58
const (
// alphabet is the modified base58 alphabet used by Bitcoin.
alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
alphabetIdx0 = '1'
)
var b58 = [256]byte{
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 0, 1, 2, 3, 4, 5, 6,
7, 8, 255, 255, 255, 255, 255, 255,
255, 9, 10, 11, 12, 13, 14, 15,
16, 255, 17, 18, 19, 20, 21, 255,
22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 255, 255, 255, 255, 255,
255, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 255, 44, 45, 46,
47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
}

View file

@ -1,75 +0,0 @@
// Copyright (c) 2013-2015 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
//
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
package base58
import "math/big"
var (
bigRadix = big.NewInt(58)
bigZero = big.NewInt(0)
)
// Decode decodes a modified base58 string to a byte slice.
func Decode(b string) []byte {
answer := big.NewInt(0)
j := big.NewInt(1)
scratch := new(big.Int)
for i := len(b) - 1; i >= 0; i-- {
tmp := b58[b[i]]
if tmp == 255 {
return []byte("")
}
scratch.SetInt64(int64(tmp))
scratch.Mul(j, scratch)
answer.Add(answer, scratch)
j.Mul(j, bigRadix)
}
tmpval := answer.Bytes()
var numZeros int
for numZeros = 0; numZeros < len(b); numZeros++ {
if b[numZeros] != alphabetIdx0 {
break
}
}
flen := numZeros + len(tmpval)
val := make([]byte, flen)
copy(val[numZeros:], tmpval)
return val
}
// Encode encodes a byte slice to a modified base58 string.
func Encode(b []byte) string {
x := new(big.Int)
x.SetBytes(b)
answer := make([]byte, 0, len(b)*136/100)
for x.Cmp(bigZero) > 0 {
mod := new(big.Int)
x.DivMod(x, bigRadix, mod)
answer = append(answer, alphabet[mod.Int64()])
}
// leading zero bytes
for _, i := range b {
if i != 0 {
break
}
answer = append(answer, alphabetIdx0)
}
// reverse
alen := len(answer)
for i := 0; i < alen/2; i++ {
answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i]
}
return string(answer)
}

View file

@ -1,100 +0,0 @@
// Copyright (c) 2013-2017 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
//
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
package base58_test
import (
"bytes"
"encoding/hex"
"testing"
"github.com/pomerium/pomerium/pkg/encoding/base58"
)
var stringTests = []struct {
in string
out string
}{
{"", ""},
{" ", "Z"},
{"-", "n"},
{"0", "q"},
{"1", "r"},
{"-1", "4SU"},
{"11", "4k8"},
{"abc", "ZiCa"},
{"1234598760", "3mJr7AoUXx2Wqd"},
{"abcdefghijklmnopqrstuvwxyz", "3yxU3u1igY8WkgtjK92fbJQCd4BZiiT1v25f"},
{"00000000000000000000000000000000000000000000000000000000000000", "3sN2THZeE9Eh9eYrwkvZqNstbHGvrxSAM7gXUXvyFQP8XvQLUqNCS27icwUeDT7ckHm4FUHM2mTVh1vbLmk7y"},
}
var invalidStringTests = []struct {
in string
out string
}{
{"0", ""},
{"O", ""},
{"I", ""},
{"l", ""},
{"3mJr0", ""},
{"O3yxU", ""},
{"3sNI", ""},
{"4kl8", ""},
{"0OIl", ""},
{"!@#$%^&*()-_=+~`", ""},
}
var hexTests = []struct {
in string
out string
}{
{"61", "2g"},
{"626262", "a3gV"},
{"636363", "aPEr"},
{"73696d706c792061206c6f6e6720737472696e67", "2cFupjhnEsSn59qHXstmK2ffpLv2"},
{"00eb15231dfceb60925886b67d065299925915aeb172c06647", "1NS17iag9jJgTHD1VXjvLCEnZuQ3rJDE9L"},
{"516b6fcd0f", "ABnLTmg"},
{"bf4f89001e670274dd", "3SEo3LWLoPntC"},
{"572e4794", "3EFU7m"},
{"ecac89cad93923c02321", "EJDM8drfXA6uyA"},
{"10c8511e", "Rt5zm"},
{"00000000000000000000", "1111111111"},
}
func TestBase58(t *testing.T) {
// Encode tests
for x, test := range stringTests {
tmp := []byte(test.in)
if res := base58.Encode(tmp); res != test.out {
t.Errorf("Encode test #%d failed: got: %s want: %s",
x, res, test.out)
continue
}
}
// Decode tests
for x, test := range hexTests {
b, err := hex.DecodeString(test.in)
if err != nil {
t.Errorf("hex.DecodeString failed failed #%d: got: %s", x, test.in)
continue
}
if res := base58.Decode(test.out); !bytes.Equal(res, b) {
t.Errorf("Decode test #%d failed: got: %q want: %q",
x, res, test.in)
continue
}
}
// Decode with invalid input
for x, test := range invalidStringTests {
if res := base58.Decode(test.in); string(res) != test.out {
t.Errorf("Decode invalidString test #%d failed: got: %q want: %q",
x, res, test.out)
continue
}
}
}

View file

@ -1,22 +0,0 @@
// Copyright (c) 2014 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
//
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
/*
Package base58 provides an API for working with modified base58 and Base58Check
encodings.
# Modified Base58 Encoding
Standard base58 encoding is similar to standard base64 encoding except, as the
name implies, it uses a 58 character alphabet which results in an alphanumeric
string and allows some characters which are problematic for humans to be
excluded. Due to this, there can be various base58 alphabets.
The modified base58 alphabet used by Bitcoin, and hence this package, omits the
0, O, I, and l characters that look the same in many fonts and are therefore
hard to humans to distinguish.
*/
package base58

View file

@ -9,7 +9,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/encoding/base58" "github.com/akamensky/base58"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
) )

View file

@ -6,7 +6,7 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/encoding/base58" "github.com/akamensky/base58"
) )
// Clone clones the Provider. // Clone clones the Provider.

View file

@ -6,7 +6,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pomerium/pomerium/pkg/encoding/base58" "github.com/akamensky/base58"
) )
const headerName = "x-request-id" const headerName = "x-request-id"

View file

@ -6,7 +6,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/pomerium/pomerium/pkg/encoding/base58" "github.com/akamensky/base58"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/device" "github.com/pomerium/pomerium/pkg/grpc/device"
"github.com/pomerium/webauthn" "github.com/pomerium/webauthn"

View file

@ -34,7 +34,7 @@ func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Se
func (p *Proxy) getSessionState(r *http.Request) (sessions.State, error) { func (p *Proxy) getSessionState(r *http.Request) (sessions.State, error) {
state := p.state.Load() state := p.state.Load()
rawJWT, err := state.sessionStore.LoadSession(r) rawJWT, err := state.sessionStore.LoadSession(r.Context(), r)
if err != nil { if err != nil {
return sessions.State{}, err return sessions.State{}, err
} }

View file

@ -113,7 +113,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri")) return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri"))
} }
idp, err := options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) idp, err := p.policyCache.Load().GetIdentityProviderForRequestURL(r.Context(), options, urlutil.GetAbsoluteURL(r).String())
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, err) return httputil.NewError(http.StatusInternalServerError, err)
} }

View file

@ -56,6 +56,7 @@ type Proxy struct {
currentOptions *atomicutil.Value[*config.Options] currentOptions *atomicutil.Value[*config.Options]
currentRouter *atomicutil.Value[*mux.Router] currentRouter *atomicutil.Value[*mux.Router]
webauthn *webauthn.Handler webauthn *webauthn.Handler
policyCache *atomicutil.Value[*config.PolicyCache]
} }
// New takes a Proxy service from options and a validation function. // New takes a Proxy service from options and a validation function.
@ -66,10 +67,15 @@ func New(cfg *config.Config) (*Proxy, error) {
return nil, err return nil, err
} }
cache, err := config.NewPolicyCache(cfg.Options)
if err != nil {
return nil, err
}
p := &Proxy{ p := &Proxy{
state: atomicutil.NewValue(state), state: atomicutil.NewValue(state),
currentOptions: config.NewAtomicOptions(), currentOptions: config.NewAtomicOptions(),
currentRouter: atomicutil.NewValue(httputil.NewRouter()), currentRouter: atomicutil.NewValue(httputil.NewRouter()),
policyCache: atomicutil.NewValue(cache),
} }
p.webauthn = webauthn.New(p.getWebauthnState) p.webauthn = webauthn.New(p.getWebauthnState)
@ -91,6 +97,9 @@ func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) {
return return
} }
if cache, err := config.NewPolicyCache(cfg.Options); err == nil {
p.policyCache.Store(cache)
}
p.currentOptions.Store(cfg.Options) p.currentOptions.Store(cfg.Options)
if err := p.setHandlers(cfg.Options); err != nil { if err := p.setHandlers(cfg.Options); err != nil {
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")