mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-04 18:38:12 +02:00
Merge remote-tracking branch 'origin/main' into cdoxsey/fix-flaky-test
This commit is contained in:
commit
21fe0d66b4
6 changed files with 137 additions and 41 deletions
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/errgrouputil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
|
@ -139,12 +141,20 @@ func New(
|
|||
return e, nil
|
||||
}
|
||||
|
||||
type routeEvaluator struct {
|
||||
id uint64
|
||||
evaluator *PolicyEvaluator
|
||||
}
|
||||
|
||||
func getOrCreatePolicyEvaluators(
|
||||
ctx context.Context, cfg *evaluatorConfig, store *store.Store,
|
||||
cachedPolicyEvaluators map[uint64]*PolicyEvaluator,
|
||||
) (map[uint64]*PolicyEvaluator, error) {
|
||||
var newCount, reusedCount int
|
||||
now := time.Now()
|
||||
|
||||
var reusedCount int
|
||||
m := make(map[uint64]*PolicyEvaluator)
|
||||
var builders []errgrouputil.BuilderFunc[routeEvaluator]
|
||||
for i := range cfg.Policies {
|
||||
configPolicy := cfg.Policies[i]
|
||||
id, err := configPolicy.RouteID()
|
||||
|
@ -157,15 +167,35 @@ func getOrCreatePolicyEvaluators(
|
|||
reusedCount++
|
||||
continue
|
||||
}
|
||||
policyEvaluator, err :=
|
||||
NewPolicyEvaluator(ctx, store, &configPolicy, cfg.AddDefaultClientCertificateRule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[id] = policyEvaluator
|
||||
newCount++
|
||||
builders = append(builders, func(ctx context.Context) (*routeEvaluator, error) {
|
||||
evaluator, err := NewPolicyEvaluator(ctx, store, &configPolicy, cfg.AddDefaultClientCertificateRule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error building evaluator for route id=%s: %w", configPolicy.ID, err)
|
||||
}
|
||||
return &routeEvaluator{
|
||||
id: id,
|
||||
evaluator: evaluator,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
log.Info(ctx).Msgf("updated policy evaluators: %d created, %d reused", newCount, reusedCount)
|
||||
|
||||
evals, errs := errgrouputil.Build(ctx, builders...)
|
||||
if len(errs) > 0 {
|
||||
for _, err := range errs {
|
||||
log.Error(ctx).Msg(err.Error())
|
||||
}
|
||||
return nil, fmt.Errorf("authorize: error building policy evaluators")
|
||||
}
|
||||
|
||||
for _, p := range evals {
|
||||
m[p.id] = p.evaluator
|
||||
}
|
||||
|
||||
log.Info(ctx).
|
||||
Dur("duration", time.Since(now)).
|
||||
Int("reused-policies", reusedCount).
|
||||
Int("created-policies", len(cfg.Policies)-reusedCount).
|
||||
Msg("updated policy evaluators")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -24,9 +24,11 @@ func TestLuaCleanUpstream(t *testing.T) {
|
|||
"context-type": "text/plain",
|
||||
"authorization": "Pomerium JWT",
|
||||
"x-pomerium-authorization": "JWT",
|
||||
"cookie": "cookieA=aaa_pomerium=123; cookieb=bbb; _pomerium=ey;_pomerium_test1=stillhere ; _pomerium_test2=stillhere",
|
||||
}
|
||||
metadata := map[string]interface{}{
|
||||
"remove_pomerium_authorization": true,
|
||||
"remove_pomerium_cookie": "_pomerium",
|
||||
}
|
||||
dynamicMetadata := map[string]map[string]interface{}{}
|
||||
handle := newLuaResponseHandle(L, headers, metadata, dynamicMetadata)
|
||||
|
@ -40,6 +42,7 @@ func TestLuaCleanUpstream(t *testing.T) {
|
|||
|
||||
assert.Equal(t, map[string]string{
|
||||
"context-type": "text/plain",
|
||||
"cookie": "cookieA=aaa_pomerium=123; cookieb=bbb; _pomerium_test1=stillhere ; _pomerium_test2=stillhere",
|
||||
}, headers)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,15 +1,23 @@
|
|||
function remove_pomerium_cookie(cookie_name, cookie)
|
||||
-- lua doesn't support optional capture groups
|
||||
-- so we replace twice to handle pomerium=xyz at the end of the string
|
||||
cookie = cookie:gsub(cookie_name .. "=[^;]+; ", "")
|
||||
cookie = cookie:gsub(cookie_name .. "=[^;]+", "")
|
||||
return cookie
|
||||
end
|
||||
|
||||
function has_prefix(str, prefix)
|
||||
return str ~= nil and str:sub(1, #prefix) == prefix
|
||||
end
|
||||
|
||||
function remove_pomerium_cookie(cookie_name, cookie)
|
||||
local result = ""
|
||||
for c in cookie:gmatch("([^;]+)") do
|
||||
c = c:gsub("^ +","")
|
||||
local name = c:match("^([^=]+)")
|
||||
if name ~= cookie_name then
|
||||
if string.len(result) > 0 then
|
||||
result = result .. "; " .. c
|
||||
else
|
||||
result = result .. c
|
||||
end
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function envoy_on_request(request_handle)
|
||||
local headers = request_handle:headers()
|
||||
local metadata = request_handle:metadata()
|
||||
|
@ -18,7 +26,7 @@ function envoy_on_request(request_handle)
|
|||
if remove_cookie_name then
|
||||
local cookie = headers:get("cookie")
|
||||
if cookie ~= nil then
|
||||
newcookie = remove_pomerium_cookie(remove_cookie_name, cookie)
|
||||
local newcookie = remove_pomerium_cookie(remove_cookie_name, cookie)
|
||||
headers:replace("cookie", newcookie)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -75,7 +75,7 @@
|
|||
"typedConfig": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.lua.v3.Lua",
|
||||
"defaultSourceCode": {
|
||||
"inlineString": "function remove_pomerium_cookie(cookie_name, cookie)\n -- lua doesn't support optional capture groups\n -- so we replace twice to handle pomerium=xyz at the end of the string\n cookie = cookie:gsub(cookie_name .. \"=[^;]+; \", \"\")\n cookie = cookie:gsub(cookie_name .. \"=[^;]+\", \"\")\n return cookie\nend\n\nfunction has_prefix(str, prefix)\n return str ~= nil and str:sub(1, #prefix) == prefix\nend\n\nfunction envoy_on_request(request_handle)\n local headers = request_handle:headers()\n local metadata = request_handle:metadata()\n\n local remove_cookie_name = metadata:get(\"remove_pomerium_cookie\")\n if remove_cookie_name then\n local cookie = headers:get(\"cookie\")\n if cookie ~= nil then\n newcookie = remove_pomerium_cookie(remove_cookie_name, cookie)\n headers:replace(\"cookie\", newcookie)\n end\n end\n\n local remove_authorization = metadata:get(\"remove_pomerium_authorization\")\n if remove_authorization then\n local authorization = headers:get(\"authorization\")\n local authorization_prefix = \"Pomerium \"\n if has_prefix(authorization, authorization_prefix) then\n headers:remove(\"authorization\")\n end\n\n headers:remove('x-pomerium-authorization')\n end\nend\n\nfunction envoy_on_response(response_handle) end\n"
|
||||
"inlineString": "function has_prefix(str, prefix)\n return str ~= nil and str:sub(1, #prefix) == prefix\nend\n\nfunction remove_pomerium_cookie(cookie_name, cookie)\n local result = \"\"\n for c in cookie:gmatch(\"([^;]+)\") do\n c = c:gsub(\"^ +\",\"\")\n local name = c:match(\"^([^=]+)\")\n if name ~= cookie_name then\n if string.len(result) \u003e 0 then\n result = result .. \"; \" .. c\n else\n result = result .. c\n end\n end\n end\n return result\nend\n\nfunction envoy_on_request(request_handle)\n local headers = request_handle:headers()\n local metadata = request_handle:metadata()\n\n local remove_cookie_name = metadata:get(\"remove_pomerium_cookie\")\n if remove_cookie_name then\n local cookie = headers:get(\"cookie\")\n if cookie ~= nil then\n local newcookie = remove_pomerium_cookie(remove_cookie_name, cookie)\n headers:replace(\"cookie\", newcookie)\n end\n end\n\n local remove_authorization = metadata:get(\"remove_pomerium_authorization\")\n if remove_authorization then\n local authorization = headers:get(\"authorization\")\n local authorization_prefix = \"Pomerium \"\n if has_prefix(authorization, authorization_prefix) then\n headers:remove(\"authorization\")\n end\n\n headers:remove('x-pomerium-authorization')\n end\nend\n\nfunction envoy_on_response(response_handle) end\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -3,7 +3,6 @@ package databroker
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -12,6 +11,7 @@ import (
|
|||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/errgrouputil"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
@ -115,7 +115,6 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
|
|||
|
||||
func (src *ConfigSource) buildNewConfigLocked(ctx context.Context, cfg *config.Config) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.SetLimit(runtime.NumCPU()/2 + 1)
|
||||
eg.Go(func() error {
|
||||
src.applySettingsLocked(ctx, cfg)
|
||||
err := cfg.Options.Validate()
|
||||
|
@ -125,30 +124,32 @@ func (src *ConfigSource) buildNewConfigLocked(ctx context.Context, cfg *config.C
|
|||
return nil
|
||||
})
|
||||
|
||||
var policies []*config.Policy
|
||||
var builders []func() error
|
||||
buildPolicy := func(i int, routepb *configpb.Route) func() error {
|
||||
return func() error {
|
||||
policy, err := src.buildPolicyFromProto(routepb)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Err(err).Msg("databroker: error building policy from protobuf")
|
||||
return nil
|
||||
}
|
||||
policies[i] = policy
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var policyBuilders []errgrouputil.BuilderFunc[config.Policy]
|
||||
for _, cfgpb := range src.dbConfigs {
|
||||
for _, routepb := range cfgpb.GetRoutes() {
|
||||
builders = append(builders, buildPolicy(len(builders), routepb))
|
||||
routepb := routepb
|
||||
policyBuilders = append(policyBuilders, func(ctx context.Context) (*config.Policy, error) {
|
||||
p, err := src.buildPolicyFromProto(ctx, routepb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error building route id=%s: %w", routepb.GetId(), err)
|
||||
}
|
||||
return p, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
policies = make([]*config.Policy, len(builders))
|
||||
for _, builder := range builders {
|
||||
eg.Go(builder)
|
||||
}
|
||||
var policies []*config.Policy
|
||||
eg.Go(func() error {
|
||||
var errs []error
|
||||
policies, errs = errgrouputil.Build(ctx, policyBuilders...)
|
||||
if len(errs) > 0 {
|
||||
for _, err := range errs {
|
||||
log.Error(ctx).Msg(err.Error())
|
||||
}
|
||||
return fmt.Errorf("error building policies")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
|
@ -177,7 +178,7 @@ func (src *ConfigSource) applySettingsLocked(ctx context.Context, cfg *config.Co
|
|||
}
|
||||
}
|
||||
|
||||
func (src *ConfigSource) buildPolicyFromProto(routepb *configpb.Route) (*config.Policy, error) {
|
||||
func (src *ConfigSource) buildPolicyFromProto(_ context.Context, routepb *configpb.Route) (*config.Policy, error) {
|
||||
policy, err := config.NewPolicyFromProto(routepb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error building policy from protobuf: %w", err)
|
||||
|
|
54
internal/errgrouputil/builder.go
Normal file
54
internal/errgrouputil/builder.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
// Package errgrouputil contains methods for working with errgroup code.
|
||||
package errgrouputil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/slices"
|
||||
)
|
||||
|
||||
// BuilderFunc is a function that builds a value of type T
|
||||
type BuilderFunc[T any] func(ctx context.Context) (*T, error)
|
||||
|
||||
// Build builds a slice of values of type T using the provided builders concurrently
|
||||
// and returns the results and any errors.
|
||||
func Build[T any](
|
||||
ctx context.Context,
|
||||
builders ...BuilderFunc[T],
|
||||
) ([]*T, []error) {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.SetLimit(runtime.NumCPU()/2 + 1)
|
||||
|
||||
results := make([]*T, len(builders))
|
||||
errors := make([]error, len(builders))
|
||||
|
||||
fn := func(i int) func() error {
|
||||
return func() error {
|
||||
result, err := builders[i](ctx)
|
||||
if err != nil {
|
||||
errors[i] = err
|
||||
return nil
|
||||
}
|
||||
results[i] = result
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
for i := range builders {
|
||||
eg.Go(fn(i))
|
||||
}
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
return nil, []error{err} // not happening
|
||||
}
|
||||
|
||||
return slices.Filter(results, func(t *T) bool {
|
||||
return t != nil
|
||||
}), slices.Filter(errors, func(err error) bool {
|
||||
return err != nil
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue