close ssh connection when session is revoked

This commit is contained in:
Joe Kralicky 2025-03-26 21:11:08 +00:00
parent 8eff4a48a4
commit 19b67bf32d
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
7 changed files with 81 additions and 14 deletions

View file

@ -24,8 +24,10 @@ import (
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
@ -33,11 +35,12 @@ import (
// Authorize struct holds
type Authorize struct {
state *atomicutil.Value[*authorizeState]
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
groupsCacheWarmer *cacheWarmer
state *atomicutil.Value[*authorizeState]
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
groupsCacheWarmer *cacheWarmer
sessionsCacheWarmer *cacheWarmer
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
@ -67,6 +70,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
a.state = atomicutil.NewValue(state)
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
a.sessionsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, protoutil.GetTypeURL(&session.Session{}))
return a, nil
}
@ -86,6 +90,10 @@ func (a *Authorize) Run(ctx context.Context) error {
a.groupsCacheWarmer.Run(ctx)
return nil
})
eg.Go(func() error {
a.sessionsCacheWarmer.Run(ctx)
return nil
})
return eg.Wait()
}
@ -173,6 +181,7 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
a.sessionsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
}
}
}

View file

@ -111,6 +111,11 @@ func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersi
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
continue
}
if record.DeletedAt != nil {
log.Ctx(ctx).Info().Msg("record deleted: " + record.Id)
h.cache.Invalidate(key)
return
}
value, err := storage.MarshalQueryResponse(res)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")

View file

@ -29,7 +29,9 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/identity/oauth"
@ -278,7 +280,7 @@ func (a *Authorize) ManageStream(
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id)
a.startContinuousAuthorization(ctx, errC, req, session)
}
}
@ -411,7 +413,7 @@ func (a *Authorize) ManageStream(
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
a.startContinuousAuthorization(ctx, errC, req, state.Session)
}
} else {
resp := extensions_ssh.ServerMessage{
@ -658,20 +660,35 @@ func (a *Authorize) startContinuousAuthorization(
ctx context.Context,
errC chan<- error,
req *evaluator.Request,
sessionID string,
session *session.Session,
) {
recheck := func() {
// XXX: probably want to log the results of this evaluation only if it changes
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
res, _ := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
if !res.Allow.Value || res.Deny.Value {
errC <- fmt.Errorf("no longer authorized")
}
}
ticker := time.NewTicker(time.Second)
keyReq := &databroker.QueryRequest{
Type: grpcutil.GetTypeURL(session),
Limit: 1,
}
keyReq.SetFilterByIDOrIndex(session.Id)
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(keyReq)
if err != nil {
panic(err)
}
ticker := time.NewTicker(10 * time.Second)
go func() {
for {
select {
case <-a.sessionsCacheWarmer.cache.Wait(key):
errC <- fmt.Errorf("session expired")
return
case <-ticker.C:
recheck()
case <-ctx.Done():
@ -1177,6 +1194,9 @@ func (a *Authorize) NewPortalCommand(
if err != nil {
return err
}
if answer.(model).choice == "" {
return nil // quit/ctrl+c
}
var handOff *anypb.Any
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
@ -1216,7 +1236,7 @@ func (a *Authorize) NewPortalCommand(
}
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session)
} else {
return fmt.Errorf("not authorized")
}