mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-18 11:37:08 +02:00
core/go: use hashicorp/go-set (#5278)
This commit is contained in:
parent
5679589014
commit
410354bc00
12 changed files with 66 additions and 176 deletions
|
@ -63,10 +63,10 @@ func (tracker *AccessTracker) Run(ctx context.Context) {
|
||||||
sessionAccesses := sets.NewSizeLimited[string](tracker.maxSize)
|
sessionAccesses := sets.NewSizeLimited[string](tracker.maxSize)
|
||||||
serviceAccountAccesses := sets.NewSizeLimited[string](tracker.maxSize)
|
serviceAccountAccesses := sets.NewSizeLimited[string](tracker.maxSize)
|
||||||
runTrackSessionAccess := func(sessionID string) {
|
runTrackSessionAccess := func(sessionID string) {
|
||||||
sessionAccesses.Add(sessionID)
|
sessionAccesses.Insert(sessionID)
|
||||||
}
|
}
|
||||||
runTrackServiceAccountAccess := func(serviceAccountID string) {
|
runTrackServiceAccountAccess := func(serviceAccountID string) {
|
||||||
serviceAccountAccesses.Add(serviceAccountID)
|
serviceAccountAccesses.Insert(serviceAccountID)
|
||||||
}
|
}
|
||||||
runSubmit := func() {
|
runSubmit := func() {
|
||||||
if dropped := atomic.SwapInt64(&tracker.droppedAccesses, 0); dropped > 0 {
|
if dropped := atomic.SwapInt64(&tracker.droppedAccesses, 0); dropped > 0 {
|
||||||
|
@ -77,24 +77,20 @@ func (tracker *AccessTracker) Run(ctx context.Context) {
|
||||||
|
|
||||||
client := tracker.provider.GetDataBrokerServiceClient()
|
client := tracker.provider.GetDataBrokerServiceClient()
|
||||||
|
|
||||||
var err error
|
for sessionID := range sessionAccesses.Items() {
|
||||||
|
err := tracker.updateSession(ctx, client, sessionID)
|
||||||
sessionAccesses.ForEach(func(sessionID string) bool {
|
if err != nil {
|
||||||
err = tracker.updateSession(ctx, client, sessionID)
|
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating session last access timestamp")
|
||||||
return err == nil
|
return
|
||||||
})
|
}
|
||||||
if err != nil {
|
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating session last access timestamp")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceAccountAccesses.ForEach(func(serviceAccountID string) bool {
|
for serviceAccountID := range serviceAccountAccesses.Items() {
|
||||||
err = tracker.updateServiceAccount(ctx, client, serviceAccountID)
|
err := tracker.updateServiceAccount(ctx, client, serviceAccountID)
|
||||||
return err == nil
|
if err != nil {
|
||||||
})
|
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating service account last access timestamp")
|
||||||
if err != nil {
|
return
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating service account last access timestamp")
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionAccesses = sets.NewSizeLimited[string](tracker.maxSize)
|
sessionAccesses = sets.NewSizeLimited[string](tracker.maxSize)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/hashicorp/go-set/v3"
|
||||||
"github.com/open-policy-agent/opa/rego"
|
"github.com/open-policy-agent/opa/rego"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
@ -18,7 +19,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/errgrouputil"
|
"github.com/pomerium/pomerium/internal/errgrouputil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sets"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
@ -240,14 +240,14 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal endpoints that require a logged-in user.
|
// Internal endpoints that require a logged-in user.
|
||||||
var internalPathsNeedingLogin = sets.NewHash(
|
var internalPathsNeedingLogin = set.From([]string{
|
||||||
"/.pomerium/jwt",
|
"/.pomerium/jwt",
|
||||||
"/.pomerium/user",
|
"/.pomerium/user",
|
||||||
"/.pomerium/webauthn",
|
"/.pomerium/webauthn",
|
||||||
)
|
})
|
||||||
|
|
||||||
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
|
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
|
||||||
if internalPathsNeedingLogin.Has(req.HTTP.Path) {
|
if internalPathsNeedingLogin.Contains(req.HTTP.Path) {
|
||||||
if req.Session.ID == "" {
|
if req.Session.ID == "" {
|
||||||
return &PolicyResponse{
|
return &PolicyResponse{
|
||||||
Allow: NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
Allow: NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
||||||
|
|
|
@ -2,6 +2,7 @@ package envoyconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -20,13 +21,13 @@ import (
|
||||||
envoy_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
|
envoy_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
|
||||||
envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
|
envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
|
||||||
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
|
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
|
||||||
|
"github.com/hashicorp/go-set/v3"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/hashutil"
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sets"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -648,14 +649,14 @@ func addCAToBundle(bundle *bytes.Buffer, ca []byte) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) {
|
func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) {
|
||||||
allHosts := sets.NewSorted[string]()
|
allHosts := set.NewTreeSet(cmp.Compare[string])
|
||||||
|
|
||||||
if addr == options.Addr {
|
if addr == options.Addr {
|
||||||
hosts, err := options.GetAllRouteableHTTPHosts()
|
hosts, err := options.GetAllRouteableHTTPHosts()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
allHosts.Add(hosts...)
|
allHosts.InsertSlice(hosts)
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr == options.GetGRPCAddr() {
|
if addr == options.GetGRPCAddr() {
|
||||||
|
@ -663,11 +664,11 @@ func getAllRouteableHosts(options *config.Options, addr string) ([]string, error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
allHosts.Add(hosts...)
|
allHosts.InsertSlice(hosts)
|
||||||
}
|
}
|
||||||
|
|
||||||
var filtered []string
|
var filtered []string
|
||||||
for _, host := range allHosts.ToSlice() {
|
for host := range allHosts.Items() {
|
||||||
if !strings.Contains(host, "*") {
|
if !strings.Contains(host, "*") {
|
||||||
filtered = append(filtered, host)
|
filtered = append(filtered, host)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
@ -18,6 +19,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
envoy_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
|
envoy_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
|
||||||
|
goset "github.com/hashicorp/go-set/v3"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
@ -29,7 +31,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/hashutil"
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sets"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -1237,7 +1238,7 @@ func (o *Options) GetCodecType() CodecType {
|
||||||
|
|
||||||
// GetAllRouteableGRPCHosts returns all the possible gRPC hosts handled by the Pomerium options.
|
// GetAllRouteableGRPCHosts returns all the possible gRPC hosts handled by the Pomerium options.
|
||||||
func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
||||||
hosts := sets.NewSorted[string]()
|
hosts := goset.NewTreeSet(cmp.Compare[string])
|
||||||
|
|
||||||
// authorize urls
|
// authorize urls
|
||||||
if IsAll(o.Services) {
|
if IsAll(o.Services) {
|
||||||
|
@ -1246,7 +1247,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, u := range authorizeURLs {
|
for _, u := range authorizeURLs {
|
||||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(u, true))
|
||||||
}
|
}
|
||||||
} else if IsAuthorize(o.Services) {
|
} else if IsAuthorize(o.Services) {
|
||||||
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
||||||
|
@ -1254,7 +1255,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, u := range authorizeURLs {
|
for _, u := range authorizeURLs {
|
||||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(u, true))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1265,7 +1266,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, u := range dataBrokerURLs {
|
for _, u := range dataBrokerURLs {
|
||||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(u, true))
|
||||||
}
|
}
|
||||||
} else if IsDataBroker(o.Services) {
|
} else if IsDataBroker(o.Services) {
|
||||||
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
||||||
|
@ -1273,23 +1274,23 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, u := range dataBrokerURLs {
|
for _, u := range dataBrokerURLs {
|
||||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(u, true))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return hosts.ToSlice(), nil
|
return hosts.Slice(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options.
|
// GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options.
|
||||||
func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
||||||
hosts := sets.NewSorted[string]()
|
hosts := goset.NewTreeSet(cmp.Compare[string])
|
||||||
if IsAuthenticate(o.Services) {
|
if IsAuthenticate(o.Services) {
|
||||||
if o.AuthenticateInternalURLString != "" {
|
if o.AuthenticateInternalURLString != "" {
|
||||||
authenticateURL, err := o.GetInternalAuthenticateURL()
|
authenticateURL, err := o.GetInternalAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.AuthenticateURLString != "" {
|
if o.AuthenticateURLString != "" {
|
||||||
|
@ -1297,7 +1298,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1309,15 +1310,15 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts.Add(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
|
||||||
if policy.TLSDownstreamServerName != "" {
|
if policy.TLSDownstreamServerName != "" {
|
||||||
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
|
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
|
||||||
hosts.Add(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))...)
|
hosts.InsertSlice(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return hosts.ToSlice(), nil
|
return hosts.Slice(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientSecret gets the client secret.
|
// GetClientSecret gets the client secret.
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -78,7 +78,6 @@ require (
|
||||||
go.uber.org/mock v0.4.0
|
go.uber.org/mock v0.4.0
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
golang.org/x/crypto v0.27.0
|
golang.org/x/crypto v0.27.0
|
||||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
|
|
||||||
golang.org/x/net v0.29.0
|
golang.org/x/net v0.29.0
|
||||||
golang.org/x/oauth2 v0.23.0
|
golang.org/x/oauth2 v0.23.0
|
||||||
golang.org/x/sync v0.8.0
|
golang.org/x/sync v0.8.0
|
||||||
|
@ -217,6 +216,7 @@ require (
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 // indirect
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 // indirect
|
||||||
go.opentelemetry.io/proto/otlp v1.3.1 // indirect
|
go.opentelemetry.io/proto/otlp v1.3.1 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
||||||
golang.org/x/mod v0.20.0 // indirect
|
golang.org/x/mod v0.20.0 // indirect
|
||||||
golang.org/x/text v0.18.0 // indirect
|
golang.org/x/text v0.18.0 // indirect
|
||||||
golang.org/x/tools v0.24.0 // indirect
|
golang.org/x/tools v0.24.0 // indirect
|
||||||
|
|
|
@ -4,9 +4,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-set/v3"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/sets"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -30,13 +29,13 @@ func HTTPHeaders[TField interface{ ~string }](
|
||||||
src map[string]string,
|
src map[string]string,
|
||||||
) *zerolog.Event {
|
) *zerolog.Event {
|
||||||
all := false
|
all := false
|
||||||
include := sets.NewHash[string]()
|
include := set.New[string](len(fields))
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field == headersFieldName {
|
if field == headersFieldName {
|
||||||
all = true
|
all = true
|
||||||
break
|
break
|
||||||
} else if strings.HasPrefix(string(field), headersFieldPrefix) {
|
} else if strings.HasPrefix(string(field), headersFieldPrefix) {
|
||||||
include.Add(CanonicalHeaderKey(string(field[len(headersFieldPrefix):])))
|
include.Insert(CanonicalHeaderKey(string(field[len(headersFieldPrefix):])))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +47,7 @@ func HTTPHeaders[TField interface{ ~string }](
|
||||||
hdrs := map[string]string{}
|
hdrs := map[string]string{}
|
||||||
for k, v := range src {
|
for k, v := range src {
|
||||||
h := CanonicalHeaderKey(k)
|
h := CanonicalHeaderKey(k)
|
||||||
if all || include.Has(h) {
|
if all || include.Contains(h) {
|
||||||
hdrs[h] = v
|
hdrs[h] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,42 +0,0 @@
|
||||||
package sets
|
|
||||||
|
|
||||||
// A Hash is a set implemented via a map.
|
|
||||||
type Hash[T comparable] struct {
|
|
||||||
m map[T]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHash creates a new Hash set.
|
|
||||||
func NewHash[T comparable](initialValues ...T) *Hash[T] {
|
|
||||||
s := &Hash[T]{
|
|
||||||
m: make(map[T]struct{}),
|
|
||||||
}
|
|
||||||
s.Add(initialValues...)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds a value to the set.
|
|
||||||
func (s *Hash[T]) Add(elements ...T) {
|
|
||||||
for _, element := range elements {
|
|
||||||
s.m[element] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Has returns true if the element is in the set.
|
|
||||||
func (s *Hash[T]) Has(element T) bool {
|
|
||||||
_, ok := s.m[element]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns the size of the set.
|
|
||||||
func (s *Hash[T]) Size() int {
|
|
||||||
return len(s.m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Items returns the set's elements as a slice.
|
|
||||||
func (s *Hash[T]) Items() []T {
|
|
||||||
items := make([]T, 0, len(s.m))
|
|
||||||
for item := range s.m {
|
|
||||||
items = append(items, item)
|
|
||||||
}
|
|
||||||
return items
|
|
||||||
}
|
|
|
@ -1,5 +1,10 @@
|
||||||
package sets
|
package sets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"iter"
|
||||||
|
"maps"
|
||||||
|
)
|
||||||
|
|
||||||
// A SizeLimited is a Set which is limited to a given size. Once
|
// A SizeLimited is a Set which is limited to a given size. Once
|
||||||
// the capacity is reached an element will be removed at random.
|
// the capacity is reached an element will be removed at random.
|
||||||
type SizeLimited[T comparable] struct {
|
type SizeLimited[T comparable] struct {
|
||||||
|
@ -15,8 +20,8 @@ func NewSizeLimited[T comparable](capacity int) *SizeLimited[T] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds an element to the set.
|
// Insert adds an element to the set.
|
||||||
func (s *SizeLimited[T]) Add(element T) {
|
func (s *SizeLimited[T]) Insert(element T) {
|
||||||
s.m[element] = struct{}{}
|
s.m[element] = struct{}{}
|
||||||
for len(s.m) > s.capacity {
|
for len(s.m) > s.capacity {
|
||||||
for k := range s.m {
|
for k := range s.m {
|
||||||
|
@ -26,11 +31,7 @@ func (s *SizeLimited[T]) Add(element T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForEach iterates over all the elements in the set.
|
// Items returns an iterator over the items in the set. Order is not specified.
|
||||||
func (s *SizeLimited[T]) ForEach(callback func(element T) bool) {
|
func (s *SizeLimited[T]) Items() iter.Seq[T] {
|
||||||
for k := range s.m {
|
return maps.Keys(s.m)
|
||||||
if !callback(k) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,67 +0,0 @@
|
||||||
package sets
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/google/btree"
|
|
||||||
"golang.org/x/exp/constraints"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Sorted is a set with sorted iteration.
|
|
||||||
type Sorted[T any] struct {
|
|
||||||
b *btree.BTreeG[T]
|
|
||||||
less func(a, b T) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSorted creates a new sorted string set.
|
|
||||||
func NewSorted[T constraints.Ordered]() *Sorted[T] {
|
|
||||||
less := func(a, b T) bool {
|
|
||||||
return a < b
|
|
||||||
}
|
|
||||||
return &Sorted[T]{
|
|
||||||
b: btree.NewG(8, less),
|
|
||||||
less: less,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds a string to the set.
|
|
||||||
func (s *Sorted[T]) Add(elements ...T) {
|
|
||||||
for _, element := range elements {
|
|
||||||
s.b.ReplaceOrInsert(element)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear clears the set.
|
|
||||||
func (s *Sorted[T]) Clear() {
|
|
||||||
s.b = btree.NewG(8, s.less)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete deletes an element from the set.
|
|
||||||
func (s *Sorted[T]) Delete(element T) {
|
|
||||||
s.b.Delete(element)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForEach iterates over the set in ascending order.
|
|
||||||
func (s *Sorted[T]) ForEach(callback func(element T) bool) {
|
|
||||||
s.b.Ascend(func(item T) bool {
|
|
||||||
return callback(item)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Has returns true if the element is in the set.
|
|
||||||
func (s *Sorted[T]) Has(element T) bool {
|
|
||||||
return s.b.Has(element)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns the size of the set.
|
|
||||||
func (s *Sorted[T]) Size() int {
|
|
||||||
return s.b.Len()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToSlice returns a slice of all the elements in the set.
|
|
||||||
func (s *Sorted[T]) ToSlice() []T {
|
|
||||||
arr := make([]T, 0, s.Size())
|
|
||||||
s.b.Ascend(func(item T) bool {
|
|
||||||
arr = append(arr, item)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
return arr
|
|
||||||
}
|
|
|
@ -9,12 +9,12 @@ import (
|
||||||
"go.opentelemetry.io/otel/sdk/metric"
|
"go.opentelemetry.io/otel/sdk/metric"
|
||||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/sets"
|
"github.com/hashicorp/go-set/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Producer struct {
|
type Producer struct {
|
||||||
producer metric.Producer
|
producer metric.Producer
|
||||||
filter atomic.Pointer[sets.Hash[string]]
|
filter atomic.Pointer[set.Set[string]]
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ metric.Producer = (*Producer)(nil)
|
var _ metric.Producer = (*Producer)(nil)
|
||||||
|
@ -37,7 +37,7 @@ func (p *Producer) Produce(ctx context.Context) ([]metricdata.ScopeMetrics, erro
|
||||||
for _, sm := range metrics {
|
for _, sm := range metrics {
|
||||||
var m []metricdata.Metrics
|
var m []metricdata.Metrics
|
||||||
for _, metric := range sm.Metrics {
|
for _, metric := range sm.Metrics {
|
||||||
if filter.Has(metric.Name) {
|
if filter.Contains(metric.Name) {
|
||||||
m = append(m, metric)
|
m = append(m, metric)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -52,5 +52,5 @@ func (p *Producer) Produce(ctx context.Context) ([]metricdata.ScopeMetrics, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Producer) SetFilter(names []string) {
|
func (p *Producer) SetFilter(names []string) {
|
||||||
p.filter.Store(sets.NewHash(names...))
|
p.filter.Store(set.From(names))
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/sets"
|
"github.com/hashicorp/go-set/v3"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
@ -25,7 +26,7 @@ func CurrentUsers(
|
||||||
return nil, fmt.Errorf("fetching sessions: %w", err)
|
return nil, fmt.Errorf("fetching sessions: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
users := sets.NewHash[string]()
|
users := set.New[string](len(records))
|
||||||
utcNow := time.Now().UTC()
|
utcNow := time.Now().UTC()
|
||||||
threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
@ -44,8 +45,8 @@ func CurrentUsers(
|
||||||
if s.AccessedAt.AsTime().Before(threshold) {
|
if s.AccessedAt.AsTime().Before(threshold) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
users.Add(s.UserId)
|
users.Insert(s.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
return users.Items(), nil
|
return users.Slice(), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,10 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-set/v3"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/sets"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/registry"
|
"github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,10 +37,10 @@ func (backend registryServer) List(
|
||||||
}
|
}
|
||||||
|
|
||||||
res := new(registry.ServiceList)
|
res := new(registry.ServiceList)
|
||||||
s := sets.NewHash[registry.ServiceKind]()
|
s := set.New[registry.ServiceKind](len(all))
|
||||||
s.Add(req.GetKinds()...)
|
s.InsertSlice(req.GetKinds())
|
||||||
for _, svc := range all {
|
for _, svc := range all {
|
||||||
if s.Size() == 0 || s.Has(svc.GetKind()) {
|
if s.Size() == 0 || s.Contains(svc.GetKind()) {
|
||||||
res.Services = append(res.Services, svc)
|
res.Services = append(res.Services, svc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue