mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-17 18:22:54 +02:00
close ssh connection when session is revoked
This commit is contained in:
parent
8eff4a48a4
commit
19b67bf32d
7 changed files with 81 additions and 14 deletions
|
@ -24,8 +24,10 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
|
@ -38,6 +40,7 @@ type Authorize struct {
|
||||||
currentConfig *atomicutil.Value[*config.Config]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
accessTracker *AccessTracker
|
accessTracker *AccessTracker
|
||||||
groupsCacheWarmer *cacheWarmer
|
groupsCacheWarmer *cacheWarmer
|
||||||
|
sessionsCacheWarmer *cacheWarmer
|
||||||
|
|
||||||
tracerProvider oteltrace.TracerProvider
|
tracerProvider oteltrace.TracerProvider
|
||||||
tracer oteltrace.Tracer
|
tracer oteltrace.Tracer
|
||||||
|
@ -67,6 +70,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
a.state = atomicutil.NewValue(state)
|
a.state = atomicutil.NewValue(state)
|
||||||
|
|
||||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
|
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
|
||||||
|
a.sessionsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, protoutil.GetTypeURL(&session.Session{}))
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,6 +90,10 @@ func (a *Authorize) Run(ctx context.Context) error {
|
||||||
a.groupsCacheWarmer.Run(ctx)
|
a.groupsCacheWarmer.Run(ctx)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
a.sessionsCacheWarmer.Run(ctx)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
return eg.Wait()
|
return eg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -173,6 +181,7 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
|
|
||||||
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
|
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
|
||||||
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
|
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
|
||||||
|
a.sessionsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,6 +111,11 @@ func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersi
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
|
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if record.DeletedAt != nil {
|
||||||
|
log.Ctx(ctx).Info().Msg("record deleted: " + record.Id)
|
||||||
|
h.cache.Invalidate(key)
|
||||||
|
return
|
||||||
|
}
|
||||||
value, err := storage.MarshalQueryResponse(res)
|
value, err := storage.MarshalQueryResponse(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
|
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
|
||||||
|
|
|
@ -29,7 +29,9 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/identity"
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||||
|
@ -278,7 +280,7 @@ func (a *Authorize) ManageStream(
|
||||||
sendC <- handleEvaluatorResponseForSSH(res, state)
|
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||||
|
|
||||||
if res.Allow.Value && !res.Deny.Value {
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
a.startContinuousAuthorization(ctx, errC, req, session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -411,7 +413,7 @@ func (a *Authorize) ManageStream(
|
||||||
sendC <- handleEvaluatorResponseForSSH(res, state)
|
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||||
|
|
||||||
if res.Allow.Value && !res.Deny.Value {
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
|
a.startContinuousAuthorization(ctx, errC, req, state.Session)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
resp := extensions_ssh.ServerMessage{
|
resp := extensions_ssh.ServerMessage{
|
||||||
|
@ -658,20 +660,35 @@ func (a *Authorize) startContinuousAuthorization(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
errC chan<- error,
|
errC chan<- error,
|
||||||
req *evaluator.Request,
|
req *evaluator.Request,
|
||||||
sessionID string,
|
session *session.Session,
|
||||||
) {
|
) {
|
||||||
recheck := func() {
|
recheck := func() {
|
||||||
// XXX: probably want to log the results of this evaluation only if it changes
|
// XXX: probably want to log the results of this evaluation only if it changes
|
||||||
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
|
res, _ := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
|
||||||
if !res.Allow.Value || res.Deny.Value {
|
if !res.Allow.Value || res.Deny.Value {
|
||||||
errC <- fmt.Errorf("no longer authorized")
|
errC <- fmt.Errorf("no longer authorized")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Second)
|
keyReq := &databroker.QueryRequest{
|
||||||
|
Type: grpcutil.GetTypeURL(session),
|
||||||
|
Limit: 1,
|
||||||
|
}
|
||||||
|
keyReq.SetFilterByIDOrIndex(session.Id)
|
||||||
|
key, err := (&proto.MarshalOptions{
|
||||||
|
Deterministic: true,
|
||||||
|
}).Marshal(keyReq)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-a.sessionsCacheWarmer.cache.Wait(key):
|
||||||
|
errC <- fmt.Errorf("session expired")
|
||||||
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
recheck()
|
recheck()
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -1177,6 +1194,9 @@ func (a *Authorize) NewPortalCommand(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if answer.(model).choice == "" {
|
||||||
|
return nil // quit/ctrl+c
|
||||||
|
}
|
||||||
var handOff *anypb.Any
|
var handOff *anypb.Any
|
||||||
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
|
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
|
||||||
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
|
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
|
||||||
|
@ -1216,7 +1236,7 @@ func (a *Authorize) NewPortalCommand(
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Allow.Value && !res.Deny.Value {
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
|
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("not authorized")
|
return fmt.Errorf("not authorized")
|
||||||
}
|
}
|
||||||
|
|
|
@ -176,6 +176,7 @@ func (b *Builder) buildRouteConfig(_ context.Context, cfg *config.Config) (*envo
|
||||||
ClusterSpecifier: &envoy_generic_proxy_action_v3.RouteAction_Cluster{
|
ClusterSpecifier: &envoy_generic_proxy_action_v3.RouteAction_Cluster{
|
||||||
Cluster: clusterId,
|
Cluster: clusterId,
|
||||||
},
|
},
|
||||||
|
Timeout: durationpb.New(0),
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -254,7 +254,7 @@ require (
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
go.uber.org/zap/exp v0.3.0 // indirect
|
go.uber.org/zap/exp v0.3.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
||||||
golang.org/x/mod v0.20.0 // indirect
|
golang.org/x/mod v0.21.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
golang.org/x/tools v0.24.0 // indirect
|
golang.org/x/tools v0.24.0 // indirect
|
||||||
google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect
|
google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -817,8 +817,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
"unique"
|
||||||
|
|
||||||
"github.com/VictoriaMetrics/fastcache"
|
"github.com/VictoriaMetrics/fastcache"
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
|
@ -20,6 +21,7 @@ type Cache interface {
|
||||||
Invalidate(key []byte)
|
Invalidate(key []byte)
|
||||||
InvalidateAll()
|
InvalidateAll()
|
||||||
Set(expiry time.Time, key, value []byte)
|
Set(expiry time.Time, key, value []byte)
|
||||||
|
Wait(key []byte) <-chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type globalCache struct {
|
type globalCache struct {
|
||||||
|
@ -28,6 +30,7 @@ type globalCache struct {
|
||||||
singleflight singleflight.Group
|
singleflight singleflight.Group
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
fastcache *fastcache.Cache
|
fastcache *fastcache.Cache
|
||||||
|
waiters map[unique.Handle[string]]chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGlobalCache creates a new Cache backed by fastcache and a TTL.
|
// NewGlobalCache creates a new Cache backed by fastcache and a TTL.
|
||||||
|
@ -35,6 +38,7 @@ func NewGlobalCache(ttl time.Duration) Cache {
|
||||||
return &globalCache{
|
return &globalCache{
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM
|
fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM
|
||||||
|
waiters: map[unique.Handle[string]]chan struct{}{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,12 +75,40 @@ func (cache *globalCache) GetOrUpdate(
|
||||||
func (cache *globalCache) Invalidate(key []byte) {
|
func (cache *globalCache) Invalidate(key []byte) {
|
||||||
cache.mu.Lock()
|
cache.mu.Lock()
|
||||||
cache.fastcache.Del(key)
|
cache.fastcache.Del(key)
|
||||||
|
keyHandle := unique.Make(string(key))
|
||||||
|
if c, ok := cache.waiters[keyHandle]; ok {
|
||||||
|
close(c)
|
||||||
|
delete(cache.waiters, keyHandle)
|
||||||
|
}
|
||||||
cache.mu.Unlock()
|
cache.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var expiredC = make(chan struct{})
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
close(expiredC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *globalCache) Wait(key []byte) <-chan struct{} {
|
||||||
|
cache.mu.Lock()
|
||||||
|
defer cache.mu.Unlock()
|
||||||
|
if !cache.fastcache.Has(key) {
|
||||||
|
return expiredC
|
||||||
|
}
|
||||||
|
keyHandle := unique.Make(string(key))
|
||||||
|
if _, ok := cache.waiters[keyHandle]; !ok {
|
||||||
|
cache.waiters[keyHandle] = make(chan struct{})
|
||||||
|
}
|
||||||
|
return cache.waiters[keyHandle]
|
||||||
|
}
|
||||||
|
|
||||||
func (cache *globalCache) InvalidateAll() {
|
func (cache *globalCache) InvalidateAll() {
|
||||||
cache.mu.Lock()
|
cache.mu.Lock()
|
||||||
cache.fastcache.Reset()
|
cache.fastcache.Reset()
|
||||||
|
for _, c := range cache.waiters {
|
||||||
|
close(c)
|
||||||
|
}
|
||||||
|
clear(cache.waiters)
|
||||||
cache.mu.Unlock()
|
cache.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue