diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index d076eef59..41846d7ce 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -33,6 +33,7 @@ import ( "github.com/pomerium/pomerium/pkg/identity/manager" "github.com/pomerium/pomerium/pkg/identity/oauth" "github.com/pomerium/pomerium/pkg/storage" + "github.com/spf13/cobra" gossh "golang.org/x/crypto/ssh" "golang.org/x/oauth2" "golang.org/x/sync/errgroup" @@ -41,6 +42,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protodelim" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" @@ -188,17 +190,14 @@ func (a *Authorize) ManageStream( case *extensions_ssh.ClientMessage_Event: switch event := req.Event.Event.(type) { case *extensions_ssh.StreamEvent_DownstreamConnected: - fmt.Println("downstream connected") _ = event case *extensions_ssh.StreamEvent_UpstreamConnected: - fmt.Printf("upstream connected: %d\n", event.UpstreamConnected.GetStreamId()) activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state) defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId()) case nil: } case *extensions_ssh.ClientMessage_AuthRequest: authReq := req.AuthRequest - fmt.Println("auth request") if state.Username == "" { state.Username = authReq.Username } @@ -633,269 +632,191 @@ func (a *Authorize) startContinuousAuthorization( }() } -// See RFC 4254, section 5.1. -const msgChannelOpen = 90 - -type channelOpenMsg struct { - ChanType string `sshtype:"90"` - PeersID uint32 - PeersWindow uint32 - MaxPacketSize uint32 - TypeSpecificData []byte `ssh:"rest"` +func marshalAny(msg proto.Message) *anypb.Any { + a, err := anypb.New(msg) + if err != nil { + panic(err) + } + return a } -const ( - msgChannelExtendedData = 95 - msgChannelData = 94 -) - -// Used for debug print outs of packets. -type channelDataMsg struct { - PeersID uint32 `sshtype:"94"` - Length uint32 - Rest []byte `ssh:"rest"` -} - -// See RFC 4254, section 5.1. -const msgChannelOpenConfirm = 91 - -type channelOpenConfirmMsg struct { - PeersID uint32 `sshtype:"91"` - MyID uint32 - MyWindow uint32 - MaxPacketSize uint32 - TypeSpecificData []byte `ssh:"rest"` -} - -const msgChannelRequest = 98 - -type channelRequestMsg struct { - PeersID uint32 `sshtype:"98"` - Request string - WantReply bool - RequestSpecificData []byte `ssh:"rest"` -} - -type channelOpenDirectMsg struct { - DestAddr string - DestPort uint32 - SrcAddr string - SrcPort uint32 -} - -// See RFC 4254, section 5.4. -const msgChannelSuccess = 99 - -type channelRequestSuccessMsg struct { - PeersID uint32 `sshtype:"99"` -} - -// See RFC 4254, section 5.4. -const msgChannelFailure = 100 - -type channelRequestFailureMsg struct { - PeersID uint32 `sshtype:"100"` -} - -// See RFC 4254, section 5.3 -const msgChannelClose = 97 - -type channelCloseMsg struct { - PeersID uint32 `sshtype:"97"` -} - -// See RFC 4254, section 5.3 -const msgChannelEOF = 96 - -type channelEOFMsg struct { - PeersID uint32 `sshtype:"96"` -} +// sentinel error to indicate that the command triggered a handoff, and we +// should not automatically disconnect +var ErrHandoff = errors.New("handoff") func (a *Authorize) ServeChannel( server extensions_ssh.StreamManagement_ServeChannelServer, ) error { - var program *tea.Program + ctx := server.Context() inputR, inputW := io.Pipe() outputR, outputW := io.Pipe() var peerId uint32 + var activeProgram atomic.Pointer[tea.Program] + + errC := make(chan error, 1) + remoteWindow := &window{Cond: sync.NewCond(&sync.Mutex{})} + sendC := make(chan any, 8) + recvC := make(chan *extensions_ssh.ChannelMessage) + go func() { + for { + select { + case msg := <-sendC: + switch msg := msg.(type) { + case *extensions_ssh.ChannelControl: + log.Ctx(ctx).Debug().Msg("sending channel control message") + if err := server.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_ChannelControl{ + ChannelControl: msg, + }, + }); err != nil { + errC <- err + return + } + case windowAdjustMsg, channelRequestMsg, channelRequestSuccessMsg, channelRequestFailureMsg, channelEOFMsg: + // these messages don't consume window space + data := gossh.Marshal(msg) + if err := server.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes(data), + }, + }); err != nil { + errC <- err + return + } + log.Ctx(ctx).Debug().Uint8("type", data[0]).Msg("message sent") + default: + data := gossh.Marshal(msg) + need := uint32(len(data)) + have := uint32(0) + for have < need { + n, err := remoteWindow.reserve(need - have) + if err != nil { + errC <- err + return + } + have += n + } + if err := server.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes(data), + }, + }); err != nil { + errC <- err + return + } + log.Ctx(ctx).Debug().Uint8("type", data[0]).Uint32("size", need).Msg("message sent") + } + case <-ctx.Done(): + errC <- ctx.Err() + return + } + } + }() + go func() { + localWindow := uint32(channelWindowSize) + for { + channelMsg, err := server.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + errC <- nil + return + } + errC <- err + return + } + if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok { + msgLen := uint32(len(raw.RawBytes.GetValue())) + if msgLen == 0 { + errC <- status.Errorf(codes.InvalidArgument, "peer sent empty message") + return + } + if msgLen > channelMaxPacket { + errC <- status.Errorf(codes.ResourceExhausted, "message too large") + return + } + log.Ctx(ctx).Debug().Uint8("type", raw.RawBytes.Value[0]).Uint32("size", msgLen).Msg("message received") + // peek the first byte to check if we need to deduct from the window + switch raw.RawBytes.Value[0] { + case msgChannelWindowAdjust, msgChannelRequest, msgChannelSuccess, msgChannelFailure, msgChannelEOF: + // these messages don't consume window space + default: + if localWindow < msgLen { + errC <- status.Errorf(codes.ResourceExhausted, "peer sent more bytes than allowed by channel window") + return + } + localWindow -= msgLen + if localWindow < channelWindowSize/2 { + log.Ctx(ctx).Debug().Msg("flow control: increasing local window size") + localWindow += channelWindowSize + sendC <- windowAdjustMsg{ + PeersID: peerId, + AdditionalBytes: channelWindowSize, + } + } + } + } + + select { + case recvC <- channelMsg: + case <-ctx.Done(): + errC <- ctx.Err() + return + } + } + }() var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo + var channelIdCounter uint32 for { - channelMsg, err := server.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - return nil - } - return err - } - rawMsg := channelMsg.GetRawBytes().GetValue() - switch rawMsg[0] { - case msgChannelOpen: - var msg channelOpenMsg - gossh.Unmarshal(rawMsg, &msg) - - var confirm channelOpenConfirmMsg - peerId = msg.PeersID - confirm.PeersID = peerId - confirm.MyID = 1 - confirm.MyWindow = msg.PeersWindow - confirm.MaxPacketSize = msg.MaxPacketSize - downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{ - ChannelType: msg.ChanType, - DownstreamChannelId: confirm.PeersID, - InternalUpstreamChannelId: confirm.MyID, - InitialWindowSize: confirm.MyWindow, - MaxPacketSize: confirm.MaxPacketSize, - } - switch msg.ChanType { - case "session": - if err := server.Send(&extensions_ssh.ChannelMessage{ - Message: &extensions_ssh.ChannelMessage_RawBytes{ - RawBytes: &wrapperspb.BytesValue{ - Value: gossh.Marshal(confirm), - }, - }, - }); err != nil { - return err + select { + case channelMsg := <-recvC: + rawMsg := channelMsg.GetRawBytes().GetValue() + switch rawMsg[0] { + case msgChannelOpen: + var msg channelOpenMsg + gossh.Unmarshal(rawMsg, &msg) + channelIdCounter++ + if channelIdCounter > 1 { + return fmt.Errorf("only one channel can be opened") } - case "direct-tcpip": - var subMsg channelOpenDirectMsg - if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil { - return err + peerId = msg.PeersID + downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: msg.ChanType, + DownstreamChannelId: peerId, + InternalUpstreamChannelId: channelIdCounter, + InitialWindowSize: msg.PeersWindow, + MaxPacketSize: msg.MaxPacketSize, } - handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ - Action: &extensions_ssh.SSHChannelControlAction_HandOff{ - HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ - DownstreamChannelInfo: downstreamChannelInfo, - UpstreamAuth: &extensions_ssh.AllowResponse{ - Target: &extensions_ssh.AllowResponse_Upstream{ - Upstream: &extensions_ssh.UpstreamTarget{ - Hostname: subMsg.DestAddr, - DirectTcpip: true, - }, - }, - }, - }, - }, - }) - if err := server.Send(&extensions_ssh.ChannelMessage{ - Message: &extensions_ssh.ChannelMessage_ChannelControl{ - ChannelControl: &extensions_ssh.ChannelControl{ - Protocol: "ssh", - ControlAction: handOff, - }, - }, - }); err != nil { - return err - } - } - - case msgChannelRequest: - var msg channelRequestMsg - gossh.Unmarshal(rawMsg, &msg) - - switch msg.Request { - case "pty-req": - opts := a.currentConfig.Load().Options - var routes []string - for r := range opts.GetAllPolicies() { - if strings.HasPrefix(r.From, "ssh://") { - routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+opts.SSHHostname))) + remoteWindow.add(msg.PeersWindow) + switch msg.ChanType { + case "session": + sendC <- channelOpenConfirmMsg{ + PeersID: peerId, + MyID: channelIdCounter, + MyWindow: channelWindowSize, + MaxPacketSize: channelMaxPacket, } - } - req := parsePtyReq(msg.RequestSpecificData) - items := []list.Item{} - for _, route := range routes { - items = append(items, item(route)) - } - activeStreamIds.Range(func(key, value any) bool { - items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key))) - return true - }) - downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ - TermEnv: req.TermEnv, - WidthColumns: req.Width, - HeightRows: req.Height, - WidthPx: req.WidthPx, - HeightPx: req.HeightPx, - Modes: req.Modes, - } - - const defaultWidth = 20 - - l := list.New(items, itemDelegate{}, defaultWidth, listHeight) - l.Title = "Connect to which server?" - l.SetShowStatusBar(false) - l.SetFilteringEnabled(false) - l.Styles.Title = titleStyle - l.Styles.PaginationStyle = paginationStyle - l.Styles.HelpStyle = helpStyle - - program = tea.NewProgram(model{list: l}, - tea.WithInput(inputR), - tea.WithOutput(outputW), - tea.WithAltScreen(), - tea.WithContext(server.Context()), - tea.WithEnvironment([]string{"TERM=" + req.TermEnv}), - ) - go func() { - answer, err := program.Run() - if err != nil { - return + case "direct-tcpip": + var subMsg channelOpenDirectMsg + if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil { + return err } - var handOff *anypb.Any - if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") { - id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64) - if err != nil { - panic(err) - } - handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{ - Action: &extensions_ssh.SSHChannelControlAction_HandOff{ - HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ - DownstreamChannelInfo: downstreamChannelInfo, - DownstreamPtyInfo: downstreamPtyInfo, - UpstreamAuth: &extensions_ssh.AllowResponse{ - Target: &extensions_ssh.AllowResponse_MirrorSession{ - MirrorSession: &extensions_ssh.MirrorSessionTarget{ - SourceId: id, - Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, - }, + handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ + Action: &extensions_ssh.SSHChannelControlAction_HandOff{ + HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: downstreamChannelInfo, + UpstreamAuth: &extensions_ssh.AllowResponse{ + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + Hostname: subMsg.DestAddr, + DirectTcpip: true, }, }, }, }, - }) - } else { - username, hostname, _ := strings.Cut(answer.(model).choice, "@") - sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ - RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), - Format: extensions_session_recording.Format_AsciicastFormat, - }) - handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{ - Action: &extensions_ssh.SSHChannelControlAction_HandOff{ - HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ - DownstreamChannelInfo: downstreamChannelInfo, - DownstreamPtyInfo: downstreamPtyInfo, - UpstreamAuth: &extensions_ssh.AllowResponse{ - Username: username, - Target: &extensions_ssh.AllowResponse_Upstream{ - Upstream: &extensions_ssh.UpstreamTarget{ - AllowMirrorConnections: true, - Hostname: hostname, - Extensions: []*corev3.TypedExtensionConfig{ - { - TypedConfig: sessionRecordingExt, - }, - }, - }, - }, - }, - }, - }, - }) - } - + }, + }) if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_ChannelControl{ ChannelControl: &extensions_ssh.ChannelControl{ @@ -904,57 +825,101 @@ func (a *Authorize) ServeChannel( }, }, }); err != nil { - return + return err } - }() - go func() { - var buf [4096]byte - for { - n, err := outputR.Read(buf[:]) - if err != nil { - return - } - if err := server.Send(&extensions_ssh.ChannelMessage{ - Message: &extensions_ssh.ChannelMessage_RawBytes{ - RawBytes: &wrapperspb.BytesValue{ - Value: gossh.Marshal(channelDataMsg{ - PeersID: peerId, - Length: uint32(n), - Rest: buf[:n], - }), - }, - }, - }); err != nil { - return - } - } - }() - program.Send(tea.WindowSizeMsg{Width: int(req.Width), Height: int(req.Height)}) + } - if err := server.Send(&extensions_ssh.ChannelMessage{ - Message: &extensions_ssh.ChannelMessage_RawBytes{ - RawBytes: &wrapperspb.BytesValue{ - Value: gossh.Marshal(channelRequestSuccessMsg{ - PeersID: peerId, - }), + case msgChannelRequest: + var msg channelRequestMsg + gossh.Unmarshal(rawMsg, &msg) + + switch msg.Request { + case "shell", "exec": + if err := server.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: &wrapperspb.BytesValue{ + Value: gossh.Marshal(channelRequestSuccessMsg{ + PeersID: peerId, + }), + }, }, - }, - }); err != nil { + }); err != nil { + return err + } + + cmd := NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, inputR, outputW, sendC, &activeProgram) + if msg.Request == "shell" { + cmd.SetArgs([]string{"portal"}) + } else { + var execReq execChannelRequestMsg + if err := gossh.Unmarshal(msg.RequestSpecificData, &execReq); err != nil { + return err + } + cmd.SetArgs(strings.Fields(execReq.Command)) + } + go func() { + defer activeProgram.Store(nil) + defer outputW.Close() + defer inputR.Close() + err := cmd.Execute() + if !errors.Is(err, ErrHandoff) { + sendC <- &extensions_ssh.ChannelControl{ + Protocol: "ssh", + ControlAction: marshalAny(&extensions_ssh.SSHChannelControlAction_Disconnect{ + ReasonCode: 11, + }), + } + } + }() + go streamOutputToChannel(sendC, peerId, outputR) + + case "pty-req": + req := parsePtyReq(msg.RequestSpecificData) + downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ + TermEnv: req.TermEnv, + WidthColumns: req.Width, + HeightRows: req.Height, + WidthPx: req.WidthPx, + HeightPx: req.HeightPx, + Modes: req.Modes, + } + sendC <- channelRequestSuccessMsg{PeersID: peerId} + case "window-change": + var req channelWindowChangeRequestMsg + if err := gossh.Unmarshal(msg.RequestSpecificData, &req); err != nil { + return err + } + if p := activeProgram.Load(); p != nil { + p.Send(tea.WindowSizeMsg{ + Width: int(req.WidthColumns), + Height: int(req.HeightRows), + }) + } + } + case msgChannelData: + var msg channelDataMsg + gossh.Unmarshal(rawMsg, &msg) + if activeProgram.Load() != nil { + inputW.Write(msg.Rest) + } + case msgChannelClose: + var msg channelDataMsg + gossh.Unmarshal(rawMsg, &msg) + case msgChannelWindowAdjust: + var msg windowAdjustMsg + if err := gossh.Unmarshal(rawMsg, &msg); err != nil { return err } + log.Ctx(ctx).Debug().Uint32("bytes", msg.AdditionalBytes).Msg("flow control: remote window size increased") + remoteWindow.add(msg.AdditionalBytes) + case msgChannelEOF: + return nil + default: + panic("unhandled message: " + fmt.Sprint(rawMsg[1])) } - case msgChannelData: - var msg channelDataMsg - gossh.Unmarshal(rawMsg, &msg) - - if program != nil { - inputW.Write(msg.Rest) - } - case msgChannelClose: - var msg channelDataMsg - gossh.Unmarshal(rawMsg, &msg) - default: - panic("unhandled message: " + fmt.Sprint(rawMsg[1])) + case err := <-errC: + log.Ctx(ctx).Err(err).Msg("channel error") + return err } } } @@ -981,7 +946,164 @@ func parsePtyReq(reqData []byte) ptyReq { } } -const listHeight = 14 +func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader) { + var buf [4096]byte + for { + n, err := outputR.Read(buf[:]) + if err != nil { + return + } + sendC <- channelDataMsg{ + PeersID: channelID, + Length: uint32(n), + Rest: buf[:n], + } + } +} + +func NewSSHCLI( + cfg *config.Config, + ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, + channelInfo *extensions_ssh.SSHDownstreamChannelInfo, + stdin io.Reader, + stdout io.Writer, + sendC chan any, + activeProgram *atomic.Pointer[tea.Program], +) *cobra.Command { + cmd := &cobra.Command{ + Use: "pomerium", + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + _, cmdIsInteractive := cmd.Annotations["interactive"] + switch { + case (ptyInfo == nil) && cmdIsInteractive: + cmd.SilenceUsage = true + return fmt.Errorf("\x1b[31m'%s' is an interactive command and requires a TTY (try passing '-t' to ssh)\x1b[0m", cmd.Use) + case (ptyInfo != nil) && !cmdIsInteractive: + cmd.SilenceUsage = true + return fmt.Errorf("\x1b[31m'%s' is not an interactive command (try passing '-T' to ssh, or removing '-t')\x1b[0m\r", cmd.Use) + } + return nil + }, + } + cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram)) + cmd.SetIn(stdin) + cmd.SetOut(stdout) + cmd.SetErr(stdout) + return cmd +} + +func NewPortalCommand( + cfg *config.Config, + ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, + channelInfo *extensions_ssh.SSHDownstreamChannelInfo, + sendC chan any, + activeProgram *atomic.Pointer[tea.Program], +) *cobra.Command { + cmd := &cobra.Command{ + Use: "portal", + Short: "Interactive route portal", + Annotations: map[string]string{ + "interactive": "", + }, + RunE: func(cmd *cobra.Command, args []string) error { + var routes []string + for r := range cfg.Options.GetAllPolicies() { + if strings.HasPrefix(r.From, "ssh://") { + routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname))) + } + } + items := []list.Item{} + for _, route := range routes { + items = append(items, item(route)) + } + activeStreamIds.Range(func(key, value any) bool { + items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key))) + return true + }) + + l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2)) + l.Title = "Connect to which server?" + l.SetShowStatusBar(false) + l.SetFilteringEnabled(false) + l.Styles.Title = titleStyle + l.Styles.PaginationStyle = paginationStyle + l.Styles.HelpStyle = helpStyle + + program := tea.NewProgram(model{list: l}, + tea.WithInput(cmd.InOrStdin()), + tea.WithOutput(cmd.OutOrStdout()), + tea.WithAltScreen(), + tea.WithContext(cmd.Context()), + tea.WithEnvironment([]string{"TERM=" + ptyInfo.TermEnv}), + ) + activeProgram.Store(program) + + go program.Send(tea.WindowSizeMsg{Width: int(ptyInfo.WidthColumns), Height: int(ptyInfo.HeightRows)}) + answer, err := program.Run() + if err != nil { + return err + } + var handOff *anypb.Any + if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") { + id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64) + if err != nil { + panic(err) + } + handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{ + Action: &extensions_ssh.SSHChannelControlAction_HandOff{ + HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: channelInfo, + DownstreamPtyInfo: ptyInfo, + UpstreamAuth: &extensions_ssh.AllowResponse{ + Target: &extensions_ssh.AllowResponse_MirrorSession{ + MirrorSession: &extensions_ssh.MirrorSessionTarget{ + SourceId: id, + Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, + }, + }, + }, + }, + }, + }) + } else { + username, hostname, _ := strings.Cut(answer.(model).choice, "@") + sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ + RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), + Format: extensions_session_recording.Format_AsciicastFormat, + }) + handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{ + Action: &extensions_ssh.SSHChannelControlAction_HandOff{ + HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: channelInfo, + DownstreamPtyInfo: ptyInfo, + UpstreamAuth: &extensions_ssh.AllowResponse{ + Username: username, + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + AllowMirrorConnections: true, + Hostname: hostname, + Extensions: []*corev3.TypedExtensionConfig{ + { + TypedConfig: sessionRecordingExt, + }, + }, + }, + }, + }, + }, + }, + }) + } + + sendC <- &extensions_ssh.ChannelControl{ + Protocol: "ssh", + ControlAction: handOff, + } + return ErrHandoff + }, + } + return cmd +} var ( titleStyle = lipgloss.NewStyle().MarginLeft(2) @@ -1032,7 +1154,8 @@ func (m model) Init() tea.Cmd { func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: - m.list.SetWidth(msg.Width) + m.list.SetWidth(msg.Width - 2) + m.list.SetHeight(msg.Height - 2) return m, nil case tea.KeyMsg: @@ -1058,3 +1181,189 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m model) View() string { return "\n" + m.list.View() } + +// code below copied from x/crypto/ssh/common.go + +const ( + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// code below copied from x/crypto/ssh/messages.go +// (with some additional messages not included there) + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersID uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const ( + msgChannelExtendedData = 95 + msgChannelData = 94 +) + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +// Used for debug print outs of packets. +type channelDataMsg struct { + PeersID uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersID uint32 `sshtype:"91"` + MyID uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersID uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +type channelOpenDirectMsg struct { + DestAddr string + DestPort uint32 + SrcAddr string + SrcPort uint32 +} + +type channelWindowChangeRequestMsg struct { + WidthColumns uint32 + HeightRows uint32 + WidthPx uint32 + HeightPx uint32 +} + +type shellChannelRequestMsg struct{} + +type execChannelRequestMsg struct { + Command string +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersID uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersID uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersID uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersID uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersID uint32 `sshtype:"96"` +}