diff --git a/authorize/authorize.go b/authorize/authorize.go index 959056970..e3be78eda 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -24,8 +24,10 @@ import ( "github.com/pomerium/pomerium/pkg/contextutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/policy/criteria" + "github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/telemetry/requestid" "github.com/pomerium/pomerium/pkg/telemetry/trace" @@ -33,11 +35,12 @@ import ( // Authorize struct holds type Authorize struct { - state *atomicutil.Value[*authorizeState] - store *store.Store - currentConfig *atomicutil.Value[*config.Config] - accessTracker *AccessTracker - groupsCacheWarmer *cacheWarmer + state *atomicutil.Value[*authorizeState] + store *store.Store + currentConfig *atomicutil.Value[*config.Config] + accessTracker *AccessTracker + groupsCacheWarmer *cacheWarmer + sessionsCacheWarmer *cacheWarmer tracerProvider oteltrace.TracerProvider tracer oteltrace.Tracer @@ -67,6 +70,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { a.state = atomicutil.NewValue(state) a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType) + a.sessionsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, protoutil.GetTypeURL(&session.Session{})) return a, nil } @@ -86,6 +90,10 @@ func (a *Authorize) Run(ctx context.Context) error { a.groupsCacheWarmer.Run(ctx) return nil }) + eg.Go(func() error { + a.sessionsCacheWarmer.Run(ctx) + return nil + }) return eg.Wait() } @@ -173,6 +181,7 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection { a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection) + a.sessionsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection) } } } diff --git a/authorize/cache_warmer.go b/authorize/cache_warmer.go index 41c1d0ae2..7564b9db8 100644 --- a/authorize/cache_warmer.go +++ b/authorize/cache_warmer.go @@ -111,6 +111,11 @@ func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersi log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request") continue } + if record.DeletedAt != nil { + log.Ctx(ctx).Info().Msg("record deleted: " + record.Id) + h.cache.Invalidate(key) + return + } value, err := storage.MarshalQueryResponse(res) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response") diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index b2f7940e8..eb418ac6e 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -29,7 +29,9 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/manager" "github.com/pomerium/pomerium/pkg/identity/oauth" @@ -278,7 +280,7 @@ func (a *Authorize) ManageStream( sendC <- handleEvaluatorResponseForSSH(res, state) if res.Allow.Value && !res.Deny.Value { - a.startContinuousAuthorization(ctx, errC, req, session.Id) + a.startContinuousAuthorization(ctx, errC, req, session) } } @@ -411,7 +413,7 @@ func (a *Authorize) ManageStream( sendC <- handleEvaluatorResponseForSSH(res, state) if res.Allow.Value && !res.Deny.Value { - a.startContinuousAuthorization(ctx, errC, req, state.Session.Id) + a.startContinuousAuthorization(ctx, errC, req, state.Session) } } else { resp := extensions_ssh.ServerMessage{ @@ -658,20 +660,35 @@ func (a *Authorize) startContinuousAuthorization( ctx context.Context, errC chan<- error, req *evaluator.Request, - sessionID string, + session *session.Session, ) { recheck := func() { // XXX: probably want to log the results of this evaluation only if it changes - res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID}) + res, _ := a.evaluate(ctx, req, &sessions.State{ID: session.Id}) if !res.Allow.Value || res.Deny.Value { errC <- fmt.Errorf("no longer authorized") } } - ticker := time.NewTicker(time.Second) + keyReq := &databroker.QueryRequest{ + Type: grpcutil.GetTypeURL(session), + Limit: 1, + } + keyReq.SetFilterByIDOrIndex(session.Id) + key, err := (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(keyReq) + if err != nil { + panic(err) + } + + ticker := time.NewTicker(10 * time.Second) go func() { for { select { + case <-a.sessionsCacheWarmer.cache.Wait(key): + errC <- fmt.Errorf("session expired") + return case <-ticker.C: recheck() case <-ctx.Done(): @@ -1177,6 +1194,9 @@ func (a *Authorize) NewPortalCommand( if err != nil { return err } + if answer.(model).choice == "" { + return nil // quit/ctrl+c + } var handOff *anypb.Any if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") { id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64) @@ -1216,7 +1236,7 @@ func (a *Authorize) NewPortalCommand( } if res.Allow.Value && !res.Deny.Value { - a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id) + a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session) } else { return fmt.Errorf("not authorized") } diff --git a/config/envoyconfig/listeners_ssh.go b/config/envoyconfig/listeners_ssh.go index fbb779e09..bdfcb9b5a 100644 --- a/config/envoyconfig/listeners_ssh.go +++ b/config/envoyconfig/listeners_ssh.go @@ -176,6 +176,7 @@ func (b *Builder) buildRouteConfig(_ context.Context, cfg *config.Config) (*envo ClusterSpecifier: &envoy_generic_proxy_action_v3.RouteAction_Cluster{ Cluster: clusterId, }, + Timeout: durationpb.New(0), }), }, }, diff --git a/go.mod b/go.mod index 306f2dd22..10649981c 100644 --- a/go.mod +++ b/go.mod @@ -254,7 +254,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap/exp v0.3.0 // indirect golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect - golang.org/x/mod v0.20.0 // indirect + golang.org/x/mod v0.21.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/tools v0.24.0 // indirect google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect diff --git a/go.sum b/go.sum index 5578f0b3c..1c0d417eb 100644 --- a/go.sum +++ b/go.sum @@ -817,8 +817,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= -golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/pkg/storage/cache.go b/pkg/storage/cache.go index 3d8d563e6..f864486b9 100644 --- a/pkg/storage/cache.go +++ b/pkg/storage/cache.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "sync" "time" + "unique" "github.com/VictoriaMetrics/fastcache" "golang.org/x/sync/singleflight" @@ -20,6 +21,7 @@ type Cache interface { Invalidate(key []byte) InvalidateAll() Set(expiry time.Time, key, value []byte) + Wait(key []byte) <-chan struct{} } type globalCache struct { @@ -28,6 +30,7 @@ type globalCache struct { singleflight singleflight.Group mu sync.RWMutex fastcache *fastcache.Cache + waiters map[unique.Handle[string]]chan struct{} } // NewGlobalCache creates a new Cache backed by fastcache and a TTL. @@ -35,6 +38,7 @@ func NewGlobalCache(ttl time.Duration) Cache { return &globalCache{ ttl: ttl, fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM + waiters: map[unique.Handle[string]]chan struct{}{}, } } @@ -71,12 +75,40 @@ func (cache *globalCache) GetOrUpdate( func (cache *globalCache) Invalidate(key []byte) { cache.mu.Lock() cache.fastcache.Del(key) + keyHandle := unique.Make(string(key)) + if c, ok := cache.waiters[keyHandle]; ok { + close(c) + delete(cache.waiters, keyHandle) + } cache.mu.Unlock() } +var expiredC = make(chan struct{}) + +func init() { + close(expiredC) +} + +func (cache *globalCache) Wait(key []byte) <-chan struct{} { + cache.mu.Lock() + defer cache.mu.Unlock() + if !cache.fastcache.Has(key) { + return expiredC + } + keyHandle := unique.Make(string(key)) + if _, ok := cache.waiters[keyHandle]; !ok { + cache.waiters[keyHandle] = make(chan struct{}) + } + return cache.waiters[keyHandle] +} + func (cache *globalCache) InvalidateAll() { cache.mu.Lock() cache.fastcache.Reset() + for _, c := range cache.waiters { + close(c) + } + clear(cache.waiters) cache.mu.Unlock() }