refactor session state

This commit is contained in:
Joe Kralicky 2025-03-21 21:34:33 +00:00
parent 3225d3b032
commit 315ee2610f
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
2 changed files with 118 additions and 59 deletions

View file

@ -6,7 +6,6 @@ import (
"context"
"fmt"
"slices"
"sync"
"time"
"github.com/rs/zerolog"
@ -45,8 +44,7 @@ type Authorize struct {
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
activeStreamsMu sync.Mutex
activeStreams []chan error
activeStreams ActiveStreams
}
// New validates and creates a new Authorize service from a set of config options.
@ -59,7 +57,9 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
globalCache: storage.NewGlobalCache(time.Minute),
tracerProvider: tracerProvider,
tracer: tracer,
activeStreams: []chan error{},
activeStreams: ActiveStreams{
streamsById: map[uint64]*StreamState{},
},
}
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
@ -167,15 +167,6 @@ func newPolicyEvaluator(
// OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
a.activeStreamsMu.Lock()
// demo code
if cfg.Options.Routes[0].AllowAnyAuthenticatedUser == false {
for _, s := range a.activeStreams {
s <- fmt.Errorf("no longer authorized")
}
clear(a.activeStreams)
}
a.activeStreamsMu.Unlock()
currentState := a.state.Load()
a.currentConfig.Store(cfg)
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {

View file

@ -51,11 +51,46 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
)
type ActiveStreams struct {
mu sync.Mutex
streamsById map[uint64]*StreamState
}
type StreamState struct {
Context context.Context
StreamID uint64
ErrorC chan<- error
Username string
Hostname string
PublicKey []byte
MethodsAuthenticated []string
Session *session.Session
}
func (a *ActiveStreams) Get(id uint64) *StreamState {
a.mu.Lock()
defer a.mu.Unlock()
return a.streamsById[id]
}
func (a *ActiveStreams) Put(id uint64, state *StreamState) {
a.mu.Lock()
defer a.mu.Unlock()
a.streamsById[id] = state
}
func (a *ActiveStreams) Delete(id uint64) {
a.mu.Lock()
defer a.mu.Unlock()
delete(a.streamsById, id)
}
func (a *ActiveStreams) Range(f func(id uint64, state *StreamState)) {
a.mu.Lock()
defer a.mu.Unlock()
for id, state := range a.streamsById {
f(id, state)
}
}
func (a *Authorize) RecordingFinalized(
@ -126,8 +161,6 @@ READ:
return nil
}
var activeStreamIds sync.Map
func (a *Authorize) ManageStream(
server extensions_ssh.StreamManagement_ManageStreamServer,
) error {
@ -171,15 +204,13 @@ func (a *Authorize) ManageStream(
}
})
var state StreamState
errC := make(chan error, 1)
state := &StreamState{
Context: ctx,
ErrorC: errC,
}
deviceAuthDone := make(chan struct{})
sessionState := &atomic.Pointer[sessions.State]{}
errC := make(chan error, 1)
a.activeStreamsMu.Lock()
a.activeStreams = append(a.activeStreams, errC)
a.activeStreamsMu.Unlock()
for {
select {
case err := <-errC:
@ -192,10 +223,14 @@ func (a *Authorize) ManageStream(
case *extensions_ssh.ClientMessage_Event:
switch event := req.Event.Event.(type) {
case *extensions_ssh.StreamEvent_DownstreamConnected:
_ = event
id := event.DownstreamConnected.StreamId
if id == 0 {
return fmt.Errorf("invalid stream ID: %v", id)
}
state.StreamID = id
a.activeStreams.Put(id, state)
defer a.activeStreams.Delete(id)
case *extensions_ssh.StreamEvent_UpstreamConnected:
activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state)
defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId())
case nil:
}
case *extensions_ssh.ClientMessage_AuthRequest:
@ -230,8 +265,9 @@ func (a *Authorize) ManageStream(
}
if session != nil {
state.Session = session
// Perform authorize check for this route
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
if err != nil {
return err
}
@ -239,7 +275,7 @@ func (a *Authorize) ManageStream(
if err != nil {
return err
}
sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id)
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id)
@ -263,9 +299,7 @@ func (a *Authorize) ManageStream(
}
case "keyboard-interactive":
route := a.getSSHRouteForHostname(state.Hostname)
// if route == nil {
// return fmt.Errorf("invalid route")
// }
// route can be nil, in which case the default idp will be used
opts := a.currentConfig.Load().Options
idp, err := opts.GetIdentityProviderForPolicy(route)
@ -292,9 +326,7 @@ func (a *Authorize) ManageStream(
infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
Name: "Sign in with " + idp.GetType(),
Instruction: deviceAuthResp.VerificationURIComplete,
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
// {},
},
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{},
}
infoReqAny, _ := anypb.New(&infoReq)
@ -328,13 +360,12 @@ func (a *Authorize) ManageStream(
return
}
fmt.Println(token)
err = a.PersistSession(ctx, s, claims, token)
state.Session, err = a.PersistSession(ctx, s, claims, token)
if err != nil {
fmt.Println("error from PersistSession:", err)
errC <- fmt.Errorf("error persisting session: %w", err)
return
}
sessionState.Store(s)
close(deviceAuthDone)
}()
}
@ -351,7 +382,7 @@ func (a *Authorize) ManageStream(
case <-deviceAuthDone:
case <-ctx.Done():
}
if sessionState.Load() != nil {
if state.Session != nil {
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
} else {
resp := extensions_ssh.ServerMessage{
@ -369,18 +400,18 @@ func (a *Authorize) ManageStream(
if slices.Contains(state.MethodsAuthenticated, "publickey") {
// Perform authorize check for this route
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
if err != nil {
return err
}
res, err := a.evaluate(ctx, req, sessionState.Load())
res, err := a.evaluate(ctx, req, &sessions.State{ID: state.Session.Id})
if err != nil {
return err
}
sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID)
sendC <- handleEvaluatorResponseForSSH(res, state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
}
} else {
resp := extensions_ssh.ServerMessage{
@ -476,7 +507,8 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
}
func handleEvaluatorResponseForSSH(
result *evaluator.Result, state *StreamState, sessionID string,
result *evaluator.Result,
state *StreamState,
) *extensions_ssh.ServerMessage {
// fmt.Printf(" *** evaluator result: %+v\n", result)
@ -499,7 +531,7 @@ func handleEvaluatorResponseForSSH(
FilterMetadata: map[string]*structpb.Struct{
"pomerium": {
Fields: map[string]*structpb.Value{
"session-id": structpb.NewStringValue(sessionID),
"stream-id": structpb.NewStringValue(strconv.FormatUint(state.StreamID, 10)),
},
},
},
@ -592,7 +624,7 @@ func (a *Authorize) PersistSession(
sessionState *sessions.State, // XXX: consider not using this struct
claims identity.SessionClaims,
accessToken *oauth2.Token,
) error {
) (*session.Session, error) {
now := time.Now()
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
@ -614,12 +646,12 @@ func (a *Authorize) PersistSession(
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
if err != nil {
return err
return nil, err
}
sessionState.DatabrokerServerVersion = res.GetServerVersion()
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
return nil
return sess, nil
}
func (a *Authorize) startContinuousAuthorization(
@ -731,7 +763,7 @@ func (a *Authorize) ServeChannel(
}
}()
var sessionID atomic.Pointer[string]
var state *StreamState
go func() {
localWindow := uint32(channelWindowSize)
for {
@ -744,14 +776,28 @@ func (a *Authorize) ServeChannel(
errC <- err
return
}
if sessionID.Load() == nil {
if state == nil {
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
if !ok {
errC <- fmt.Errorf("first message was not metadata")
return
}
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue()
sessionID.Store(&id)
idStr := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["stream-id"].GetStringValue()
if idStr == "" {
errC <- fmt.Errorf("no session ID found for stream %q", idStr)
return
}
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
errC <- fmt.Errorf("invalid stream ID %q: %w", idStr, err)
return
}
if v := a.activeStreams.Get(id); v != nil {
state = v
} else {
errC <- fmt.Errorf("no stream state found for ID %d", id)
return
}
continue
}
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
@ -877,7 +923,7 @@ func (a *Authorize) ServeChannel(
return err
}
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram)
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, state, inputR, outputW, sendC, &activeProgram)
if msg.Request == "shell" {
cmd.SetArgs([]string{"portal"})
} else {
@ -995,10 +1041,10 @@ func (a *Authorize) NewSSHCLI(
cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
sessionID string,
state *StreamState,
stdin io.Reader,
stdout io.Writer,
sendC chan any,
sendC chan<- any,
activeProgram *atomic.Pointer[tea.Program],
) *cobra.Command {
cmd := &cobra.Command{
@ -1016,7 +1062,8 @@ func (a *Authorize) NewSSHCLI(
return nil
},
}
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
sessionID := state.Session.Id
cmd.AddCommand(a.NewPortalCommand(cfg, ptyInfo, channelInfo, state, sendC, activeProgram))
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
cmd.CompletionOptions.DisableDefaultCmd = true
@ -1077,11 +1124,12 @@ func (a *Authorize) NewWhoamiCommand(
return cmd
}
func NewPortalCommand(
func (a *Authorize) NewPortalCommand(
cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
sendC chan any,
state *StreamState,
sendC chan<- any,
activeProgram *atomic.Pointer[tea.Program],
) *cobra.Command {
cmd := &cobra.Command{
@ -1094,16 +1142,17 @@ func NewPortalCommand(
var routes []string
for r := range cfg.Options.GetAllPolicies() {
if strings.HasPrefix(r.From, "ssh://") {
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
routes = append(routes, fmt.Sprintf("%s@%s", state.Username, strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
}
}
items := []list.Item{}
for _, route := range routes {
items = append(items, item(route))
}
activeStreamIds.Range(func(key, value any) bool {
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
return true
a.activeStreams.Range(func(id uint64, state *StreamState) {
if id != state.StreamID {
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", id)))
}
})
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
@ -1152,6 +1201,25 @@ func NewPortalCommand(
})
} else {
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
// Perform authorize check for this route
state.Hostname = hostname
if username != state.Username {
return fmt.Errorf("internal error: username mismatch")
}
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
if err != nil {
return err
}
res, err := a.evaluate(cmd.Context(), req, &sessions.State{ID: state.Session.Id})
if err != nil {
return err
}
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
} else {
return fmt.Errorf("not authorized")
}
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,