mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
close ssh connection when session is revoked
This commit is contained in:
parent
8eff4a48a4
commit
19b67bf32d
7 changed files with 81 additions and 14 deletions
|
@ -24,8 +24,10 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||
|
@ -33,11 +35,12 @@ import (
|
|||
|
||||
// Authorize struct holds
|
||||
type Authorize struct {
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
groupsCacheWarmer *cacheWarmer
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
groupsCacheWarmer *cacheWarmer
|
||||
sessionsCacheWarmer *cacheWarmer
|
||||
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
tracer oteltrace.Tracer
|
||||
|
@ -67,6 +70,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
a.state = atomicutil.NewValue(state)
|
||||
|
||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
|
||||
a.sessionsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, protoutil.GetTypeURL(&session.Session{}))
|
||||
return a, nil
|
||||
}
|
||||
|
||||
|
@ -86,6 +90,10 @@ func (a *Authorize) Run(ctx context.Context) error {
|
|||
a.groupsCacheWarmer.Run(ctx)
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
a.sessionsCacheWarmer.Run(ctx)
|
||||
return nil
|
||||
})
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
|
@ -173,6 +181,7 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
|
||||
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
|
||||
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
|
||||
a.sessionsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -111,6 +111,11 @@ func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersi
|
|||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
|
||||
continue
|
||||
}
|
||||
if record.DeletedAt != nil {
|
||||
log.Ctx(ctx).Info().Msg("record deleted: " + record.Id)
|
||||
h.cache.Invalidate(key)
|
||||
return
|
||||
}
|
||||
value, err := storage.MarshalQueryResponse(res)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
|
||||
|
|
|
@ -29,7 +29,9 @@ import (
|
|||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
|
@ -278,7 +280,7 @@ func (a *Authorize) ManageStream(
|
|||
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
||||
a.startContinuousAuthorization(ctx, errC, req, session)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -411,7 +413,7 @@ func (a *Authorize) ManageStream(
|
|||
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
|
||||
a.startContinuousAuthorization(ctx, errC, req, state.Session)
|
||||
}
|
||||
} else {
|
||||
resp := extensions_ssh.ServerMessage{
|
||||
|
@ -658,20 +660,35 @@ func (a *Authorize) startContinuousAuthorization(
|
|||
ctx context.Context,
|
||||
errC chan<- error,
|
||||
req *evaluator.Request,
|
||||
sessionID string,
|
||||
session *session.Session,
|
||||
) {
|
||||
recheck := func() {
|
||||
// XXX: probably want to log the results of this evaluation only if it changes
|
||||
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
|
||||
res, _ := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
|
||||
if !res.Allow.Value || res.Deny.Value {
|
||||
errC <- fmt.Errorf("no longer authorized")
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
keyReq := &databroker.QueryRequest{
|
||||
Type: grpcutil.GetTypeURL(session),
|
||||
Limit: 1,
|
||||
}
|
||||
keyReq.SetFilterByIDOrIndex(session.Id)
|
||||
key, err := (&proto.MarshalOptions{
|
||||
Deterministic: true,
|
||||
}).Marshal(keyReq)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-a.sessionsCacheWarmer.cache.Wait(key):
|
||||
errC <- fmt.Errorf("session expired")
|
||||
return
|
||||
case <-ticker.C:
|
||||
recheck()
|
||||
case <-ctx.Done():
|
||||
|
@ -1177,6 +1194,9 @@ func (a *Authorize) NewPortalCommand(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if answer.(model).choice == "" {
|
||||
return nil // quit/ctrl+c
|
||||
}
|
||||
var handOff *anypb.Any
|
||||
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
|
||||
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
|
||||
|
@ -1216,7 +1236,7 @@ func (a *Authorize) NewPortalCommand(
|
|||
}
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
|
||||
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session)
|
||||
} else {
|
||||
return fmt.Errorf("not authorized")
|
||||
}
|
||||
|
|
|
@ -176,6 +176,7 @@ func (b *Builder) buildRouteConfig(_ context.Context, cfg *config.Config) (*envo
|
|||
ClusterSpecifier: &envoy_generic_proxy_action_v3.RouteAction_Cluster{
|
||||
Cluster: clusterId,
|
||||
},
|
||||
Timeout: durationpb.New(0),
|
||||
}),
|
||||
},
|
||||
},
|
||||
|
|
2
go.mod
2
go.mod
|
@ -254,7 +254,7 @@ require (
|
|||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap/exp v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
||||
golang.org/x/mod v0.20.0 // indirect
|
||||
golang.org/x/mod v0.21.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/tools v0.24.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -817,8 +817,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
|||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
"unique"
|
||||
|
||||
"github.com/VictoriaMetrics/fastcache"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
@ -20,6 +21,7 @@ type Cache interface {
|
|||
Invalidate(key []byte)
|
||||
InvalidateAll()
|
||||
Set(expiry time.Time, key, value []byte)
|
||||
Wait(key []byte) <-chan struct{}
|
||||
}
|
||||
|
||||
type globalCache struct {
|
||||
|
@ -28,6 +30,7 @@ type globalCache struct {
|
|||
singleflight singleflight.Group
|
||||
mu sync.RWMutex
|
||||
fastcache *fastcache.Cache
|
||||
waiters map[unique.Handle[string]]chan struct{}
|
||||
}
|
||||
|
||||
// NewGlobalCache creates a new Cache backed by fastcache and a TTL.
|
||||
|
@ -35,6 +38,7 @@ func NewGlobalCache(ttl time.Duration) Cache {
|
|||
return &globalCache{
|
||||
ttl: ttl,
|
||||
fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM
|
||||
waiters: map[unique.Handle[string]]chan struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,12 +75,40 @@ func (cache *globalCache) GetOrUpdate(
|
|||
func (cache *globalCache) Invalidate(key []byte) {
|
||||
cache.mu.Lock()
|
||||
cache.fastcache.Del(key)
|
||||
keyHandle := unique.Make(string(key))
|
||||
if c, ok := cache.waiters[keyHandle]; ok {
|
||||
close(c)
|
||||
delete(cache.waiters, keyHandle)
|
||||
}
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
var expiredC = make(chan struct{})
|
||||
|
||||
func init() {
|
||||
close(expiredC)
|
||||
}
|
||||
|
||||
func (cache *globalCache) Wait(key []byte) <-chan struct{} {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
if !cache.fastcache.Has(key) {
|
||||
return expiredC
|
||||
}
|
||||
keyHandle := unique.Make(string(key))
|
||||
if _, ok := cache.waiters[keyHandle]; !ok {
|
||||
cache.waiters[keyHandle] = make(chan struct{})
|
||||
}
|
||||
return cache.waiters[keyHandle]
|
||||
}
|
||||
|
||||
func (cache *globalCache) InvalidateAll() {
|
||||
cache.mu.Lock()
|
||||
cache.fastcache.Reset()
|
||||
for _, c := range cache.waiters {
|
||||
close(c)
|
||||
}
|
||||
clear(cache.waiters)
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue