mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-22 11:08:10 +02:00
authorize: hot path identity provider lookup optimizations
This commit is contained in:
parent
7eca911292
commit
e18c04216e
29 changed files with 387 additions and 284 deletions
|
@ -249,7 +249,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
}
|
||||
u.RawQuery = params.Encode()
|
||||
r := httptest.NewRequest(tt.method, u.String(), nil)
|
||||
state, err := tt.sessionStore.LoadSession(r)
|
||||
state, err := tt.sessionStore.LoadSession(context.TODO(), r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -481,7 +481,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
state, err := tt.session.LoadSession(r)
|
||||
state, err := tt.session.LoadSession(context.TODO(), r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -586,7 +586,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
}),
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
|
||||
state, err := tt.sessionStore.LoadSession(r)
|
||||
state, err := tt.sessionStore.LoadSession(context.TODO(), r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
hreq := getHTTPRequestFromCheckRequest(in)
|
||||
ctx = requestid.WithValue(ctx, requestid.FromHTTPHeader(hreq.Header))
|
||||
|
||||
sessionState, _ := state.sessionStore.LoadSessionState(hreq)
|
||||
sessionState, _ := state.sessionStore.LoadSessionState(ctx, hreq)
|
||||
|
||||
var s sessionOrServiceAccount
|
||||
var u *user.User
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
art "github.com/kralicky/go-adaptive-radix-tree"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
@ -51,7 +58,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
}
|
||||
|
||||
// GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL.
|
||||
func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity.Provider, error) {
|
||||
func (o *Options) GetIdentityProviderForRequestURL(ctx context.Context, requestURL string) (*identity.Provider, error) {
|
||||
u, err := urlutil.ParseAndValidateURL(requestURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -65,3 +72,127 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
|
|||
}
|
||||
return o.GetIdentityProviderForPolicy(nil)
|
||||
}
|
||||
|
||||
type PolicyCache struct {
|
||||
domainTree art.Tree[*domainNode]
|
||||
matchPorts bool
|
||||
}
|
||||
|
||||
func NewPolicyCache(options *Options) (*PolicyCache, error) {
|
||||
tree := art.New[*domainNode]()
|
||||
shouldMatchPorts := !options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)
|
||||
for _, policy := range options.GetAllPolicies() {
|
||||
u, err := urlutil.ParseAndValidateURL(policy.From)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domains := urlutil.GetDomainsForURL(u, shouldMatchPorts)
|
||||
for _, domain := range domains {
|
||||
host, port, err := net.SplitHostPort(domain) // todo: this is not optimal
|
||||
if err != nil {
|
||||
host, port = domain, ""
|
||||
}
|
||||
domainKey := radixKeyForHostPort(host, port)
|
||||
tree.Update(domainKey, newDomainNode, func(dn **domainNode) {
|
||||
if policy.Prefix != "" {
|
||||
(*dn).policiesByPrefix.Insert(art.Key(policy.Prefix), policy)
|
||||
} else if policy.Path != "" {
|
||||
(*dn).policiesByPrefix.Insert(art.Key(policy.Path), policy)
|
||||
} else if policy.compiledRegex != nil {
|
||||
(*dn).policiesByRegex = append((*dn).policiesByRegex, policy)
|
||||
} else {
|
||||
(*dn).policiesNoPathMatching = append((*dn).policiesNoPathMatching, policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
return &PolicyCache{
|
||||
domainTree: tree,
|
||||
matchPorts: shouldMatchPorts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pc *PolicyCache) GetIdentityProviderForRequestURL(ctx context.Context, o *Options, requestURL string) (*identity.Provider, error) {
|
||||
u, err := urlutil.ParseAndValidateURL(requestURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
domainKey := radixKeyForHostPort(u.Hostname(), u.Port())
|
||||
domain, ok := pc.domainTree.Resolve(domainKey, wildcardResolver)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL)
|
||||
}
|
||||
var policy *Policy
|
||||
if len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/') && len(domain.policiesNoPathMatching) > 0 {
|
||||
policy = &domain.policiesNoPathMatching[0]
|
||||
} else {
|
||||
if domain.policiesByPrefix.Size() > 0 {
|
||||
pathKey := art.Key(u.Path)
|
||||
actualKey, val, found := domain.policiesByPrefix.SearchNearest(pathKey)
|
||||
if found {
|
||||
// check for prefix match or exact match
|
||||
if c := actualKey.Compare(pathKey); c < 0 {
|
||||
if val.Prefix != "" && strings.HasPrefix(u.Path, val.Prefix) {
|
||||
policy = &val
|
||||
}
|
||||
} else if c == 0 {
|
||||
if val.Path != "" || val.Prefix != "" {
|
||||
policy = &val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if policy == nil {
|
||||
for _, p := range domain.policiesByRegex {
|
||||
if p.compiledRegex.MatchString(u.Path) {
|
||||
policy = &p
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if policy != nil {
|
||||
return o.GetIdentityProviderForPolicy(policy)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL)
|
||||
}
|
||||
|
||||
type domainNode struct {
|
||||
policiesByPrefix art.Tree[Policy]
|
||||
policiesByRegex []Policy
|
||||
policiesNoPathMatching []Policy
|
||||
}
|
||||
|
||||
func newDomainNode() *domainNode {
|
||||
return &domainNode{
|
||||
policiesByPrefix: art.New[Policy](),
|
||||
}
|
||||
}
|
||||
|
||||
func radixKeyForHostPort(host, port string) art.Key {
|
||||
parts := strings.Split(host, ".")
|
||||
sb := strings.Builder{}
|
||||
sb.WriteString(port)
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
sb.WriteByte('.')
|
||||
sb.WriteString(parts[i])
|
||||
}
|
||||
return art.Key(sb.String())
|
||||
}
|
||||
|
||||
func wildcardResolver(key art.Key, conflictIndex int) (art.Key, int) {
|
||||
if conflictIndex >= len(key) {
|
||||
return nil, -1
|
||||
}
|
||||
c := key[conflictIndex]
|
||||
if c != '*' && c != '.' {
|
||||
nextDot := slices.Index(key[conflictIndex:], '.')
|
||||
if nextDot == -1 {
|
||||
return art.Key("*"), len(key)
|
||||
}
|
||||
return art.Key("*"), conflictIndex + nextDot
|
||||
}
|
||||
return nil, -1
|
||||
}
|
||||
|
|
184
config/identity_benchmark_test.go
Normal file
184
config/identity_benchmark_test.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package config_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_Old(b *testing.B) {
|
||||
runBench := func(numPolicies int) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
options := config.NewDefaultOptions()
|
||||
sharedKey := cryptutil.NewKey()
|
||||
options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey)
|
||||
options.Provider = "oidc"
|
||||
options.ProviderURL = "https://oidc.example.com"
|
||||
options.ClientID = "client_id"
|
||||
options.ClientSecret = "client_secret"
|
||||
urlFormat := "https://*.foo.bar.test-%d.example.com"
|
||||
for i := range numPolicies {
|
||||
options.Policies = append(options.Policies,
|
||||
config.Policy{
|
||||
From: fmt.Sprintf(urlFormat, i),
|
||||
To: mustParseWeightedURLs(b, fmt.Sprintf("https://p2-%d", i)),
|
||||
IDPClientID: fmt.Sprintf("client_id_%d", i),
|
||||
IDPClientSecret: fmt.Sprintf("client_secret_%d", i),
|
||||
},
|
||||
)
|
||||
}
|
||||
require.NoError(b, options.Validate())
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
idp, err := options.GetIdentityProviderForRequestURL(context.Background(), fmt.Sprintf(urlFormat, numPolicies-1))
|
||||
require.NoError(b, err)
|
||||
require.Equal(b, fmt.Sprintf("client_id_%d", numPolicies-1), idp.ClientId)
|
||||
require.Equal(b, fmt.Sprintf("client_secret_%d", numPolicies-1), idp.ClientSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.Run("5 policies", runBench(5))
|
||||
b.Run("50 policies", runBench(50))
|
||||
b.Run("500 policies", runBench(500))
|
||||
b.Run("5000 policies", runBench(5000))
|
||||
}
|
||||
|
||||
var bench = func(fill func(i int, p *config.Policy) string, numPolicies int) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
options := config.NewDefaultOptions()
|
||||
sharedKey := cryptutil.NewKey()
|
||||
options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey)
|
||||
options.Provider = "oidc"
|
||||
options.ProviderURL = "https://oidc.example.com"
|
||||
options.ClientID = "client_id"
|
||||
options.ClientSecret = "client_secret"
|
||||
var allUrls []string
|
||||
for i := range numPolicies {
|
||||
p := config.Policy{
|
||||
To: mustParseWeightedURLs(b, fmt.Sprintf("https://p2-%d", i)),
|
||||
IDPClientID: fmt.Sprintf("client_id_%d", i),
|
||||
IDPClientSecret: fmt.Sprintf("client_secret_%d", i),
|
||||
}
|
||||
|
||||
allUrls = append(allUrls, fill(i, &p))
|
||||
options.Policies = append(options.Policies, p)
|
||||
}
|
||||
require.NoError(b, options.Validate())
|
||||
|
||||
cache, err := config.NewPolicyCache(options)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
reqUrl := strings.Replace(allUrls[i%numPolicies], "*", fmt.Sprint(i), 1)
|
||||
idp, err := cache.GetIdentityProviderForRequestURL(context.Background(), options, reqUrl)
|
||||
require.NoError(b, err)
|
||||
require.Equal(b, fmt.Sprintf("client_id_%d", i%numPolicies), idp.ClientId)
|
||||
require.Equal(b, fmt.Sprintf("client_secret_%d", i%numPolicies), idp.ClientSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_DomainMatchOnly(b *testing.B) {
|
||||
domainMatchingOnly := func(i int, p *config.Policy) string {
|
||||
p.From = fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i)
|
||||
return p.From
|
||||
}
|
||||
|
||||
b.Run("5 policies (domain matching only)", bench(domainMatchingOnly, 5))
|
||||
b.Run("50 policies (domain matching only)", bench(domainMatchingOnly, 50))
|
||||
b.Run("500 policies (domain matching only)", bench(domainMatchingOnly, 500))
|
||||
b.Run("5000 policies (domain matching only)", bench(domainMatchingOnly, 5000))
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_PathMatchOnly(b *testing.B) {
|
||||
pathMatchingOnly := func(i int, p *config.Policy) string {
|
||||
p.From = "https://example.com"
|
||||
p.Path = fmt.Sprintf("/foo/bar/path%d", i)
|
||||
return p.From + p.Path
|
||||
}
|
||||
b.Run("5 policies (path matching only)", bench(pathMatchingOnly, 5))
|
||||
b.Run("50 policies (path matching only)", bench(pathMatchingOnly, 50))
|
||||
b.Run("500 policies (path matching only)", bench(pathMatchingOnly, 500))
|
||||
b.Run("5000 policies (path matching only)", bench(pathMatchingOnly, 5000))
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPathMatching(b *testing.B) {
|
||||
combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string {
|
||||
// returns a sequence of policies (ex: numPathsPerDomain=3)
|
||||
// https://*.foo.bar.test-0.example.com
|
||||
// https://*.foo.bar.test-0.example.com/foo/bar/path1
|
||||
// https://*.foo.bar.test-0.example.com/foo/bar/path2
|
||||
// https://*.foo.bar.test-1.example.com
|
||||
// https://*.foo.bar.test-1.example.com/foo/bar/path1
|
||||
// https://*.foo.bar.test-1.example.com/foo/bar/path2
|
||||
return func(i int, p *config.Policy) string {
|
||||
domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain)
|
||||
pathIdx := i % numPathsPerDomain
|
||||
var path string
|
||||
if pathIdx == 0 {
|
||||
path = ""
|
||||
} else {
|
||||
path = fmt.Sprintf("/foo/bar/path%d", pathIdx)
|
||||
}
|
||||
p.From = domain
|
||||
p.Path = path
|
||||
return domain + path
|
||||
}
|
||||
}
|
||||
|
||||
b.Run("25 policies (5 domains, 5 paths per domain)", bench(combinedMatching(5), 25))
|
||||
b.Run("500 policies (50 domains, 10 paths per domain)", bench(combinedMatching(10), 500))
|
||||
b.Run("500 policies (10 domains, 50 paths per domain)", bench(combinedMatching(50), 500))
|
||||
b.Run("2500 policies (50 domains, 50 paths per domain)", bench(combinedMatching(50), 2500))
|
||||
b.Run("5000 policies (100 domains, 50 paths per domain)", bench(combinedMatching(50), 5000))
|
||||
b.Run("5000 policies (50 domains, 100 paths per domain)", bench(combinedMatching(100), 5000))
|
||||
b.Run("10000 policies (100 domains, 100 paths per domain)", bench(combinedMatching(100), 10000))
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPrefixMatching(b *testing.B) {
|
||||
combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string {
|
||||
// returns a sequence of policies (ex: numPathsPerDomain=3)
|
||||
// https://*.foo.bar.test-0.example.com/0
|
||||
// https://*.foo.bar.test-0.example.com/0/1
|
||||
// https://*.foo.bar.test-0.example.com/0/1/2
|
||||
// https://*.foo.bar.test-1.example.com/0
|
||||
// https://*.foo.bar.test-1.example.com/0/1
|
||||
// https://*.foo.bar.test-1.example.com/0/1/2
|
||||
return func(i int, p *config.Policy) string {
|
||||
domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain)
|
||||
pathIdx := i % numPathsPerDomain
|
||||
var prefix strings.Builder
|
||||
for i := range pathIdx + 1 {
|
||||
fmt.Fprintf(&prefix, "/%d", i)
|
||||
}
|
||||
p.From = domain
|
||||
p.Prefix = prefix.String()
|
||||
return domain + p.Prefix
|
||||
}
|
||||
}
|
||||
|
||||
b.Run("25 policies (5 domains, 5 paths per domain)", bench(combinedMatching(5), 25))
|
||||
b.Run("500 policies (50 domains, 10 paths per domain)", bench(combinedMatching(10), 500))
|
||||
b.Run("500 policies (10 domains, 50 paths per domain)", bench(combinedMatching(50), 500))
|
||||
b.Run("2500 policies (50 domains, 50 paths per domain)", bench(combinedMatching(50), 2500))
|
||||
b.Run("5000 policies (100 domains, 50 paths per domain)", bench(combinedMatching(50), 5000))
|
||||
b.Run("5000 policies (50 domains, 100 paths per domain)", bench(combinedMatching(100), 5000))
|
||||
b.Run("10000 policies (100 domains, 100 paths per domain)", bench(combinedMatching(100), 10000))
|
||||
}
|
||||
|
||||
func mustParseWeightedURLs(t testing.TB, urls ...string) []config.WeightedURL {
|
||||
wu, err := config.ParseWeightedUrls(urls...)
|
||||
require.NoError(t, err)
|
||||
return wu
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
|
@ -10,20 +11,27 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
// A SessionStore saves and loads sessions based on the options.
|
||||
type SessionStore struct {
|
||||
options *Options
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
loader sessions.SessionLoader
|
||||
options *Options
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
loader sessions.SessionLoader
|
||||
policyCache *PolicyCache
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new SessionStore from the Options.
|
||||
func NewSessionStore(options *Options) (*SessionStore, error) {
|
||||
cache, err := NewPolicyCache(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
store := &SessionStore{
|
||||
options: options,
|
||||
options: options,
|
||||
policyCache: cache,
|
||||
}
|
||||
|
||||
sharedKey, err := options.GetSharedKey()
|
||||
|
@ -57,8 +65,11 @@ func NewSessionStore(options *Options) (*SessionStore, error) {
|
|||
}
|
||||
|
||||
// LoadSessionState loads the session state from a request.
|
||||
func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, error) {
|
||||
rawJWT, err := store.loader.LoadSession(r)
|
||||
func (store *SessionStore) LoadSessionState(ctx context.Context, r *http.Request) (*sessions.State, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "session_store.load_session_state")
|
||||
defer span.End()
|
||||
|
||||
rawJWT, err := store.loader.LoadSession(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -71,7 +82,7 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e
|
|||
|
||||
// confirm that the identity provider id matches the state
|
||||
if state.IdentityProviderID != "" {
|
||||
idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
||||
idp, err := store.policyCache.GetIdentityProviderForRequestURL(ctx, store.options, urlutil.GetAbsoluteURL(r).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -70,7 +71,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
t.Run("mssing", func(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
s, err := store.LoadSessionState(r)
|
||||
s, err := store.LoadSessionState(context.TODO(), r)
|
||||
assert.ErrorIs(t, err, sessions.ErrNoSessionFound)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
|
@ -85,7 +86,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
urlutil.QuerySession: {rawJWS},
|
||||
}.Encode(), nil)
|
||||
require.NoError(t, err)
|
||||
s, err := store.LoadSessionState(r)
|
||||
s, err := store.LoadSessionState(context.TODO(), r)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(&sessions.State{
|
||||
Issuer: "authenticate.example.com",
|
||||
|
@ -103,7 +104,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
|
||||
s, err := store.LoadSessionState(r)
|
||||
s, err := store.LoadSessionState(context.TODO(), r)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(&sessions.State{
|
||||
Issuer: "authenticate.example.com",
|
||||
|
@ -121,7 +122,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
|
||||
s, err := store.LoadSessionState(r)
|
||||
s, err := store.LoadSessionState(context.TODO(), r)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
|
@ -134,7 +135,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
|
||||
s, err := store.LoadSessionState(r)
|
||||
s, err := store.LoadSessionState(context.TODO(), r)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(&sessions.State{
|
||||
Issuer: "authenticate.example.com",
|
||||
|
|
2
go.mod
2
go.mod
|
@ -36,6 +36,7 @@ require (
|
|||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/jackc/pgx/v5 v5.6.0
|
||||
github.com/klauspost/compress v1.17.8
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31
|
||||
github.com/martinlindhe/base36 v1.1.1
|
||||
github.com/mholt/acmez/v2 v2.0.1
|
||||
github.com/minio/minio-go/v7 v7.0.70
|
||||
|
@ -104,6 +105,7 @@ require (
|
|||
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
|
||||
github.com/OneOfOne/xxhash v1.2.8 // indirect
|
||||
github.com/agnivade/levenshtein v1.1.1 // indirect
|
||||
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f // indirect
|
||||
github.com/andybalholm/brotli v1.0.5 // indirect
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -74,6 +74,8 @@ github.com/VictoriaMetrics/fastcache v1.12.2 h1:N0y9ASrJ0F6h0QaC3o6uJb3NIZ9VKLjC
|
|||
github.com/VictoriaMetrics/fastcache v1.12.2/go.mod h1:AmC+Nzz1+3G2eCPapF6UcsnkThDcMsQicp4xDukwJYI=
|
||||
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
|
||||
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
|
||||
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f h1:z8MkSJCUyTmW5YQlxsMLBlwA7GmjxC7L4ooicxqnhz8=
|
||||
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f/go.mod h1:UdUwYgAXBiL+kLfcqxoQJYkHA/vl937/PbFhZM34aZs=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||
|
@ -414,6 +416,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31 h1:cOZUoNuv9OuCZKepddeIW87ScJWwveLfLcJMaii6YCA=
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss=
|
||||
github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
|
||||
github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package cookie
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -117,7 +118,7 @@ func getCookies(r *http.Request, name string) []*http.Cookie {
|
|||
}
|
||||
|
||||
// LoadSession returns a State from the cookie in the request.
|
||||
func (cs *Store) LoadSession(r *http.Request) (string, error) {
|
||||
func (cs *Store) LoadSession(ctx context.Context, r *http.Request) (string, error) {
|
||||
opts := cs.getOptions()
|
||||
cookies := getCookies(r, opts.Name)
|
||||
if len(cookies) == 0 {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cookie
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -145,7 +146,7 @@ func TestStore_SaveSession(t *testing.T) {
|
|||
r.AddCookie(cookie)
|
||||
}
|
||||
|
||||
jwt, err := s.LoadSession(r)
|
||||
jwt, err := s.LoadSession(context.TODO(), r)
|
||||
if (err != nil) != tt.wantLoadErr {
|
||||
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package header
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
@ -31,7 +32,7 @@ func NewStore(enc encoding.Unmarshaler) *Store {
|
|||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from the Authorization header.
|
||||
func (as *Store) LoadSession(r *http.Request) (string, error) {
|
||||
func (as *Store) LoadSession(_ context.Context, r *http.Request) (string, error) {
|
||||
jwt := TokenFromHeaders(r)
|
||||
if jwt == "" {
|
||||
return "", sessions.ErrNoSessionFound
|
||||
|
|
|
@ -23,7 +23,7 @@ func retrieve(s SessionLoader) func(http.Handler) http.Handler {
|
|||
return func(next http.Handler) http.Handler {
|
||||
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jwt, err := s.LoadSession(r)
|
||||
jwt, err := s.LoadSession(ctx, r)
|
||||
ctx = NewContext(ctx, jwt, err)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
|
@ -30,7 +31,7 @@ func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) {
|
|||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms Store) LoadSession(*http.Request) (string, error) {
|
||||
func (ms Store) LoadSession(context.Context, *http.Request) (string, error) {
|
||||
var signer encoding.MarshalUnmarshaler
|
||||
signer, _ = jws.NewHS256Signer(ms.Secret)
|
||||
jwt, _ := signer.Marshal(ms.Session)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -40,7 +41,7 @@ func TestStore(t *testing.T) {
|
|||
t.Errorf("mockstore.SaveSession() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
|
||||
return
|
||||
}
|
||||
got, err := ms.LoadSession(nil)
|
||||
got, err := ms.LoadSession(context.TODO(), nil)
|
||||
if (err != nil) != tt.wantLoadErr {
|
||||
t.Errorf("mockstore.LoadSession() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package queryparam
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
|
@ -43,7 +44,7 @@ func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store {
|
|||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from URL query parameters.
|
||||
func (qp *Store) LoadSession(r *http.Request) (string, error) {
|
||||
func (qp *Store) LoadSession(_ context.Context, r *http.Request) (string, error) {
|
||||
jwt := r.URL.Query().Get(qp.queryParamKey)
|
||||
if jwt == "" {
|
||||
return "", sessions.ErrNoSessionFound
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
@ -16,14 +17,14 @@ type SessionStore interface {
|
|||
|
||||
// SessionLoader defines an interface for loading a session.
|
||||
type SessionLoader interface {
|
||||
LoadSession(*http.Request) (string, error)
|
||||
LoadSession(context.Context, *http.Request) (string, error)
|
||||
}
|
||||
|
||||
type multiSessionLoader []SessionLoader
|
||||
|
||||
func (l multiSessionLoader) LoadSession(r *http.Request) (string, error) {
|
||||
func (l multiSessionLoader) LoadSession(ctx context.Context, r *http.Request) (string, error) {
|
||||
for _, ll := range l {
|
||||
s, err := ll.LoadSession(r)
|
||||
s, err := ll.LoadSession(ctx, r)
|
||||
if errors.Is(err, ErrNoSessionFound) {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
"github.com/akamensky/base58"
|
||||
)
|
||||
|
||||
// A KeyEncryptionKey (KEK) is used to implement *envelope encryption*, similar to how data is stored at rest with
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
"github.com/akamensky/base58"
|
||||
)
|
||||
|
||||
// TokenLength is the length of a token.
|
||||
|
@ -27,7 +27,7 @@ func NewRandomToken() (tok Token) {
|
|||
|
||||
// TokenFromString parses a base58-encoded string into a token.
|
||||
func TokenFromString(rawstr string) (tok Token, ok bool) {
|
||||
result := base58.Decode(rawstr)
|
||||
result, _ := base58.Decode(rawstr)
|
||||
if len(result) != TokenLength {
|
||||
return tok, false
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ type SecretToken struct {
|
|||
|
||||
// SecretTokenFromString parses a base58-encoded string into a secret token.
|
||||
func SecretTokenFromString(rawstr string) (tok SecretToken, ok bool) {
|
||||
result := base58.Decode(rawstr)
|
||||
result, _ := base58.Decode(rawstr)
|
||||
if len(result) != TokenLength*2 {
|
||||
return tok, false
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ func GenerateSecureToken(key []byte, expiry time.Time, token Token) SecureToken
|
|||
|
||||
// SecureTokenFromString parses a base58-encoded string into a SecureToken.
|
||||
func SecureTokenFromString(rawstr string) (secureToken SecureToken, ok bool) {
|
||||
result := base58.Decode(rawstr)
|
||||
result, _ := base58.Decode(rawstr)
|
||||
if len(result) != SecureTokenLength {
|
||||
return secureToken, false
|
||||
}
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
// Copyright (c) 2015 The btcsuite developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
//
|
||||
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
|
||||
|
||||
package base58
|
||||
|
||||
const (
|
||||
// alphabet is the modified base58 alphabet used by Bitcoin.
|
||||
alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
|
||||
alphabetIdx0 = '1'
|
||||
)
|
||||
|
||||
var b58 = [256]byte{
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 0, 1, 2, 3, 4, 5, 6,
|
||||
7, 8, 255, 255, 255, 255, 255, 255,
|
||||
255, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 255, 17, 18, 19, 20, 21, 255,
|
||||
22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 255, 255, 255, 255, 255,
|
||||
255, 33, 34, 35, 36, 37, 38, 39,
|
||||
40, 41, 42, 43, 255, 44, 45, 46,
|
||||
47, 48, 49, 50, 51, 52, 53, 54,
|
||||
55, 56, 57, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255,
|
||||
}
|
|
@ -1,75 +0,0 @@
|
|||
// Copyright (c) 2013-2015 The btcsuite developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
//
|
||||
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
|
||||
|
||||
package base58
|
||||
|
||||
import "math/big"
|
||||
|
||||
var (
|
||||
bigRadix = big.NewInt(58)
|
||||
bigZero = big.NewInt(0)
|
||||
)
|
||||
|
||||
// Decode decodes a modified base58 string to a byte slice.
|
||||
func Decode(b string) []byte {
|
||||
answer := big.NewInt(0)
|
||||
j := big.NewInt(1)
|
||||
|
||||
scratch := new(big.Int)
|
||||
for i := len(b) - 1; i >= 0; i-- {
|
||||
tmp := b58[b[i]]
|
||||
if tmp == 255 {
|
||||
return []byte("")
|
||||
}
|
||||
scratch.SetInt64(int64(tmp))
|
||||
scratch.Mul(j, scratch)
|
||||
answer.Add(answer, scratch)
|
||||
j.Mul(j, bigRadix)
|
||||
}
|
||||
|
||||
tmpval := answer.Bytes()
|
||||
|
||||
var numZeros int
|
||||
for numZeros = 0; numZeros < len(b); numZeros++ {
|
||||
if b[numZeros] != alphabetIdx0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
flen := numZeros + len(tmpval)
|
||||
val := make([]byte, flen)
|
||||
copy(val[numZeros:], tmpval)
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// Encode encodes a byte slice to a modified base58 string.
|
||||
func Encode(b []byte) string {
|
||||
x := new(big.Int)
|
||||
x.SetBytes(b)
|
||||
|
||||
answer := make([]byte, 0, len(b)*136/100)
|
||||
for x.Cmp(bigZero) > 0 {
|
||||
mod := new(big.Int)
|
||||
x.DivMod(x, bigRadix, mod)
|
||||
answer = append(answer, alphabet[mod.Int64()])
|
||||
}
|
||||
|
||||
// leading zero bytes
|
||||
for _, i := range b {
|
||||
if i != 0 {
|
||||
break
|
||||
}
|
||||
answer = append(answer, alphabetIdx0)
|
||||
}
|
||||
|
||||
// reverse
|
||||
alen := len(answer)
|
||||
for i := 0; i < alen/2; i++ {
|
||||
answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i]
|
||||
}
|
||||
|
||||
return string(answer)
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
// Copyright (c) 2013-2017 The btcsuite developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
//
|
||||
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
|
||||
|
||||
package base58_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
)
|
||||
|
||||
var stringTests = []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", "Z"},
|
||||
{"-", "n"},
|
||||
{"0", "q"},
|
||||
{"1", "r"},
|
||||
{"-1", "4SU"},
|
||||
{"11", "4k8"},
|
||||
{"abc", "ZiCa"},
|
||||
{"1234598760", "3mJr7AoUXx2Wqd"},
|
||||
{"abcdefghijklmnopqrstuvwxyz", "3yxU3u1igY8WkgtjK92fbJQCd4BZiiT1v25f"},
|
||||
{"00000000000000000000000000000000000000000000000000000000000000", "3sN2THZeE9Eh9eYrwkvZqNstbHGvrxSAM7gXUXvyFQP8XvQLUqNCS27icwUeDT7ckHm4FUHM2mTVh1vbLmk7y"},
|
||||
}
|
||||
|
||||
var invalidStringTests = []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"0", ""},
|
||||
{"O", ""},
|
||||
{"I", ""},
|
||||
{"l", ""},
|
||||
{"3mJr0", ""},
|
||||
{"O3yxU", ""},
|
||||
{"3sNI", ""},
|
||||
{"4kl8", ""},
|
||||
{"0OIl", ""},
|
||||
{"!@#$%^&*()-_=+~`", ""},
|
||||
}
|
||||
|
||||
var hexTests = []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"61", "2g"},
|
||||
{"626262", "a3gV"},
|
||||
{"636363", "aPEr"},
|
||||
{"73696d706c792061206c6f6e6720737472696e67", "2cFupjhnEsSn59qHXstmK2ffpLv2"},
|
||||
{"00eb15231dfceb60925886b67d065299925915aeb172c06647", "1NS17iag9jJgTHD1VXjvLCEnZuQ3rJDE9L"},
|
||||
{"516b6fcd0f", "ABnLTmg"},
|
||||
{"bf4f89001e670274dd", "3SEo3LWLoPntC"},
|
||||
{"572e4794", "3EFU7m"},
|
||||
{"ecac89cad93923c02321", "EJDM8drfXA6uyA"},
|
||||
{"10c8511e", "Rt5zm"},
|
||||
{"00000000000000000000", "1111111111"},
|
||||
}
|
||||
|
||||
func TestBase58(t *testing.T) {
|
||||
// Encode tests
|
||||
for x, test := range stringTests {
|
||||
tmp := []byte(test.in)
|
||||
if res := base58.Encode(tmp); res != test.out {
|
||||
t.Errorf("Encode test #%d failed: got: %s want: %s",
|
||||
x, res, test.out)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Decode tests
|
||||
for x, test := range hexTests {
|
||||
b, err := hex.DecodeString(test.in)
|
||||
if err != nil {
|
||||
t.Errorf("hex.DecodeString failed failed #%d: got: %s", x, test.in)
|
||||
continue
|
||||
}
|
||||
if res := base58.Decode(test.out); !bytes.Equal(res, b) {
|
||||
t.Errorf("Decode test #%d failed: got: %q want: %q",
|
||||
x, res, test.in)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Decode with invalid input
|
||||
for x, test := range invalidStringTests {
|
||||
if res := base58.Decode(test.in); string(res) != test.out {
|
||||
t.Errorf("Decode invalidString test #%d failed: got: %q want: %q",
|
||||
x, res, test.out)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
// Copyright (c) 2014 The btcsuite developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
//
|
||||
// See: https://github.com/btcsuite/btcutil/blob/master/LICENSE
|
||||
|
||||
/*
|
||||
Package base58 provides an API for working with modified base58 and Base58Check
|
||||
encodings.
|
||||
|
||||
# Modified Base58 Encoding
|
||||
|
||||
Standard base58 encoding is similar to standard base64 encoding except, as the
|
||||
name implies, it uses a 58 character alphabet which results in an alphanumeric
|
||||
string and allows some characters which are problematic for humans to be
|
||||
excluded. Due to this, there can be various base58 alphabets.
|
||||
|
||||
The modified base58 alphabet used by Bitcoin, and hence this package, omits the
|
||||
0, O, I, and l characters that look the same in many fonts and are therefore
|
||||
hard to humans to distinguish.
|
||||
*/
|
||||
package base58
|
|
@ -9,7 +9,7 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
"github.com/akamensky/base58"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
"github.com/akamensky/base58"
|
||||
)
|
||||
|
||||
// Clone clones the Provider.
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
"github.com/akamensky/base58"
|
||||
)
|
||||
|
||||
const headerName = "x-request-id"
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
"github.com/akamensky/base58"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/device"
|
||||
"github.com/pomerium/webauthn"
|
||||
|
|
|
@ -34,7 +34,7 @@ func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Se
|
|||
func (p *Proxy) getSessionState(r *http.Request) (sessions.State, error) {
|
||||
state := p.state.Load()
|
||||
|
||||
rawJWT, err := state.sessionStore.LoadSession(r)
|
||||
rawJWT, err := state.sessionStore.LoadSession(r.Context(), r)
|
||||
if err != nil {
|
||||
return sessions.State{}, err
|
||||
}
|
||||
|
|
|
@ -113,7 +113,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
|
|||
return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri"))
|
||||
}
|
||||
|
||||
idp, err := options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
||||
idp, err := p.policyCache.Load().GetIdentityProviderForRequestURL(r.Context(), options, urlutil.GetAbsoluteURL(r).String())
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
|
|
@ -56,6 +56,7 @@ type Proxy struct {
|
|||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentRouter *atomicutil.Value[*mux.Router]
|
||||
webauthn *webauthn.Handler
|
||||
policyCache *atomicutil.Value[*config.PolicyCache]
|
||||
}
|
||||
|
||||
// New takes a Proxy service from options and a validation function.
|
||||
|
@ -66,10 +67,15 @@ func New(cfg *config.Config) (*Proxy, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cache, err := config.NewPolicyCache(cfg.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &Proxy{
|
||||
state: atomicutil.NewValue(state),
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||
policyCache: atomicutil.NewValue(cache),
|
||||
}
|
||||
p.webauthn = webauthn.New(p.getWebauthnState)
|
||||
|
||||
|
@ -91,6 +97,9 @@ func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) {
|
|||
return
|
||||
}
|
||||
|
||||
if cache, err := config.NewPolicyCache(cfg.Options); err == nil {
|
||||
p.policyCache.Store(cache)
|
||||
}
|
||||
p.currentOptions.Store(cfg.Options)
|
||||
if err := p.setHandlers(cfg.Options); err != nil {
|
||||
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue