close ssh connection when session is revoked

This commit is contained in:
Joe Kralicky 2025-03-26 21:11:08 +00:00
parent 8eff4a48a4
commit 19b67bf32d
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
7 changed files with 81 additions and 14 deletions

View file

@ -24,8 +24,10 @@ import (
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
@ -33,11 +35,12 @@ import (
// Authorize struct holds
type Authorize struct {
state *atomicutil.Value[*authorizeState]
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
groupsCacheWarmer *cacheWarmer
state *atomicutil.Value[*authorizeState]
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
groupsCacheWarmer *cacheWarmer
sessionsCacheWarmer *cacheWarmer
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
@ -67,6 +70,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
a.state = atomicutil.NewValue(state)
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
a.sessionsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, protoutil.GetTypeURL(&session.Session{}))
return a, nil
}
@ -86,6 +90,10 @@ func (a *Authorize) Run(ctx context.Context) error {
a.groupsCacheWarmer.Run(ctx)
return nil
})
eg.Go(func() error {
a.sessionsCacheWarmer.Run(ctx)
return nil
})
return eg.Wait()
}
@ -173,6 +181,7 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
a.sessionsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
}
}
}

View file

@ -111,6 +111,11 @@ func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersi
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
continue
}
if record.DeletedAt != nil {
log.Ctx(ctx).Info().Msg("record deleted: " + record.Id)
h.cache.Invalidate(key)
return
}
value, err := storage.MarshalQueryResponse(res)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")

View file

@ -29,7 +29,9 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/identity/oauth"
@ -278,7 +280,7 @@ func (a *Authorize) ManageStream(
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id)
a.startContinuousAuthorization(ctx, errC, req, session)
}
}
@ -411,7 +413,7 @@ func (a *Authorize) ManageStream(
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
a.startContinuousAuthorization(ctx, errC, req, state.Session)
}
} else {
resp := extensions_ssh.ServerMessage{
@ -658,20 +660,35 @@ func (a *Authorize) startContinuousAuthorization(
ctx context.Context,
errC chan<- error,
req *evaluator.Request,
sessionID string,
session *session.Session,
) {
recheck := func() {
// XXX: probably want to log the results of this evaluation only if it changes
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
res, _ := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
if !res.Allow.Value || res.Deny.Value {
errC <- fmt.Errorf("no longer authorized")
}
}
ticker := time.NewTicker(time.Second)
keyReq := &databroker.QueryRequest{
Type: grpcutil.GetTypeURL(session),
Limit: 1,
}
keyReq.SetFilterByIDOrIndex(session.Id)
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(keyReq)
if err != nil {
panic(err)
}
ticker := time.NewTicker(10 * time.Second)
go func() {
for {
select {
case <-a.sessionsCacheWarmer.cache.Wait(key):
errC <- fmt.Errorf("session expired")
return
case <-ticker.C:
recheck()
case <-ctx.Done():
@ -1177,6 +1194,9 @@ func (a *Authorize) NewPortalCommand(
if err != nil {
return err
}
if answer.(model).choice == "" {
return nil // quit/ctrl+c
}
var handOff *anypb.Any
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
@ -1216,7 +1236,7 @@ func (a *Authorize) NewPortalCommand(
}
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session)
} else {
return fmt.Errorf("not authorized")
}

View file

@ -176,6 +176,7 @@ func (b *Builder) buildRouteConfig(_ context.Context, cfg *config.Config) (*envo
ClusterSpecifier: &envoy_generic_proxy_action_v3.RouteAction_Cluster{
Cluster: clusterId,
},
Timeout: durationpb.New(0),
}),
},
},

2
go.mod
View file

@ -254,7 +254,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap/exp v0.3.0 // indirect
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/tools v0.24.0 // indirect
google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect

4
go.sum
View file

@ -817,8 +817,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=

View file

@ -5,6 +5,7 @@ import (
"encoding/binary"
"sync"
"time"
"unique"
"github.com/VictoriaMetrics/fastcache"
"golang.org/x/sync/singleflight"
@ -20,6 +21,7 @@ type Cache interface {
Invalidate(key []byte)
InvalidateAll()
Set(expiry time.Time, key, value []byte)
Wait(key []byte) <-chan struct{}
}
type globalCache struct {
@ -28,6 +30,7 @@ type globalCache struct {
singleflight singleflight.Group
mu sync.RWMutex
fastcache *fastcache.Cache
waiters map[unique.Handle[string]]chan struct{}
}
// NewGlobalCache creates a new Cache backed by fastcache and a TTL.
@ -35,6 +38,7 @@ func NewGlobalCache(ttl time.Duration) Cache {
return &globalCache{
ttl: ttl,
fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM
waiters: map[unique.Handle[string]]chan struct{}{},
}
}
@ -71,12 +75,40 @@ func (cache *globalCache) GetOrUpdate(
func (cache *globalCache) Invalidate(key []byte) {
cache.mu.Lock()
cache.fastcache.Del(key)
keyHandle := unique.Make(string(key))
if c, ok := cache.waiters[keyHandle]; ok {
close(c)
delete(cache.waiters, keyHandle)
}
cache.mu.Unlock()
}
var expiredC = make(chan struct{})
func init() {
close(expiredC)
}
func (cache *globalCache) Wait(key []byte) <-chan struct{} {
cache.mu.Lock()
defer cache.mu.Unlock()
if !cache.fastcache.Has(key) {
return expiredC
}
keyHandle := unique.Make(string(key))
if _, ok := cache.waiters[keyHandle]; !ok {
cache.waiters[keyHandle] = make(chan struct{})
}
return cache.waiters[keyHandle]
}
func (cache *globalCache) InvalidateAll() {
cache.mu.Lock()
cache.fastcache.Reset()
for _, c := range cache.waiters {
close(c)
}
clear(cache.waiters)
cache.mu.Unlock()
}