refactor session state

This commit is contained in:
Joe Kralicky 2025-03-21 21:34:33 +00:00
parent 3225d3b032
commit 315ee2610f
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
2 changed files with 118 additions and 59 deletions

View file

@ -6,7 +6,6 @@ import (
"context" "context"
"fmt" "fmt"
"slices" "slices"
"sync"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -45,8 +44,7 @@ type Authorize struct {
tracerProvider oteltrace.TracerProvider tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer tracer oteltrace.Tracer
activeStreamsMu sync.Mutex activeStreams ActiveStreams
activeStreams []chan error
} }
// New validates and creates a new Authorize service from a set of config options. // New validates and creates a new Authorize service from a set of config options.
@ -59,7 +57,9 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
globalCache: storage.NewGlobalCache(time.Minute), globalCache: storage.NewGlobalCache(time.Minute),
tracerProvider: tracerProvider, tracerProvider: tracerProvider,
tracer: tracer, tracer: tracer,
activeStreams: []chan error{}, activeStreams: ActiveStreams{
streamsById: map[uint64]*StreamState{},
},
} }
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
@ -167,15 +167,6 @@ func newPolicyEvaluator(
// OnConfigChange updates internal structures based on config.Options // OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
a.activeStreamsMu.Lock()
// demo code
if cfg.Options.Routes[0].AllowAnyAuthenticatedUser == false {
for _, s := range a.activeStreams {
s <- fmt.Errorf("no longer authorized")
}
clear(a.activeStreams)
}
a.activeStreamsMu.Unlock()
currentState := a.state.Load() currentState := a.state.Load()
a.currentConfig.Store(cfg) a.currentConfig.Store(cfg)
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {

View file

@ -51,11 +51,46 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb" "google.golang.org/protobuf/types/known/wrapperspb"
) )
type ActiveStreams struct {
mu sync.Mutex
streamsById map[uint64]*StreamState
}
type StreamState struct { type StreamState struct {
Context context.Context
StreamID uint64
ErrorC chan<- error
Username string Username string
Hostname string Hostname string
PublicKey []byte PublicKey []byte
MethodsAuthenticated []string MethodsAuthenticated []string
Session *session.Session
}
func (a *ActiveStreams) Get(id uint64) *StreamState {
a.mu.Lock()
defer a.mu.Unlock()
return a.streamsById[id]
}
func (a *ActiveStreams) Put(id uint64, state *StreamState) {
a.mu.Lock()
defer a.mu.Unlock()
a.streamsById[id] = state
}
func (a *ActiveStreams) Delete(id uint64) {
a.mu.Lock()
defer a.mu.Unlock()
delete(a.streamsById, id)
}
func (a *ActiveStreams) Range(f func(id uint64, state *StreamState)) {
a.mu.Lock()
defer a.mu.Unlock()
for id, state := range a.streamsById {
f(id, state)
}
} }
func (a *Authorize) RecordingFinalized( func (a *Authorize) RecordingFinalized(
@ -126,8 +161,6 @@ READ:
return nil return nil
} }
var activeStreamIds sync.Map
func (a *Authorize) ManageStream( func (a *Authorize) ManageStream(
server extensions_ssh.StreamManagement_ManageStreamServer, server extensions_ssh.StreamManagement_ManageStreamServer,
) error { ) error {
@ -171,15 +204,13 @@ func (a *Authorize) ManageStream(
} }
}) })
var state StreamState errC := make(chan error, 1)
state := &StreamState{
Context: ctx,
ErrorC: errC,
}
deviceAuthDone := make(chan struct{}) deviceAuthDone := make(chan struct{})
sessionState := &atomic.Pointer[sessions.State]{}
errC := make(chan error, 1)
a.activeStreamsMu.Lock()
a.activeStreams = append(a.activeStreams, errC)
a.activeStreamsMu.Unlock()
for { for {
select { select {
case err := <-errC: case err := <-errC:
@ -192,10 +223,14 @@ func (a *Authorize) ManageStream(
case *extensions_ssh.ClientMessage_Event: case *extensions_ssh.ClientMessage_Event:
switch event := req.Event.Event.(type) { switch event := req.Event.Event.(type) {
case *extensions_ssh.StreamEvent_DownstreamConnected: case *extensions_ssh.StreamEvent_DownstreamConnected:
_ = event id := event.DownstreamConnected.StreamId
if id == 0 {
return fmt.Errorf("invalid stream ID: %v", id)
}
state.StreamID = id
a.activeStreams.Put(id, state)
defer a.activeStreams.Delete(id)
case *extensions_ssh.StreamEvent_UpstreamConnected: case *extensions_ssh.StreamEvent_UpstreamConnected:
activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state)
defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId())
case nil: case nil:
} }
case *extensions_ssh.ClientMessage_AuthRequest: case *extensions_ssh.ClientMessage_AuthRequest:
@ -230,8 +265,9 @@ func (a *Authorize) ManageStream(
} }
if session != nil { if session != nil {
state.Session = session
// Perform authorize check for this route // Perform authorize check for this route
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state) req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
if err != nil { if err != nil {
return err return err
} }
@ -239,7 +275,7 @@ func (a *Authorize) ManageStream(
if err != nil { if err != nil {
return err return err
} }
sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id) sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value { if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id) a.startContinuousAuthorization(ctx, errC, req, session.Id)
@ -263,9 +299,7 @@ func (a *Authorize) ManageStream(
} }
case "keyboard-interactive": case "keyboard-interactive":
route := a.getSSHRouteForHostname(state.Hostname) route := a.getSSHRouteForHostname(state.Hostname)
// if route == nil { // route can be nil, in which case the default idp will be used
// return fmt.Errorf("invalid route")
// }
opts := a.currentConfig.Load().Options opts := a.currentConfig.Load().Options
idp, err := opts.GetIdentityProviderForPolicy(route) idp, err := opts.GetIdentityProviderForPolicy(route)
@ -292,9 +326,7 @@ func (a *Authorize) ManageStream(
infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
Name: "Sign in with " + idp.GetType(), Name: "Sign in with " + idp.GetType(),
Instruction: deviceAuthResp.VerificationURIComplete, Instruction: deviceAuthResp.VerificationURIComplete,
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{},
// {},
},
} }
infoReqAny, _ := anypb.New(&infoReq) infoReqAny, _ := anypb.New(&infoReq)
@ -328,13 +360,12 @@ func (a *Authorize) ManageStream(
return return
} }
fmt.Println(token) fmt.Println(token)
err = a.PersistSession(ctx, s, claims, token) state.Session, err = a.PersistSession(ctx, s, claims, token)
if err != nil { if err != nil {
fmt.Println("error from PersistSession:", err) fmt.Println("error from PersistSession:", err)
errC <- fmt.Errorf("error persisting session: %w", err) errC <- fmt.Errorf("error persisting session: %w", err)
return return
} }
sessionState.Store(s)
close(deviceAuthDone) close(deviceAuthDone)
}() }()
} }
@ -351,7 +382,7 @@ func (a *Authorize) ManageStream(
case <-deviceAuthDone: case <-deviceAuthDone:
case <-ctx.Done(): case <-ctx.Done():
} }
if sessionState.Load() != nil { if state.Session != nil {
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive") state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
} else { } else {
resp := extensions_ssh.ServerMessage{ resp := extensions_ssh.ServerMessage{
@ -369,18 +400,18 @@ func (a *Authorize) ManageStream(
if slices.Contains(state.MethodsAuthenticated, "publickey") { if slices.Contains(state.MethodsAuthenticated, "publickey") {
// Perform authorize check for this route // Perform authorize check for this route
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state) req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
if err != nil { if err != nil {
return err return err
} }
res, err := a.evaluate(ctx, req, sessionState.Load()) res, err := a.evaluate(ctx, req, &sessions.State{ID: state.Session.Id})
if err != nil { if err != nil {
return err return err
} }
sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID) sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value { if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID) a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
} }
} else { } else {
resp := extensions_ssh.ServerMessage{ resp := extensions_ssh.ServerMessage{
@ -476,7 +507,8 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
} }
func handleEvaluatorResponseForSSH( func handleEvaluatorResponseForSSH(
result *evaluator.Result, state *StreamState, sessionID string, result *evaluator.Result,
state *StreamState,
) *extensions_ssh.ServerMessage { ) *extensions_ssh.ServerMessage {
// fmt.Printf(" *** evaluator result: %+v\n", result) // fmt.Printf(" *** evaluator result: %+v\n", result)
@ -499,7 +531,7 @@ func handleEvaluatorResponseForSSH(
FilterMetadata: map[string]*structpb.Struct{ FilterMetadata: map[string]*structpb.Struct{
"pomerium": { "pomerium": {
Fields: map[string]*structpb.Value{ Fields: map[string]*structpb.Value{
"session-id": structpb.NewStringValue(sessionID), "stream-id": structpb.NewStringValue(strconv.FormatUint(state.StreamID, 10)),
}, },
}, },
}, },
@ -592,7 +624,7 @@ func (a *Authorize) PersistSession(
sessionState *sessions.State, // XXX: consider not using this struct sessionState *sessions.State, // XXX: consider not using this struct
claims identity.SessionClaims, claims identity.SessionClaims,
accessToken *oauth2.Token, accessToken *oauth2.Token,
) error { ) (*session.Session, error) {
now := time.Now() now := time.Now()
sessionLifetime := a.currentConfig.Load().Options.CookieExpire sessionLifetime := a.currentConfig.Load().Options.CookieExpire
sessionExpiry := timestamppb.New(now.Add(sessionLifetime)) sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
@ -614,12 +646,12 @@ func (a *Authorize) PersistSession(
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess) res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
if err != nil { if err != nil {
return err return nil, err
} }
sessionState.DatabrokerServerVersion = res.GetServerVersion() sessionState.DatabrokerServerVersion = res.GetServerVersion()
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion() sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
return nil return sess, nil
} }
func (a *Authorize) startContinuousAuthorization( func (a *Authorize) startContinuousAuthorization(
@ -731,7 +763,7 @@ func (a *Authorize) ServeChannel(
} }
}() }()
var sessionID atomic.Pointer[string] var state *StreamState
go func() { go func() {
localWindow := uint32(channelWindowSize) localWindow := uint32(channelWindowSize)
for { for {
@ -744,14 +776,28 @@ func (a *Authorize) ServeChannel(
errC <- err errC <- err
return return
} }
if sessionID.Load() == nil { if state == nil {
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata) mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
if !ok { if !ok {
errC <- fmt.Errorf("first message was not metadata") errC <- fmt.Errorf("first message was not metadata")
return return
} }
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue() idStr := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["stream-id"].GetStringValue()
sessionID.Store(&id) if idStr == "" {
errC <- fmt.Errorf("no session ID found for stream %q", idStr)
return
}
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
errC <- fmt.Errorf("invalid stream ID %q: %w", idStr, err)
return
}
if v := a.activeStreams.Get(id); v != nil {
state = v
} else {
errC <- fmt.Errorf("no stream state found for ID %d", id)
return
}
continue continue
} }
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok { if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
@ -877,7 +923,7 @@ func (a *Authorize) ServeChannel(
return err return err
} }
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram) cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, state, inputR, outputW, sendC, &activeProgram)
if msg.Request == "shell" { if msg.Request == "shell" {
cmd.SetArgs([]string{"portal"}) cmd.SetArgs([]string{"portal"})
} else { } else {
@ -995,10 +1041,10 @@ func (a *Authorize) NewSSHCLI(
cfg *config.Config, cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
channelInfo *extensions_ssh.SSHDownstreamChannelInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
sessionID string, state *StreamState,
stdin io.Reader, stdin io.Reader,
stdout io.Writer, stdout io.Writer,
sendC chan any, sendC chan<- any,
activeProgram *atomic.Pointer[tea.Program], activeProgram *atomic.Pointer[tea.Program],
) *cobra.Command { ) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
@ -1016,7 +1062,8 @@ func (a *Authorize) NewSSHCLI(
return nil return nil
}, },
} }
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram)) sessionID := state.Session.Id
cmd.AddCommand(a.NewPortalCommand(cfg, ptyInfo, channelInfo, state, sendC, activeProgram))
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID)) cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID)) cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
cmd.CompletionOptions.DisableDefaultCmd = true cmd.CompletionOptions.DisableDefaultCmd = true
@ -1077,11 +1124,12 @@ func (a *Authorize) NewWhoamiCommand(
return cmd return cmd
} }
func NewPortalCommand( func (a *Authorize) NewPortalCommand(
cfg *config.Config, cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
channelInfo *extensions_ssh.SSHDownstreamChannelInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
sendC chan any, state *StreamState,
sendC chan<- any,
activeProgram *atomic.Pointer[tea.Program], activeProgram *atomic.Pointer[tea.Program],
) *cobra.Command { ) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
@ -1094,16 +1142,17 @@ func NewPortalCommand(
var routes []string var routes []string
for r := range cfg.Options.GetAllPolicies() { for r := range cfg.Options.GetAllPolicies() {
if strings.HasPrefix(r.From, "ssh://") { if strings.HasPrefix(r.From, "ssh://") {
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname))) routes = append(routes, fmt.Sprintf("%s@%s", state.Username, strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
} }
} }
items := []list.Item{} items := []list.Item{}
for _, route := range routes { for _, route := range routes {
items = append(items, item(route)) items = append(items, item(route))
} }
activeStreamIds.Range(func(key, value any) bool { a.activeStreams.Range(func(id uint64, state *StreamState) {
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key))) if id != state.StreamID {
return true items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", id)))
}
}) })
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2)) l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
@ -1152,6 +1201,25 @@ func NewPortalCommand(
}) })
} else { } else {
username, hostname, _ := strings.Cut(answer.(model).choice, "@") username, hostname, _ := strings.Cut(answer.(model).choice, "@")
// Perform authorize check for this route
state.Hostname = hostname
if username != state.Username {
return fmt.Errorf("internal error: username mismatch")
}
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
if err != nil {
return err
}
res, err := a.evaluate(cmd.Context(), req, &sessions.State{ID: state.Session.Id})
if err != nil {
return err
}
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
} else {
return fmt.Errorf("not authorized")
}
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat, Format: extensions_session_recording.Format_AsciicastFormat,