From 3225d3b032dbd084efb3837aa00ff3f4b6c9cb25 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Fri, 21 Mar 2025 18:38:26 +0000 Subject: [PATCH] wip: ssh cli mode --- authorize/ssh_grpc.go | 141 +++++++++++++++++++++++++++++++++--------- 1 file changed, 113 insertions(+), 28 deletions(-) diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index 41846d7ce..40cfaacb3 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -15,6 +15,7 @@ import ( "strings" "sync" "sync/atomic" + "text/template" "time" "github.com/charmbracelet/bubbles/list" @@ -45,6 +46,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -226,24 +228,6 @@ func (a *Authorize) ManageStream( if authReq.Username == "" { return fmt.Errorf("no username given") } - if authReq.Hostname == "" { - resp := extensions_ssh.ServerMessage{ - Message: &extensions_ssh.ServerMessage_AuthResponse{ - AuthResponse: &extensions_ssh.AuthenticationResponse{ - Response: &extensions_ssh.AuthenticationResponse_Allow{ - Allow: &extensions_ssh.AllowResponse{ - Username: state.Username, - Target: &extensions_ssh.AllowResponse_Internal{ - Internal: &extensions_ssh.InternalTarget{}, - }, - }, - }, - }, - }, - } - sendC <- &resp - continue - } if session != nil { // Perform authorize check for this route @@ -255,7 +239,7 @@ func (a *Authorize) ManageStream( if err != nil { return err } - sendC <- handleEvaluatorResponseForSSH(res, &state) + sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id) if res.Allow.Value && !res.Deny.Value { a.startContinuousAuthorization(ctx, errC, req, session.Id) @@ -279,9 +263,9 @@ func (a *Authorize) ManageStream( } case "keyboard-interactive": route := a.getSSHRouteForHostname(state.Hostname) - if route == nil { - return fmt.Errorf("invalid route") - } + // if route == nil { + // return fmt.Errorf("invalid route") + // } opts := a.currentConfig.Load().Options idp, err := opts.GetIdentityProviderForPolicy(route) @@ -393,7 +377,7 @@ func (a *Authorize) ManageStream( if err != nil { return err } - sendC <- handleEvaluatorResponseForSSH(res, &state) + sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID) if res.Allow.Value && !res.Deny.Value { a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID) @@ -468,7 +452,14 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest( } route := a.getSSHRouteForHostname(state.Hostname) if route == nil { - return nil, fmt.Errorf("no route found for hostname %q", state.Hostname) + return &evaluator.Request{ + IsInternal: true, + Session: evaluator.RequestSession{ + ID: sessionID, + }, + }, nil + + // return nil, fmt.Errorf("no route found for hostname %q", state.Hostname) } req := &evaluator.Request{ IsInternal: false, @@ -485,7 +476,7 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest( } func handleEvaluatorResponseForSSH( - result *evaluator.Result, state *StreamState, + result *evaluator.Result, state *StreamState, sessionID string, ) *extensions_ssh.ServerMessage { // fmt.Printf(" *** evaluator result: %+v\n", result) @@ -494,6 +485,33 @@ func handleEvaluatorResponseForSSH( if allow { pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey)) + + if state.Hostname == "" { + return &extensions_ssh.ServerMessage{ + Message: &extensions_ssh.ServerMessage_AuthResponse{ + AuthResponse: &extensions_ssh.AuthenticationResponse{ + Response: &extensions_ssh.AuthenticationResponse_Allow{ + Allow: &extensions_ssh.AllowResponse{ + Username: state.Username, + Target: &extensions_ssh.AllowResponse_Internal{ + Internal: &extensions_ssh.InternalTarget{ + SetMetadata: &corev3.Metadata{ + FilterMetadata: map[string]*structpb.Struct{ + "pomerium": { + Fields: map[string]*structpb.Value{ + "session-id": structpb.NewStringValue(sessionID), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + } sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()), Format: extensions_session_recording.Format_AsciicastFormat, @@ -712,6 +730,8 @@ func (a *Authorize) ServeChannel( } } }() + + var sessionID atomic.Pointer[string] go func() { localWindow := uint32(channelWindowSize) for { @@ -724,6 +744,16 @@ func (a *Authorize) ServeChannel( errC <- err return } + if sessionID.Load() == nil { + mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata) + if !ok { + errC <- fmt.Errorf("first message was not metadata") + return + } + id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue() + sessionID.Store(&id) + continue + } if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok { msgLen := uint32(len(raw.RawBytes.GetValue())) if msgLen == 0 { @@ -847,7 +877,7 @@ func (a *Authorize) ServeChannel( return err } - cmd := NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, inputR, outputW, sendC, &activeProgram) + cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram) if msg.Request == "shell" { cmd.SetArgs([]string{"portal"}) } else { @@ -956,15 +986,16 @@ func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader sendC <- channelDataMsg{ PeersID: channelID, Length: uint32(n), - Rest: buf[:n], + Rest: slices.Clone(buf[:n]), } } } -func NewSSHCLI( +func (a *Authorize) NewSSHCLI( cfg *config.Config, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo, + sessionID string, stdin io.Reader, stdout io.Writer, sendC chan any, @@ -986,12 +1017,66 @@ func NewSSHCLI( }, } cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram)) + cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID)) + cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID)) + cmd.CompletionOptions.DisableDefaultCmd = true cmd.SetIn(stdin) cmd.SetOut(stdout) cmd.SetErr(stdout) return cmd } +func (a *Authorize) NewLogoutCommand( + cfg *config.Config, + sessionID string, +) *cobra.Command { + cmd := &cobra.Command{ + Use: "logout", + Short: "Log out", + RunE: func(cmd *cobra.Command, args []string) error { + client := a.state.Load().dataBrokerClient + err := session.Delete(cmd.Context(), client, sessionID) + if err != nil { + return fmt.Errorf("internal error: %w", err) + } + cmd.OutOrStdout().Write([]byte("Logged out successfully\r\n")) + return nil + }, + } + return cmd +} + +var whoamiTmpl = template.Must(template.New("whoami").Parse(` +User ID: {{.UserId}} +Session ID: {{.Id}} +Expires at: {{.ExpiresAt.AsTime}} +Claims: +{{- range $k, $v := .Claims }} + {{ $k }}: {{ $v.AsSlice }} +{{- end }} +`)) + +func (a *Authorize) NewWhoamiCommand( + cfg *config.Config, + sessionID string, +) *cobra.Command { + cmd := &cobra.Command{ + Use: "whoami", + RunE: func(cmd *cobra.Command, args []string) error { + client := a.state.Load().dataBrokerClient + s, err := session.Get(cmd.Context(), client, sessionID) + if err != nil { + return fmt.Errorf("couldn't fetch session: %w", err) + } + var b bytes.Buffer + whoamiTmpl.Execute(&b, s) + cmd.OutOrStdout().Write([]byte(b.String() + "\r\n")) + return nil + }, + } + return cmd +} + func NewPortalCommand( cfg *config.Config, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,