This commit is contained in:
Joe Kralicky 2025-03-21 15:08:08 +00:00
parent 08252f32df
commit 5e06f2aef9
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79

View file

@ -33,6 +33,7 @@ import (
"github.com/pomerium/pomerium/pkg/identity/manager" "github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/identity/oauth" "github.com/pomerium/pomerium/pkg/identity/oauth"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
"github.com/spf13/cobra"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -41,6 +42,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protodelim" "google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
@ -188,17 +190,14 @@ func (a *Authorize) ManageStream(
case *extensions_ssh.ClientMessage_Event: case *extensions_ssh.ClientMessage_Event:
switch event := req.Event.Event.(type) { switch event := req.Event.Event.(type) {
case *extensions_ssh.StreamEvent_DownstreamConnected: case *extensions_ssh.StreamEvent_DownstreamConnected:
fmt.Println("downstream connected")
_ = event _ = event
case *extensions_ssh.StreamEvent_UpstreamConnected: case *extensions_ssh.StreamEvent_UpstreamConnected:
fmt.Printf("upstream connected: %d\n", event.UpstreamConnected.GetStreamId())
activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state) activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state)
defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId()) defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId())
case nil: case nil:
} }
case *extensions_ssh.ClientMessage_AuthRequest: case *extensions_ssh.ClientMessage_AuthRequest:
authReq := req.AuthRequest authReq := req.AuthRequest
fmt.Println("auth request")
if state.Username == "" { if state.Username == "" {
state.Username = authReq.Username state.Username = authReq.Username
} }
@ -633,269 +632,191 @@ func (a *Authorize) startContinuousAuthorization(
}() }()
} }
// See RFC 4254, section 5.1. func marshalAny(msg proto.Message) *anypb.Any {
const msgChannelOpen = 90 a, err := anypb.New(msg)
if err != nil {
type channelOpenMsg struct { panic(err)
ChanType string `sshtype:"90"` }
PeersID uint32 return a
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
} }
const ( // sentinel error to indicate that the command triggered a handoff, and we
msgChannelExtendedData = 95 // should not automatically disconnect
msgChannelData = 94 var ErrHandoff = errors.New("handoff")
)
// Used for debug print outs of packets.
type channelDataMsg struct {
PeersID uint32 `sshtype:"94"`
Length uint32
Rest []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct {
PeersID uint32 `sshtype:"91"`
MyID uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
const msgChannelRequest = 98
type channelRequestMsg struct {
PeersID uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
}
type channelOpenDirectMsg struct {
DestAddr string
DestPort uint32
SrcAddr string
SrcPort uint32
}
// See RFC 4254, section 5.4.
const msgChannelSuccess = 99
type channelRequestSuccessMsg struct {
PeersID uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
const msgChannelFailure = 100
type channelRequestFailureMsg struct {
PeersID uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
const msgChannelClose = 97
type channelCloseMsg struct {
PeersID uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
const msgChannelEOF = 96
type channelEOFMsg struct {
PeersID uint32 `sshtype:"96"`
}
func (a *Authorize) ServeChannel( func (a *Authorize) ServeChannel(
server extensions_ssh.StreamManagement_ServeChannelServer, server extensions_ssh.StreamManagement_ServeChannelServer,
) error { ) error {
var program *tea.Program ctx := server.Context()
inputR, inputW := io.Pipe() inputR, inputW := io.Pipe()
outputR, outputW := io.Pipe() outputR, outputW := io.Pipe()
var peerId uint32 var peerId uint32
var activeProgram atomic.Pointer[tea.Program]
errC := make(chan error, 1)
remoteWindow := &window{Cond: sync.NewCond(&sync.Mutex{})}
sendC := make(chan any, 8)
recvC := make(chan *extensions_ssh.ChannelMessage)
go func() {
for {
select {
case msg := <-sendC:
switch msg := msg.(type) {
case *extensions_ssh.ChannelControl:
log.Ctx(ctx).Debug().Msg("sending channel control message")
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_ChannelControl{
ChannelControl: msg,
},
}); err != nil {
errC <- err
return
}
case windowAdjustMsg, channelRequestMsg, channelRequestSuccessMsg, channelRequestFailureMsg, channelEOFMsg:
// these messages don't consume window space
data := gossh.Marshal(msg)
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: wrapperspb.Bytes(data),
},
}); err != nil {
errC <- err
return
}
log.Ctx(ctx).Debug().Uint8("type", data[0]).Msg("message sent")
default:
data := gossh.Marshal(msg)
need := uint32(len(data))
have := uint32(0)
for have < need {
n, err := remoteWindow.reserve(need - have)
if err != nil {
errC <- err
return
}
have += n
}
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: wrapperspb.Bytes(data),
},
}); err != nil {
errC <- err
return
}
log.Ctx(ctx).Debug().Uint8("type", data[0]).Uint32("size", need).Msg("message sent")
}
case <-ctx.Done():
errC <- ctx.Err()
return
}
}
}()
go func() {
localWindow := uint32(channelWindowSize)
for {
channelMsg, err := server.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
errC <- nil
return
}
errC <- err
return
}
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
msgLen := uint32(len(raw.RawBytes.GetValue()))
if msgLen == 0 {
errC <- status.Errorf(codes.InvalidArgument, "peer sent empty message")
return
}
if msgLen > channelMaxPacket {
errC <- status.Errorf(codes.ResourceExhausted, "message too large")
return
}
log.Ctx(ctx).Debug().Uint8("type", raw.RawBytes.Value[0]).Uint32("size", msgLen).Msg("message received")
// peek the first byte to check if we need to deduct from the window
switch raw.RawBytes.Value[0] {
case msgChannelWindowAdjust, msgChannelRequest, msgChannelSuccess, msgChannelFailure, msgChannelEOF:
// these messages don't consume window space
default:
if localWindow < msgLen {
errC <- status.Errorf(codes.ResourceExhausted, "peer sent more bytes than allowed by channel window")
return
}
localWindow -= msgLen
if localWindow < channelWindowSize/2 {
log.Ctx(ctx).Debug().Msg("flow control: increasing local window size")
localWindow += channelWindowSize
sendC <- windowAdjustMsg{
PeersID: peerId,
AdditionalBytes: channelWindowSize,
}
}
}
}
select {
case recvC <- channelMsg:
case <-ctx.Done():
errC <- ctx.Err()
return
}
}
}()
var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo
var channelIdCounter uint32
for { for {
channelMsg, err := server.Recv() select {
if err != nil { case channelMsg := <-recvC:
if errors.Is(err, io.EOF) { rawMsg := channelMsg.GetRawBytes().GetValue()
return nil switch rawMsg[0] {
} case msgChannelOpen:
return err var msg channelOpenMsg
} gossh.Unmarshal(rawMsg, &msg)
rawMsg := channelMsg.GetRawBytes().GetValue() channelIdCounter++
switch rawMsg[0] { if channelIdCounter > 1 {
case msgChannelOpen: return fmt.Errorf("only one channel can be opened")
var msg channelOpenMsg
gossh.Unmarshal(rawMsg, &msg)
var confirm channelOpenConfirmMsg
peerId = msg.PeersID
confirm.PeersID = peerId
confirm.MyID = 1
confirm.MyWindow = msg.PeersWindow
confirm.MaxPacketSize = msg.MaxPacketSize
downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
ChannelType: msg.ChanType,
DownstreamChannelId: confirm.PeersID,
InternalUpstreamChannelId: confirm.MyID,
InitialWindowSize: confirm.MyWindow,
MaxPacketSize: confirm.MaxPacketSize,
}
switch msg.ChanType {
case "session":
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: &wrapperspb.BytesValue{
Value: gossh.Marshal(confirm),
},
},
}); err != nil {
return err
} }
case "direct-tcpip": peerId = msg.PeersID
var subMsg channelOpenDirectMsg downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil { ChannelType: msg.ChanType,
return err DownstreamChannelId: peerId,
InternalUpstreamChannelId: channelIdCounter,
InitialWindowSize: msg.PeersWindow,
MaxPacketSize: msg.MaxPacketSize,
} }
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ remoteWindow.add(msg.PeersWindow)
Action: &extensions_ssh.SSHChannelControlAction_HandOff{ switch msg.ChanType {
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ case "session":
DownstreamChannelInfo: downstreamChannelInfo, sendC <- channelOpenConfirmMsg{
UpstreamAuth: &extensions_ssh.AllowResponse{ PeersID: peerId,
Target: &extensions_ssh.AllowResponse_Upstream{ MyID: channelIdCounter,
Upstream: &extensions_ssh.UpstreamTarget{ MyWindow: channelWindowSize,
Hostname: subMsg.DestAddr, MaxPacketSize: channelMaxPacket,
DirectTcpip: true,
},
},
},
},
},
})
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_ChannelControl{
ChannelControl: &extensions_ssh.ChannelControl{
Protocol: "ssh",
ControlAction: handOff,
},
},
}); err != nil {
return err
}
}
case msgChannelRequest:
var msg channelRequestMsg
gossh.Unmarshal(rawMsg, &msg)
switch msg.Request {
case "pty-req":
opts := a.currentConfig.Load().Options
var routes []string
for r := range opts.GetAllPolicies() {
if strings.HasPrefix(r.From, "ssh://") {
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+opts.SSHHostname)))
} }
} case "direct-tcpip":
req := parsePtyReq(msg.RequestSpecificData) var subMsg channelOpenDirectMsg
items := []list.Item{} if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
for _, route := range routes { return err
items = append(items, item(route))
}
activeStreamIds.Range(func(key, value any) bool {
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
return true
})
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
TermEnv: req.TermEnv,
WidthColumns: req.Width,
HeightRows: req.Height,
WidthPx: req.WidthPx,
HeightPx: req.HeightPx,
Modes: req.Modes,
}
const defaultWidth = 20
l := list.New(items, itemDelegate{}, defaultWidth, listHeight)
l.Title = "Connect to which server?"
l.SetShowStatusBar(false)
l.SetFilteringEnabled(false)
l.Styles.Title = titleStyle
l.Styles.PaginationStyle = paginationStyle
l.Styles.HelpStyle = helpStyle
program = tea.NewProgram(model{list: l},
tea.WithInput(inputR),
tea.WithOutput(outputW),
tea.WithAltScreen(),
tea.WithContext(server.Context()),
tea.WithEnvironment([]string{"TERM=" + req.TermEnv}),
)
go func() {
answer, err := program.Run()
if err != nil {
return
} }
var handOff *anypb.Any handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") { Action: &extensions_ssh.SSHChannelControlAction_HandOff{
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64) HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
if err != nil { DownstreamChannelInfo: downstreamChannelInfo,
panic(err) UpstreamAuth: &extensions_ssh.AllowResponse{
} Target: &extensions_ssh.AllowResponse_Upstream{
handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{ Upstream: &extensions_ssh.UpstreamTarget{
Action: &extensions_ssh.SSHChannelControlAction_HandOff{ Hostname: subMsg.DestAddr,
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ DirectTcpip: true,
DownstreamChannelInfo: downstreamChannelInfo,
DownstreamPtyInfo: downstreamPtyInfo,
UpstreamAuth: &extensions_ssh.AllowResponse{
Target: &extensions_ssh.AllowResponse_MirrorSession{
MirrorSession: &extensions_ssh.MirrorSessionTarget{
SourceId: id,
Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
},
}, },
}, },
}, },
}, },
}) },
} else { })
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,
})
handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamChannelInfo: downstreamChannelInfo,
DownstreamPtyInfo: downstreamPtyInfo,
UpstreamAuth: &extensions_ssh.AllowResponse{
Username: username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
AllowMirrorConnections: true,
Hostname: hostname,
Extensions: []*corev3.TypedExtensionConfig{
{
TypedConfig: sessionRecordingExt,
},
},
},
},
},
},
},
})
}
if err := server.Send(&extensions_ssh.ChannelMessage{ if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_ChannelControl{ Message: &extensions_ssh.ChannelMessage_ChannelControl{
ChannelControl: &extensions_ssh.ChannelControl{ ChannelControl: &extensions_ssh.ChannelControl{
@ -904,57 +825,101 @@ func (a *Authorize) ServeChannel(
}, },
}, },
}); err != nil { }); err != nil {
return return err
} }
}() }
go func() {
var buf [4096]byte
for {
n, err := outputR.Read(buf[:])
if err != nil {
return
}
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: &wrapperspb.BytesValue{
Value: gossh.Marshal(channelDataMsg{
PeersID: peerId,
Length: uint32(n),
Rest: buf[:n],
}),
},
},
}); err != nil {
return
}
}
}()
program.Send(tea.WindowSizeMsg{Width: int(req.Width), Height: int(req.Height)})
if err := server.Send(&extensions_ssh.ChannelMessage{ case msgChannelRequest:
Message: &extensions_ssh.ChannelMessage_RawBytes{ var msg channelRequestMsg
RawBytes: &wrapperspb.BytesValue{ gossh.Unmarshal(rawMsg, &msg)
Value: gossh.Marshal(channelRequestSuccessMsg{
PeersID: peerId, switch msg.Request {
}), case "shell", "exec":
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: &wrapperspb.BytesValue{
Value: gossh.Marshal(channelRequestSuccessMsg{
PeersID: peerId,
}),
},
}, },
}, }); err != nil {
}); err != nil { return err
}
cmd := NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, inputR, outputW, sendC, &activeProgram)
if msg.Request == "shell" {
cmd.SetArgs([]string{"portal"})
} else {
var execReq execChannelRequestMsg
if err := gossh.Unmarshal(msg.RequestSpecificData, &execReq); err != nil {
return err
}
cmd.SetArgs(strings.Fields(execReq.Command))
}
go func() {
defer activeProgram.Store(nil)
defer outputW.Close()
defer inputR.Close()
err := cmd.Execute()
if !errors.Is(err, ErrHandoff) {
sendC <- &extensions_ssh.ChannelControl{
Protocol: "ssh",
ControlAction: marshalAny(&extensions_ssh.SSHChannelControlAction_Disconnect{
ReasonCode: 11,
}),
}
}
}()
go streamOutputToChannel(sendC, peerId, outputR)
case "pty-req":
req := parsePtyReq(msg.RequestSpecificData)
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
TermEnv: req.TermEnv,
WidthColumns: req.Width,
HeightRows: req.Height,
WidthPx: req.WidthPx,
HeightPx: req.HeightPx,
Modes: req.Modes,
}
sendC <- channelRequestSuccessMsg{PeersID: peerId}
case "window-change":
var req channelWindowChangeRequestMsg
if err := gossh.Unmarshal(msg.RequestSpecificData, &req); err != nil {
return err
}
if p := activeProgram.Load(); p != nil {
p.Send(tea.WindowSizeMsg{
Width: int(req.WidthColumns),
Height: int(req.HeightRows),
})
}
}
case msgChannelData:
var msg channelDataMsg
gossh.Unmarshal(rawMsg, &msg)
if activeProgram.Load() != nil {
inputW.Write(msg.Rest)
}
case msgChannelClose:
var msg channelDataMsg
gossh.Unmarshal(rawMsg, &msg)
case msgChannelWindowAdjust:
var msg windowAdjustMsg
if err := gossh.Unmarshal(rawMsg, &msg); err != nil {
return err return err
} }
log.Ctx(ctx).Debug().Uint32("bytes", msg.AdditionalBytes).Msg("flow control: remote window size increased")
remoteWindow.add(msg.AdditionalBytes)
case msgChannelEOF:
return nil
default:
panic("unhandled message: " + fmt.Sprint(rawMsg[1]))
} }
case msgChannelData: case err := <-errC:
var msg channelDataMsg log.Ctx(ctx).Err(err).Msg("channel error")
gossh.Unmarshal(rawMsg, &msg) return err
if program != nil {
inputW.Write(msg.Rest)
}
case msgChannelClose:
var msg channelDataMsg
gossh.Unmarshal(rawMsg, &msg)
default:
panic("unhandled message: " + fmt.Sprint(rawMsg[1]))
} }
} }
} }
@ -981,7 +946,164 @@ func parsePtyReq(reqData []byte) ptyReq {
} }
} }
const listHeight = 14 func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader) {
var buf [4096]byte
for {
n, err := outputR.Read(buf[:])
if err != nil {
return
}
sendC <- channelDataMsg{
PeersID: channelID,
Length: uint32(n),
Rest: buf[:n],
}
}
}
func NewSSHCLI(
cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
stdin io.Reader,
stdout io.Writer,
sendC chan any,
activeProgram *atomic.Pointer[tea.Program],
) *cobra.Command {
cmd := &cobra.Command{
Use: "pomerium",
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
_, cmdIsInteractive := cmd.Annotations["interactive"]
switch {
case (ptyInfo == nil) && cmdIsInteractive:
cmd.SilenceUsage = true
return fmt.Errorf("\x1b[31m'%s' is an interactive command and requires a TTY (try passing '-t' to ssh)\x1b[0m", cmd.Use)
case (ptyInfo != nil) && !cmdIsInteractive:
cmd.SilenceUsage = true
return fmt.Errorf("\x1b[31m'%s' is not an interactive command (try passing '-T' to ssh, or removing '-t')\x1b[0m\r", cmd.Use)
}
return nil
},
}
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
cmd.SetIn(stdin)
cmd.SetOut(stdout)
cmd.SetErr(stdout)
return cmd
}
func NewPortalCommand(
cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
sendC chan any,
activeProgram *atomic.Pointer[tea.Program],
) *cobra.Command {
cmd := &cobra.Command{
Use: "portal",
Short: "Interactive route portal",
Annotations: map[string]string{
"interactive": "",
},
RunE: func(cmd *cobra.Command, args []string) error {
var routes []string
for r := range cfg.Options.GetAllPolicies() {
if strings.HasPrefix(r.From, "ssh://") {
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
}
}
items := []list.Item{}
for _, route := range routes {
items = append(items, item(route))
}
activeStreamIds.Range(func(key, value any) bool {
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
return true
})
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
l.Title = "Connect to which server?"
l.SetShowStatusBar(false)
l.SetFilteringEnabled(false)
l.Styles.Title = titleStyle
l.Styles.PaginationStyle = paginationStyle
l.Styles.HelpStyle = helpStyle
program := tea.NewProgram(model{list: l},
tea.WithInput(cmd.InOrStdin()),
tea.WithOutput(cmd.OutOrStdout()),
tea.WithAltScreen(),
tea.WithContext(cmd.Context()),
tea.WithEnvironment([]string{"TERM=" + ptyInfo.TermEnv}),
)
activeProgram.Store(program)
go program.Send(tea.WindowSizeMsg{Width: int(ptyInfo.WidthColumns), Height: int(ptyInfo.HeightRows)})
answer, err := program.Run()
if err != nil {
return err
}
var handOff *anypb.Any
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
if err != nil {
panic(err)
}
handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamChannelInfo: channelInfo,
DownstreamPtyInfo: ptyInfo,
UpstreamAuth: &extensions_ssh.AllowResponse{
Target: &extensions_ssh.AllowResponse_MirrorSession{
MirrorSession: &extensions_ssh.MirrorSessionTarget{
SourceId: id,
Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
},
},
},
},
},
})
} else {
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,
})
handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamChannelInfo: channelInfo,
DownstreamPtyInfo: ptyInfo,
UpstreamAuth: &extensions_ssh.AllowResponse{
Username: username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
AllowMirrorConnections: true,
Hostname: hostname,
Extensions: []*corev3.TypedExtensionConfig{
{
TypedConfig: sessionRecordingExt,
},
},
},
},
},
},
},
})
}
sendC <- &extensions_ssh.ChannelControl{
Protocol: "ssh",
ControlAction: handOff,
}
return ErrHandoff
},
}
return cmd
}
var ( var (
titleStyle = lipgloss.NewStyle().MarginLeft(2) titleStyle = lipgloss.NewStyle().MarginLeft(2)
@ -1032,7 +1154,8 @@ func (m model) Init() tea.Cmd {
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.WindowSizeMsg: case tea.WindowSizeMsg:
m.list.SetWidth(msg.Width) m.list.SetWidth(msg.Width - 2)
m.list.SetHeight(msg.Height - 2)
return m, nil return m, nil
case tea.KeyMsg: case tea.KeyMsg:
@ -1058,3 +1181,189 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (m model) View() string { func (m model) View() string {
return "\n" + m.list.View() return "\n" + m.list.View()
} }
// code below copied from x/crypto/ssh/common.go
const (
// channelMaxPacket contains the maximum number of bytes that will be
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
// the minimum.
channelMaxPacket = 1 << 15
// We follow OpenSSH here.
channelWindowSize = 64 * channelMaxPacket
)
// window represents the buffer available to clients
// wishing to write to a channel.
type window struct {
*sync.Cond
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
writeWaiters int
closed bool
}
// add adds win to the amount of window available
// for consumers.
func (w *window) add(win uint32) bool {
// a zero sized window adjust is a noop.
if win == 0 {
return true
}
w.L.Lock()
if w.win+win < win {
w.L.Unlock()
return false
}
w.win += win
// It is unusual that multiple goroutines would be attempting to reserve
// window space, but not guaranteed. Use broadcast to notify all waiters
// that additional window is available.
w.Broadcast()
w.L.Unlock()
return true
}
// close sets the window to closed, so all reservations fail
// immediately.
func (w *window) close() {
w.L.Lock()
w.closed = true
w.Broadcast()
w.L.Unlock()
}
// reserve reserves win from the available window capacity.
// If no capacity remains, reserve will block. reserve may
// return less than requested.
func (w *window) reserve(win uint32) (uint32, error) {
var err error
w.L.Lock()
w.writeWaiters++
w.Broadcast()
for w.win == 0 && !w.closed {
w.Wait()
}
w.writeWaiters--
if w.win < win {
win = w.win
}
w.win -= win
if w.closed {
err = io.EOF
}
w.L.Unlock()
return win, err
}
// code below copied from x/crypto/ssh/messages.go
// (with some additional messages not included there)
// See RFC 4254, section 5.1.
const msgChannelOpen = 90
type channelOpenMsg struct {
ChanType string `sshtype:"90"`
PeersID uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
const (
msgChannelExtendedData = 95
msgChannelData = 94
)
// See RFC 4253, section 11.1.
const msgDisconnect = 1
// disconnectMsg is the message that signals a disconnect. It is also
// the error type returned from mux.Wait()
type disconnectMsg struct {
Reason uint32 `sshtype:"1"`
Message string
Language string
}
// Used for debug print outs of packets.
type channelDataMsg struct {
PeersID uint32 `sshtype:"94"`
Length uint32
Rest []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct {
PeersID uint32 `sshtype:"91"`
MyID uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
const msgChannelRequest = 98
type channelRequestMsg struct {
PeersID uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
}
type channelOpenDirectMsg struct {
DestAddr string
DestPort uint32
SrcAddr string
SrcPort uint32
}
type channelWindowChangeRequestMsg struct {
WidthColumns uint32
HeightRows uint32
WidthPx uint32
HeightPx uint32
}
type shellChannelRequestMsg struct{}
type execChannelRequestMsg struct {
Command string
}
// See RFC 4254, section 5.2
const msgChannelWindowAdjust = 93
type windowAdjustMsg struct {
PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32
}
// See RFC 4254, section 5.4.
const msgChannelSuccess = 99
type channelRequestSuccessMsg struct {
PeersID uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
const msgChannelFailure = 100
type channelRequestFailureMsg struct {
PeersID uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
const msgChannelClose = 97
type channelCloseMsg struct {
PeersID uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
const msgChannelEOF = 96
type channelEOFMsg struct {
PeersID uint32 `sshtype:"96"`
}