close ssh connection when session is revoked

This commit is contained in:
Joe Kralicky 2025-03-26 21:11:08 +00:00
parent 8eff4a48a4
commit 19b67bf32d
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
7 changed files with 81 additions and 14 deletions

View file

@ -29,7 +29,9 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/identity/oauth"
@ -278,7 +280,7 @@ func (a *Authorize) ManageStream(
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id)
a.startContinuousAuthorization(ctx, errC, req, session)
}
}
@ -411,7 +413,7 @@ func (a *Authorize) ManageStream(
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
a.startContinuousAuthorization(ctx, errC, req, state.Session)
}
} else {
resp := extensions_ssh.ServerMessage{
@ -658,20 +660,35 @@ func (a *Authorize) startContinuousAuthorization(
ctx context.Context,
errC chan<- error,
req *evaluator.Request,
sessionID string,
session *session.Session,
) {
recheck := func() {
// XXX: probably want to log the results of this evaluation only if it changes
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
res, _ := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
if !res.Allow.Value || res.Deny.Value {
errC <- fmt.Errorf("no longer authorized")
}
}
ticker := time.NewTicker(time.Second)
keyReq := &databroker.QueryRequest{
Type: grpcutil.GetTypeURL(session),
Limit: 1,
}
keyReq.SetFilterByIDOrIndex(session.Id)
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(keyReq)
if err != nil {
panic(err)
}
ticker := time.NewTicker(10 * time.Second)
go func() {
for {
select {
case <-a.sessionsCacheWarmer.cache.Wait(key):
errC <- fmt.Errorf("session expired")
return
case <-ticker.C:
recheck()
case <-ctx.Done():
@ -1177,6 +1194,9 @@ func (a *Authorize) NewPortalCommand(
if err != nil {
return err
}
if answer.(model).choice == "" {
return nil // quit/ctrl+c
}
var handOff *anypb.Any
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
@ -1216,7 +1236,7 @@ func (a *Authorize) NewPortalCommand(
}
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session)
} else {
return fmt.Errorf("not authorized")
}