diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 0ca5f9a7e..5a047e8c4 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "time" "github.com/go-jose/go-jose/v3" "github.com/open-policy-agent/opa/rego" @@ -14,6 +15,7 @@ import ( "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/errgrouputil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/trace" @@ -139,12 +141,20 @@ func New( return e, nil } +type routeEvaluator struct { + id uint64 + evaluator *PolicyEvaluator +} + func getOrCreatePolicyEvaluators( ctx context.Context, cfg *evaluatorConfig, store *store.Store, cachedPolicyEvaluators map[uint64]*PolicyEvaluator, ) (map[uint64]*PolicyEvaluator, error) { - var newCount, reusedCount int + now := time.Now() + + var reusedCount int m := make(map[uint64]*PolicyEvaluator) + var builders []errgrouputil.BuilderFunc[routeEvaluator] for i := range cfg.Policies { configPolicy := cfg.Policies[i] id, err := configPolicy.RouteID() @@ -157,15 +167,35 @@ func getOrCreatePolicyEvaluators( reusedCount++ continue } - policyEvaluator, err := - NewPolicyEvaluator(ctx, store, &configPolicy, cfg.AddDefaultClientCertificateRule) - if err != nil { - return nil, err - } - m[id] = policyEvaluator - newCount++ + builders = append(builders, func(ctx context.Context) (*routeEvaluator, error) { + evaluator, err := NewPolicyEvaluator(ctx, store, &configPolicy, cfg.AddDefaultClientCertificateRule) + if err != nil { + return nil, fmt.Errorf("authorize: error building evaluator for route id=%s: %w", configPolicy.ID, err) + } + return &routeEvaluator{ + id: id, + evaluator: evaluator, + }, nil + }) } - log.Info(ctx).Msgf("updated policy evaluators: %d created, %d reused", newCount, reusedCount) + + evals, errs := errgrouputil.Build(ctx, builders...) + if len(errs) > 0 { + for _, err := range errs { + log.Error(ctx).Msg(err.Error()) + } + return nil, fmt.Errorf("authorize: error building policy evaluators") + } + + for _, p := range evals { + m[p.id] = p.evaluator + } + + log.Info(ctx). + Dur("duration", time.Since(now)). + Int("reused-policies", reusedCount). + Int("created-policies", len(cfg.Policies)-reusedCount). + Msg("updated policy evaluators") return m, nil } diff --git a/config/envoyconfig/lua_test.go b/config/envoyconfig/lua_test.go index b3badf32e..8b3f8d7d2 100644 --- a/config/envoyconfig/lua_test.go +++ b/config/envoyconfig/lua_test.go @@ -24,9 +24,11 @@ func TestLuaCleanUpstream(t *testing.T) { "context-type": "text/plain", "authorization": "Pomerium JWT", "x-pomerium-authorization": "JWT", + "cookie": "cookieA=aaa_pomerium=123; cookieb=bbb; _pomerium=ey;_pomerium_test1=stillhere ; _pomerium_test2=stillhere", } metadata := map[string]interface{}{ "remove_pomerium_authorization": true, + "remove_pomerium_cookie": "_pomerium", } dynamicMetadata := map[string]map[string]interface{}{} handle := newLuaResponseHandle(L, headers, metadata, dynamicMetadata) @@ -40,6 +42,7 @@ func TestLuaCleanUpstream(t *testing.T) { assert.Equal(t, map[string]string{ "context-type": "text/plain", + "cookie": "cookieA=aaa_pomerium=123; cookieb=bbb; _pomerium_test1=stillhere ; _pomerium_test2=stillhere", }, headers) } diff --git a/config/envoyconfig/luascripts/clean-upstream.lua b/config/envoyconfig/luascripts/clean-upstream.lua index 64bd60315..f12079f8d 100644 --- a/config/envoyconfig/luascripts/clean-upstream.lua +++ b/config/envoyconfig/luascripts/clean-upstream.lua @@ -1,15 +1,23 @@ -function remove_pomerium_cookie(cookie_name, cookie) - -- lua doesn't support optional capture groups - -- so we replace twice to handle pomerium=xyz at the end of the string - cookie = cookie:gsub(cookie_name .. "=[^;]+; ", "") - cookie = cookie:gsub(cookie_name .. "=[^;]+", "") - return cookie -end - function has_prefix(str, prefix) return str ~= nil and str:sub(1, #prefix) == prefix end +function remove_pomerium_cookie(cookie_name, cookie) + local result = "" + for c in cookie:gmatch("([^;]+)") do + c = c:gsub("^ +","") + local name = c:match("^([^=]+)") + if name ~= cookie_name then + if string.len(result) > 0 then + result = result .. "; " .. c + else + result = result .. c + end + end + end + return result +end + function envoy_on_request(request_handle) local headers = request_handle:headers() local metadata = request_handle:metadata() @@ -18,7 +26,7 @@ function envoy_on_request(request_handle) if remove_cookie_name then local cookie = headers:get("cookie") if cookie ~= nil then - newcookie = remove_pomerium_cookie(remove_cookie_name, cookie) + local newcookie = remove_pomerium_cookie(remove_cookie_name, cookie) headers:replace("cookie", newcookie) end end diff --git a/config/envoyconfig/testdata/main_http_connection_manager_filter.json b/config/envoyconfig/testdata/main_http_connection_manager_filter.json index cc9cc0bb2..82a78649e 100644 --- a/config/envoyconfig/testdata/main_http_connection_manager_filter.json +++ b/config/envoyconfig/testdata/main_http_connection_manager_filter.json @@ -75,7 +75,7 @@ "typedConfig": { "@type": "type.googleapis.com/envoy.extensions.filters.http.lua.v3.Lua", "defaultSourceCode": { - "inlineString": "function remove_pomerium_cookie(cookie_name, cookie)\n -- lua doesn't support optional capture groups\n -- so we replace twice to handle pomerium=xyz at the end of the string\n cookie = cookie:gsub(cookie_name .. \"=[^;]+; \", \"\")\n cookie = cookie:gsub(cookie_name .. \"=[^;]+\", \"\")\n return cookie\nend\n\nfunction has_prefix(str, prefix)\n return str ~= nil and str:sub(1, #prefix) == prefix\nend\n\nfunction envoy_on_request(request_handle)\n local headers = request_handle:headers()\n local metadata = request_handle:metadata()\n\n local remove_cookie_name = metadata:get(\"remove_pomerium_cookie\")\n if remove_cookie_name then\n local cookie = headers:get(\"cookie\")\n if cookie ~= nil then\n newcookie = remove_pomerium_cookie(remove_cookie_name, cookie)\n headers:replace(\"cookie\", newcookie)\n end\n end\n\n local remove_authorization = metadata:get(\"remove_pomerium_authorization\")\n if remove_authorization then\n local authorization = headers:get(\"authorization\")\n local authorization_prefix = \"Pomerium \"\n if has_prefix(authorization, authorization_prefix) then\n headers:remove(\"authorization\")\n end\n\n headers:remove('x-pomerium-authorization')\n end\nend\n\nfunction envoy_on_response(response_handle) end\n" + "inlineString": "function has_prefix(str, prefix)\n return str ~= nil and str:sub(1, #prefix) == prefix\nend\n\nfunction remove_pomerium_cookie(cookie_name, cookie)\n local result = \"\"\n for c in cookie:gmatch(\"([^;]+)\") do\n c = c:gsub(\"^ +\",\"\")\n local name = c:match(\"^([^=]+)\")\n if name ~= cookie_name then\n if string.len(result) \u003e 0 then\n result = result .. \"; \" .. c\n else\n result = result .. c\n end\n end\n end\n return result\nend\n\nfunction envoy_on_request(request_handle)\n local headers = request_handle:headers()\n local metadata = request_handle:metadata()\n\n local remove_cookie_name = metadata:get(\"remove_pomerium_cookie\")\n if remove_cookie_name then\n local cookie = headers:get(\"cookie\")\n if cookie ~= nil then\n local newcookie = remove_pomerium_cookie(remove_cookie_name, cookie)\n headers:replace(\"cookie\", newcookie)\n end\n end\n\n local remove_authorization = metadata:get(\"remove_pomerium_authorization\")\n if remove_authorization then\n local authorization = headers:get(\"authorization\")\n local authorization_prefix = \"Pomerium \"\n if has_prefix(authorization, authorization_prefix) then\n headers:remove(\"authorization\")\n end\n\n headers:remove('x-pomerium-authorization')\n end\nend\n\nfunction envoy_on_response(response_handle) end\n" } } }, diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 7456328ae..dc2a52244 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -3,7 +3,6 @@ package databroker import ( "context" "fmt" - "runtime" "sort" "sync" "time" @@ -12,6 +11,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/errgrouputil" "github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" @@ -115,7 +115,6 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { func (src *ConfigSource) buildNewConfigLocked(ctx context.Context, cfg *config.Config) error { eg, ctx := errgroup.WithContext(ctx) - eg.SetLimit(runtime.NumCPU()/2 + 1) eg.Go(func() error { src.applySettingsLocked(ctx, cfg) err := cfg.Options.Validate() @@ -125,30 +124,32 @@ func (src *ConfigSource) buildNewConfigLocked(ctx context.Context, cfg *config.C return nil }) - var policies []*config.Policy - var builders []func() error - buildPolicy := func(i int, routepb *configpb.Route) func() error { - return func() error { - policy, err := src.buildPolicyFromProto(routepb) - if err != nil { - log.Ctx(ctx).Err(err).Msg("databroker: error building policy from protobuf") - return nil - } - policies[i] = policy - return nil - } - } - + var policyBuilders []errgrouputil.BuilderFunc[config.Policy] for _, cfgpb := range src.dbConfigs { for _, routepb := range cfgpb.GetRoutes() { - builders = append(builders, buildPolicy(len(builders), routepb)) + routepb := routepb + policyBuilders = append(policyBuilders, func(ctx context.Context) (*config.Policy, error) { + p, err := src.buildPolicyFromProto(ctx, routepb) + if err != nil { + return nil, fmt.Errorf("error building route id=%s: %w", routepb.GetId(), err) + } + return p, nil + }) } } - policies = make([]*config.Policy, len(builders)) - for _, builder := range builders { - eg.Go(builder) - } + var policies []*config.Policy + eg.Go(func() error { + var errs []error + policies, errs = errgrouputil.Build(ctx, policyBuilders...) + if len(errs) > 0 { + for _, err := range errs { + log.Error(ctx).Msg(err.Error()) + } + return fmt.Errorf("error building policies") + } + return nil + }) err := eg.Wait() if err != nil { @@ -177,7 +178,7 @@ func (src *ConfigSource) applySettingsLocked(ctx context.Context, cfg *config.Co } } -func (src *ConfigSource) buildPolicyFromProto(routepb *configpb.Route) (*config.Policy, error) { +func (src *ConfigSource) buildPolicyFromProto(_ context.Context, routepb *configpb.Route) (*config.Policy, error) { policy, err := config.NewPolicyFromProto(routepb) if err != nil { return nil, fmt.Errorf("error building policy from protobuf: %w", err) diff --git a/internal/errgrouputil/builder.go b/internal/errgrouputil/builder.go new file mode 100644 index 000000000..87c2489a4 --- /dev/null +++ b/internal/errgrouputil/builder.go @@ -0,0 +1,54 @@ +// Package errgrouputil contains methods for working with errgroup code. +package errgrouputil + +import ( + "context" + "runtime" + + "golang.org/x/sync/errgroup" + + "github.com/pomerium/pomerium/pkg/slices" +) + +// BuilderFunc is a function that builds a value of type T +type BuilderFunc[T any] func(ctx context.Context) (*T, error) + +// Build builds a slice of values of type T using the provided builders concurrently +// and returns the results and any errors. +func Build[T any]( + ctx context.Context, + builders ...BuilderFunc[T], +) ([]*T, []error) { + eg, ctx := errgroup.WithContext(ctx) + eg.SetLimit(runtime.NumCPU()/2 + 1) + + results := make([]*T, len(builders)) + errors := make([]error, len(builders)) + + fn := func(i int) func() error { + return func() error { + result, err := builders[i](ctx) + if err != nil { + errors[i] = err + return nil + } + results[i] = result + return nil + } + } + + for i := range builders { + eg.Go(fn(i)) + } + + err := eg.Wait() + if err != nil { + return nil, []error{err} // not happening + } + + return slices.Filter(results, func(t *T) bool { + return t != nil + }), slices.Filter(errors, func(err error) bool { + return err != nil + }) +}