diff --git a/authorize/authorize.go b/authorize/authorize.go index 726780402..cadc71193 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "slices" - "sync" "time" "github.com/rs/zerolog" @@ -45,8 +44,7 @@ type Authorize struct { tracerProvider oteltrace.TracerProvider tracer oteltrace.Tracer - activeStreamsMu sync.Mutex - activeStreams []chan error + activeStreams ActiveStreams } // New validates and creates a new Authorize service from a set of config options. @@ -59,7 +57,9 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { globalCache: storage.NewGlobalCache(time.Minute), tracerProvider: tracerProvider, tracer: tracer, - activeStreams: []chan error{}, + activeStreams: ActiveStreams{ + streamsById: map[uint64]*StreamState{}, + }, } a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) @@ -167,15 +167,6 @@ func newPolicyEvaluator( // OnConfigChange updates internal structures based on config.Options func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { - a.activeStreamsMu.Lock() - // demo code - if cfg.Options.Routes[0].AllowAnyAuthenticatedUser == false { - for _, s := range a.activeStreams { - s <- fmt.Errorf("no longer authorized") - } - clear(a.activeStreams) - } - a.activeStreamsMu.Unlock() currentState := a.state.Load() a.currentConfig.Store(cfg) if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index 40cfaacb3..34a478bd4 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -51,11 +51,46 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" ) +type ActiveStreams struct { + mu sync.Mutex + streamsById map[uint64]*StreamState +} + type StreamState struct { + Context context.Context + StreamID uint64 + ErrorC chan<- error Username string Hostname string PublicKey []byte MethodsAuthenticated []string + Session *session.Session +} + +func (a *ActiveStreams) Get(id uint64) *StreamState { + a.mu.Lock() + defer a.mu.Unlock() + return a.streamsById[id] +} + +func (a *ActiveStreams) Put(id uint64, state *StreamState) { + a.mu.Lock() + defer a.mu.Unlock() + a.streamsById[id] = state +} + +func (a *ActiveStreams) Delete(id uint64) { + a.mu.Lock() + defer a.mu.Unlock() + delete(a.streamsById, id) +} + +func (a *ActiveStreams) Range(f func(id uint64, state *StreamState)) { + a.mu.Lock() + defer a.mu.Unlock() + for id, state := range a.streamsById { + f(id, state) + } } func (a *Authorize) RecordingFinalized( @@ -126,8 +161,6 @@ READ: return nil } -var activeStreamIds sync.Map - func (a *Authorize) ManageStream( server extensions_ssh.StreamManagement_ManageStreamServer, ) error { @@ -171,15 +204,13 @@ func (a *Authorize) ManageStream( } }) - var state StreamState + errC := make(chan error, 1) + state := &StreamState{ + Context: ctx, + ErrorC: errC, + } deviceAuthDone := make(chan struct{}) - sessionState := &atomic.Pointer[sessions.State]{} - - errC := make(chan error, 1) - a.activeStreamsMu.Lock() - a.activeStreams = append(a.activeStreams, errC) - a.activeStreamsMu.Unlock() for { select { case err := <-errC: @@ -192,10 +223,14 @@ func (a *Authorize) ManageStream( case *extensions_ssh.ClientMessage_Event: switch event := req.Event.Event.(type) { case *extensions_ssh.StreamEvent_DownstreamConnected: - _ = event + id := event.DownstreamConnected.StreamId + if id == 0 { + return fmt.Errorf("invalid stream ID: %v", id) + } + state.StreamID = id + a.activeStreams.Put(id, state) + defer a.activeStreams.Delete(id) case *extensions_ssh.StreamEvent_UpstreamConnected: - activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state) - defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId()) case nil: } case *extensions_ssh.ClientMessage_AuthRequest: @@ -230,8 +265,9 @@ func (a *Authorize) ManageStream( } if session != nil { + state.Session = session // Perform authorize check for this route - req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state) + req, err := a.getEvaluatorRequestFromSSHAuthRequest(state) if err != nil { return err } @@ -239,7 +275,7 @@ func (a *Authorize) ManageStream( if err != nil { return err } - sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id) + sendC <- handleEvaluatorResponseForSSH(res, state) if res.Allow.Value && !res.Deny.Value { a.startContinuousAuthorization(ctx, errC, req, session.Id) @@ -263,9 +299,7 @@ func (a *Authorize) ManageStream( } case "keyboard-interactive": route := a.getSSHRouteForHostname(state.Hostname) - // if route == nil { - // return fmt.Errorf("invalid route") - // } + // route can be nil, in which case the default idp will be used opts := a.currentConfig.Load().Options idp, err := opts.GetIdentityProviderForPolicy(route) @@ -292,9 +326,7 @@ func (a *Authorize) ManageStream( infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ Name: "Sign in with " + idp.GetType(), Instruction: deviceAuthResp.VerificationURIComplete, - Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ - // {}, - }, + Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{}, } infoReqAny, _ := anypb.New(&infoReq) @@ -328,13 +360,12 @@ func (a *Authorize) ManageStream( return } fmt.Println(token) - err = a.PersistSession(ctx, s, claims, token) + state.Session, err = a.PersistSession(ctx, s, claims, token) if err != nil { fmt.Println("error from PersistSession:", err) errC <- fmt.Errorf("error persisting session: %w", err) return } - sessionState.Store(s) close(deviceAuthDone) }() } @@ -351,7 +382,7 @@ func (a *Authorize) ManageStream( case <-deviceAuthDone: case <-ctx.Done(): } - if sessionState.Load() != nil { + if state.Session != nil { state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive") } else { resp := extensions_ssh.ServerMessage{ @@ -369,18 +400,18 @@ func (a *Authorize) ManageStream( if slices.Contains(state.MethodsAuthenticated, "publickey") { // Perform authorize check for this route - req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state) + req, err := a.getEvaluatorRequestFromSSHAuthRequest(state) if err != nil { return err } - res, err := a.evaluate(ctx, req, sessionState.Load()) + res, err := a.evaluate(ctx, req, &sessions.State{ID: state.Session.Id}) if err != nil { return err } - sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID) + sendC <- handleEvaluatorResponseForSSH(res, state) if res.Allow.Value && !res.Deny.Value { - a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID) + a.startContinuousAuthorization(ctx, errC, req, state.Session.Id) } } else { resp := extensions_ssh.ServerMessage{ @@ -476,7 +507,8 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest( } func handleEvaluatorResponseForSSH( - result *evaluator.Result, state *StreamState, sessionID string, + result *evaluator.Result, + state *StreamState, ) *extensions_ssh.ServerMessage { // fmt.Printf(" *** evaluator result: %+v\n", result) @@ -499,7 +531,7 @@ func handleEvaluatorResponseForSSH( FilterMetadata: map[string]*structpb.Struct{ "pomerium": { Fields: map[string]*structpb.Value{ - "session-id": structpb.NewStringValue(sessionID), + "stream-id": structpb.NewStringValue(strconv.FormatUint(state.StreamID, 10)), }, }, }, @@ -592,7 +624,7 @@ func (a *Authorize) PersistSession( sessionState *sessions.State, // XXX: consider not using this struct claims identity.SessionClaims, accessToken *oauth2.Token, -) error { +) (*session.Session, error) { now := time.Now() sessionLifetime := a.currentConfig.Load().Options.CookieExpire sessionExpiry := timestamppb.New(now.Add(sessionLifetime)) @@ -614,12 +646,12 @@ func (a *Authorize) PersistSession( res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess) if err != nil { - return err + return nil, err } sessionState.DatabrokerServerVersion = res.GetServerVersion() sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion() - return nil + return sess, nil } func (a *Authorize) startContinuousAuthorization( @@ -731,7 +763,7 @@ func (a *Authorize) ServeChannel( } }() - var sessionID atomic.Pointer[string] + var state *StreamState go func() { localWindow := uint32(channelWindowSize) for { @@ -744,14 +776,28 @@ func (a *Authorize) ServeChannel( errC <- err return } - if sessionID.Load() == nil { + if state == nil { mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata) if !ok { errC <- fmt.Errorf("first message was not metadata") return } - id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue() - sessionID.Store(&id) + idStr := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["stream-id"].GetStringValue() + if idStr == "" { + errC <- fmt.Errorf("no session ID found for stream %q", idStr) + return + } + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + errC <- fmt.Errorf("invalid stream ID %q: %w", idStr, err) + return + } + if v := a.activeStreams.Get(id); v != nil { + state = v + } else { + errC <- fmt.Errorf("no stream state found for ID %d", id) + return + } continue } if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok { @@ -877,7 +923,7 @@ func (a *Authorize) ServeChannel( return err } - cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram) + cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, state, inputR, outputW, sendC, &activeProgram) if msg.Request == "shell" { cmd.SetArgs([]string{"portal"}) } else { @@ -995,10 +1041,10 @@ func (a *Authorize) NewSSHCLI( cfg *config.Config, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo, - sessionID string, + state *StreamState, stdin io.Reader, stdout io.Writer, - sendC chan any, + sendC chan<- any, activeProgram *atomic.Pointer[tea.Program], ) *cobra.Command { cmd := &cobra.Command{ @@ -1016,7 +1062,8 @@ func (a *Authorize) NewSSHCLI( return nil }, } - cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram)) + sessionID := state.Session.Id + cmd.AddCommand(a.NewPortalCommand(cfg, ptyInfo, channelInfo, state, sendC, activeProgram)) cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID)) cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID)) cmd.CompletionOptions.DisableDefaultCmd = true @@ -1077,11 +1124,12 @@ func (a *Authorize) NewWhoamiCommand( return cmd } -func NewPortalCommand( +func (a *Authorize) NewPortalCommand( cfg *config.Config, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo, - sendC chan any, + state *StreamState, + sendC chan<- any, activeProgram *atomic.Pointer[tea.Program], ) *cobra.Command { cmd := &cobra.Command{ @@ -1094,16 +1142,17 @@ func NewPortalCommand( var routes []string for r := range cfg.Options.GetAllPolicies() { if strings.HasPrefix(r.From, "ssh://") { - routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname))) + routes = append(routes, fmt.Sprintf("%s@%s", state.Username, strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname))) } } items := []list.Item{} for _, route := range routes { items = append(items, item(route)) } - activeStreamIds.Range(func(key, value any) bool { - items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key))) - return true + a.activeStreams.Range(func(id uint64, state *StreamState) { + if id != state.StreamID { + items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", id))) + } }) l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2)) @@ -1152,6 +1201,25 @@ func NewPortalCommand( }) } else { username, hostname, _ := strings.Cut(answer.(model).choice, "@") + // Perform authorize check for this route + state.Hostname = hostname + if username != state.Username { + return fmt.Errorf("internal error: username mismatch") + } + req, err := a.getEvaluatorRequestFromSSHAuthRequest(state) + if err != nil { + return err + } + res, err := a.evaluate(cmd.Context(), req, &sessions.State{ID: state.Session.Id}) + if err != nil { + return err + } + + if res.Allow.Value && !res.Deny.Value { + a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id) + } else { + return fmt.Errorf("not authorized") + } sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), Format: extensions_session_recording.Format_AsciicastFormat,