mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-05 20:32:57 +02:00
wip
This commit is contained in:
parent
08252f32df
commit
5e06f2aef9
1 changed files with 601 additions and 292 deletions
|
@ -33,6 +33,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/spf13/cobra"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
@ -41,6 +42,7 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protodelim"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
@ -188,17 +190,14 @@ func (a *Authorize) ManageStream(
|
|||
case *extensions_ssh.ClientMessage_Event:
|
||||
switch event := req.Event.Event.(type) {
|
||||
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
||||
fmt.Println("downstream connected")
|
||||
_ = event
|
||||
case *extensions_ssh.StreamEvent_UpstreamConnected:
|
||||
fmt.Printf("upstream connected: %d\n", event.UpstreamConnected.GetStreamId())
|
||||
activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state)
|
||||
defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId())
|
||||
case nil:
|
||||
}
|
||||
case *extensions_ssh.ClientMessage_AuthRequest:
|
||||
authReq := req.AuthRequest
|
||||
fmt.Println("auth request")
|
||||
if state.Username == "" {
|
||||
state.Username = authReq.Username
|
||||
}
|
||||
|
@ -633,269 +632,191 @@ func (a *Authorize) startContinuousAuthorization(
|
|||
}()
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpen = 90
|
||||
|
||||
type channelOpenMsg struct {
|
||||
ChanType string `sshtype:"90"`
|
||||
PeersID uint32
|
||||
PeersWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte `ssh:"rest"`
|
||||
func marshalAny(msg proto.Message) *anypb.Any {
|
||||
a, err := anypb.New(msg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
const (
|
||||
msgChannelExtendedData = 95
|
||||
msgChannelData = 94
|
||||
)
|
||||
|
||||
// Used for debug print outs of packets.
|
||||
type channelDataMsg struct {
|
||||
PeersID uint32 `sshtype:"94"`
|
||||
Length uint32
|
||||
Rest []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenConfirm = 91
|
||||
|
||||
type channelOpenConfirmMsg struct {
|
||||
PeersID uint32 `sshtype:"91"`
|
||||
MyID uint32
|
||||
MyWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
const msgChannelRequest = 98
|
||||
|
||||
type channelRequestMsg struct {
|
||||
PeersID uint32 `sshtype:"98"`
|
||||
Request string
|
||||
WantReply bool
|
||||
RequestSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
type channelOpenDirectMsg struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
SrcAddr string
|
||||
SrcPort uint32
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelSuccess = 99
|
||||
|
||||
type channelRequestSuccessMsg struct {
|
||||
PeersID uint32 `sshtype:"99"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelFailure = 100
|
||||
|
||||
type channelRequestFailureMsg struct {
|
||||
PeersID uint32 `sshtype:"100"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelClose = 97
|
||||
|
||||
type channelCloseMsg struct {
|
||||
PeersID uint32 `sshtype:"97"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelEOF = 96
|
||||
|
||||
type channelEOFMsg struct {
|
||||
PeersID uint32 `sshtype:"96"`
|
||||
}
|
||||
// sentinel error to indicate that the command triggered a handoff, and we
|
||||
// should not automatically disconnect
|
||||
var ErrHandoff = errors.New("handoff")
|
||||
|
||||
func (a *Authorize) ServeChannel(
|
||||
server extensions_ssh.StreamManagement_ServeChannelServer,
|
||||
) error {
|
||||
var program *tea.Program
|
||||
ctx := server.Context()
|
||||
inputR, inputW := io.Pipe()
|
||||
outputR, outputW := io.Pipe()
|
||||
var peerId uint32
|
||||
var activeProgram atomic.Pointer[tea.Program]
|
||||
|
||||
errC := make(chan error, 1)
|
||||
remoteWindow := &window{Cond: sync.NewCond(&sync.Mutex{})}
|
||||
sendC := make(chan any, 8)
|
||||
recvC := make(chan *extensions_ssh.ChannelMessage)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case msg := <-sendC:
|
||||
switch msg := msg.(type) {
|
||||
case *extensions_ssh.ChannelControl:
|
||||
log.Ctx(ctx).Debug().Msg("sending channel control message")
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_ChannelControl{
|
||||
ChannelControl: msg,
|
||||
},
|
||||
}); err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
case windowAdjustMsg, channelRequestMsg, channelRequestSuccessMsg, channelRequestFailureMsg, channelEOFMsg:
|
||||
// these messages don't consume window space
|
||||
data := gossh.Marshal(msg)
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes(data),
|
||||
},
|
||||
}); err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
log.Ctx(ctx).Debug().Uint8("type", data[0]).Msg("message sent")
|
||||
default:
|
||||
data := gossh.Marshal(msg)
|
||||
need := uint32(len(data))
|
||||
have := uint32(0)
|
||||
for have < need {
|
||||
n, err := remoteWindow.reserve(need - have)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
have += n
|
||||
}
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes(data),
|
||||
},
|
||||
}); err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
log.Ctx(ctx).Debug().Uint8("type", data[0]).Uint32("size", need).Msg("message sent")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
errC <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
localWindow := uint32(channelWindowSize)
|
||||
for {
|
||||
channelMsg, err := server.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
errC <- nil
|
||||
return
|
||||
}
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
||||
msgLen := uint32(len(raw.RawBytes.GetValue()))
|
||||
if msgLen == 0 {
|
||||
errC <- status.Errorf(codes.InvalidArgument, "peer sent empty message")
|
||||
return
|
||||
}
|
||||
if msgLen > channelMaxPacket {
|
||||
errC <- status.Errorf(codes.ResourceExhausted, "message too large")
|
||||
return
|
||||
}
|
||||
log.Ctx(ctx).Debug().Uint8("type", raw.RawBytes.Value[0]).Uint32("size", msgLen).Msg("message received")
|
||||
// peek the first byte to check if we need to deduct from the window
|
||||
switch raw.RawBytes.Value[0] {
|
||||
case msgChannelWindowAdjust, msgChannelRequest, msgChannelSuccess, msgChannelFailure, msgChannelEOF:
|
||||
// these messages don't consume window space
|
||||
default:
|
||||
if localWindow < msgLen {
|
||||
errC <- status.Errorf(codes.ResourceExhausted, "peer sent more bytes than allowed by channel window")
|
||||
return
|
||||
}
|
||||
localWindow -= msgLen
|
||||
if localWindow < channelWindowSize/2 {
|
||||
log.Ctx(ctx).Debug().Msg("flow control: increasing local window size")
|
||||
localWindow += channelWindowSize
|
||||
sendC <- windowAdjustMsg{
|
||||
PeersID: peerId,
|
||||
AdditionalBytes: channelWindowSize,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case recvC <- channelMsg:
|
||||
case <-ctx.Done():
|
||||
errC <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
||||
var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo
|
||||
var channelIdCounter uint32
|
||||
for {
|
||||
channelMsg, err := server.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
rawMsg := channelMsg.GetRawBytes().GetValue()
|
||||
switch rawMsg[0] {
|
||||
case msgChannelOpen:
|
||||
var msg channelOpenMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
|
||||
var confirm channelOpenConfirmMsg
|
||||
peerId = msg.PeersID
|
||||
confirm.PeersID = peerId
|
||||
confirm.MyID = 1
|
||||
confirm.MyWindow = msg.PeersWindow
|
||||
confirm.MaxPacketSize = msg.MaxPacketSize
|
||||
downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: msg.ChanType,
|
||||
DownstreamChannelId: confirm.PeersID,
|
||||
InternalUpstreamChannelId: confirm.MyID,
|
||||
InitialWindowSize: confirm.MyWindow,
|
||||
MaxPacketSize: confirm.MaxPacketSize,
|
||||
}
|
||||
switch msg.ChanType {
|
||||
case "session":
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: &wrapperspb.BytesValue{
|
||||
Value: gossh.Marshal(confirm),
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
select {
|
||||
case channelMsg := <-recvC:
|
||||
rawMsg := channelMsg.GetRawBytes().GetValue()
|
||||
switch rawMsg[0] {
|
||||
case msgChannelOpen:
|
||||
var msg channelOpenMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
channelIdCounter++
|
||||
if channelIdCounter > 1 {
|
||||
return fmt.Errorf("only one channel can be opened")
|
||||
}
|
||||
case "direct-tcpip":
|
||||
var subMsg channelOpenDirectMsg
|
||||
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
||||
return err
|
||||
peerId = msg.PeersID
|
||||
downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: msg.ChanType,
|
||||
DownstreamChannelId: peerId,
|
||||
InternalUpstreamChannelId: channelIdCounter,
|
||||
InitialWindowSize: msg.PeersWindow,
|
||||
MaxPacketSize: msg.MaxPacketSize,
|
||||
}
|
||||
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{
|
||||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||||
DownstreamChannelInfo: downstreamChannelInfo,
|
||||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||||
Target: &extensions_ssh.AllowResponse_Upstream{
|
||||
Upstream: &extensions_ssh.UpstreamTarget{
|
||||
Hostname: subMsg.DestAddr,
|
||||
DirectTcpip: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_ChannelControl{
|
||||
ChannelControl: &extensions_ssh.ChannelControl{
|
||||
Protocol: "ssh",
|
||||
ControlAction: handOff,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case msgChannelRequest:
|
||||
var msg channelRequestMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
|
||||
switch msg.Request {
|
||||
case "pty-req":
|
||||
opts := a.currentConfig.Load().Options
|
||||
var routes []string
|
||||
for r := range opts.GetAllPolicies() {
|
||||
if strings.HasPrefix(r.From, "ssh://") {
|
||||
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+opts.SSHHostname)))
|
||||
remoteWindow.add(msg.PeersWindow)
|
||||
switch msg.ChanType {
|
||||
case "session":
|
||||
sendC <- channelOpenConfirmMsg{
|
||||
PeersID: peerId,
|
||||
MyID: channelIdCounter,
|
||||
MyWindow: channelWindowSize,
|
||||
MaxPacketSize: channelMaxPacket,
|
||||
}
|
||||
}
|
||||
req := parsePtyReq(msg.RequestSpecificData)
|
||||
items := []list.Item{}
|
||||
for _, route := range routes {
|
||||
items = append(items, item(route))
|
||||
}
|
||||
activeStreamIds.Range(func(key, value any) bool {
|
||||
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
|
||||
return true
|
||||
})
|
||||
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
|
||||
TermEnv: req.TermEnv,
|
||||
WidthColumns: req.Width,
|
||||
HeightRows: req.Height,
|
||||
WidthPx: req.WidthPx,
|
||||
HeightPx: req.HeightPx,
|
||||
Modes: req.Modes,
|
||||
}
|
||||
|
||||
const defaultWidth = 20
|
||||
|
||||
l := list.New(items, itemDelegate{}, defaultWidth, listHeight)
|
||||
l.Title = "Connect to which server?"
|
||||
l.SetShowStatusBar(false)
|
||||
l.SetFilteringEnabled(false)
|
||||
l.Styles.Title = titleStyle
|
||||
l.Styles.PaginationStyle = paginationStyle
|
||||
l.Styles.HelpStyle = helpStyle
|
||||
|
||||
program = tea.NewProgram(model{list: l},
|
||||
tea.WithInput(inputR),
|
||||
tea.WithOutput(outputW),
|
||||
tea.WithAltScreen(),
|
||||
tea.WithContext(server.Context()),
|
||||
tea.WithEnvironment([]string{"TERM=" + req.TermEnv}),
|
||||
)
|
||||
go func() {
|
||||
answer, err := program.Run()
|
||||
if err != nil {
|
||||
return
|
||||
case "direct-tcpip":
|
||||
var subMsg channelOpenDirectMsg
|
||||
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
var handOff *anypb.Any
|
||||
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
|
||||
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{
|
||||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||||
DownstreamChannelInfo: downstreamChannelInfo,
|
||||
DownstreamPtyInfo: downstreamPtyInfo,
|
||||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||||
Target: &extensions_ssh.AllowResponse_MirrorSession{
|
||||
MirrorSession: &extensions_ssh.MirrorSessionTarget{
|
||||
SourceId: id,
|
||||
Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
|
||||
},
|
||||
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{
|
||||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||||
DownstreamChannelInfo: downstreamChannelInfo,
|
||||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||||
Target: &extensions_ssh.AllowResponse_Upstream{
|
||||
Upstream: &extensions_ssh.UpstreamTarget{
|
||||
Hostname: subMsg.DestAddr,
|
||||
DirectTcpip: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
|
||||
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
||||
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
|
||||
Format: extensions_session_recording.Format_AsciicastFormat,
|
||||
})
|
||||
handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{
|
||||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||||
DownstreamChannelInfo: downstreamChannelInfo,
|
||||
DownstreamPtyInfo: downstreamPtyInfo,
|
||||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||||
Username: username,
|
||||
Target: &extensions_ssh.AllowResponse_Upstream{
|
||||
Upstream: &extensions_ssh.UpstreamTarget{
|
||||
AllowMirrorConnections: true,
|
||||
Hostname: hostname,
|
||||
Extensions: []*corev3.TypedExtensionConfig{
|
||||
{
|
||||
TypedConfig: sessionRecordingExt,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
},
|
||||
})
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_ChannelControl{
|
||||
ChannelControl: &extensions_ssh.ChannelControl{
|
||||
|
@ -904,57 +825,101 @@ func (a *Authorize) ServeChannel(
|
|||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
var buf [4096]byte
|
||||
for {
|
||||
n, err := outputR.Read(buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: &wrapperspb.BytesValue{
|
||||
Value: gossh.Marshal(channelDataMsg{
|
||||
PeersID: peerId,
|
||||
Length: uint32(n),
|
||||
Rest: buf[:n],
|
||||
}),
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
program.Send(tea.WindowSizeMsg{Width: int(req.Width), Height: int(req.Height)})
|
||||
}
|
||||
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: &wrapperspb.BytesValue{
|
||||
Value: gossh.Marshal(channelRequestSuccessMsg{
|
||||
PeersID: peerId,
|
||||
}),
|
||||
case msgChannelRequest:
|
||||
var msg channelRequestMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
|
||||
switch msg.Request {
|
||||
case "shell", "exec":
|
||||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: &wrapperspb.BytesValue{
|
||||
Value: gossh.Marshal(channelRequestSuccessMsg{
|
||||
PeersID: peerId,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, inputR, outputW, sendC, &activeProgram)
|
||||
if msg.Request == "shell" {
|
||||
cmd.SetArgs([]string{"portal"})
|
||||
} else {
|
||||
var execReq execChannelRequestMsg
|
||||
if err := gossh.Unmarshal(msg.RequestSpecificData, &execReq); err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.SetArgs(strings.Fields(execReq.Command))
|
||||
}
|
||||
go func() {
|
||||
defer activeProgram.Store(nil)
|
||||
defer outputW.Close()
|
||||
defer inputR.Close()
|
||||
err := cmd.Execute()
|
||||
if !errors.Is(err, ErrHandoff) {
|
||||
sendC <- &extensions_ssh.ChannelControl{
|
||||
Protocol: "ssh",
|
||||
ControlAction: marshalAny(&extensions_ssh.SSHChannelControlAction_Disconnect{
|
||||
ReasonCode: 11,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}()
|
||||
go streamOutputToChannel(sendC, peerId, outputR)
|
||||
|
||||
case "pty-req":
|
||||
req := parsePtyReq(msg.RequestSpecificData)
|
||||
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
|
||||
TermEnv: req.TermEnv,
|
||||
WidthColumns: req.Width,
|
||||
HeightRows: req.Height,
|
||||
WidthPx: req.WidthPx,
|
||||
HeightPx: req.HeightPx,
|
||||
Modes: req.Modes,
|
||||
}
|
||||
sendC <- channelRequestSuccessMsg{PeersID: peerId}
|
||||
case "window-change":
|
||||
var req channelWindowChangeRequestMsg
|
||||
if err := gossh.Unmarshal(msg.RequestSpecificData, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
if p := activeProgram.Load(); p != nil {
|
||||
p.Send(tea.WindowSizeMsg{
|
||||
Width: int(req.WidthColumns),
|
||||
Height: int(req.HeightRows),
|
||||
})
|
||||
}
|
||||
}
|
||||
case msgChannelData:
|
||||
var msg channelDataMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
if activeProgram.Load() != nil {
|
||||
inputW.Write(msg.Rest)
|
||||
}
|
||||
case msgChannelClose:
|
||||
var msg channelDataMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
case msgChannelWindowAdjust:
|
||||
var msg windowAdjustMsg
|
||||
if err := gossh.Unmarshal(rawMsg, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Ctx(ctx).Debug().Uint32("bytes", msg.AdditionalBytes).Msg("flow control: remote window size increased")
|
||||
remoteWindow.add(msg.AdditionalBytes)
|
||||
case msgChannelEOF:
|
||||
return nil
|
||||
default:
|
||||
panic("unhandled message: " + fmt.Sprint(rawMsg[1]))
|
||||
}
|
||||
case msgChannelData:
|
||||
var msg channelDataMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
|
||||
if program != nil {
|
||||
inputW.Write(msg.Rest)
|
||||
}
|
||||
case msgChannelClose:
|
||||
var msg channelDataMsg
|
||||
gossh.Unmarshal(rawMsg, &msg)
|
||||
default:
|
||||
panic("unhandled message: " + fmt.Sprint(rawMsg[1]))
|
||||
case err := <-errC:
|
||||
log.Ctx(ctx).Err(err).Msg("channel error")
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -981,7 +946,164 @@ func parsePtyReq(reqData []byte) ptyReq {
|
|||
}
|
||||
}
|
||||
|
||||
const listHeight = 14
|
||||
func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader) {
|
||||
var buf [4096]byte
|
||||
for {
|
||||
n, err := outputR.Read(buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sendC <- channelDataMsg{
|
||||
PeersID: channelID,
|
||||
Length: uint32(n),
|
||||
Rest: buf[:n],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewSSHCLI(
|
||||
cfg *config.Config,
|
||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||
stdin io.Reader,
|
||||
stdout io.Writer,
|
||||
sendC chan any,
|
||||
activeProgram *atomic.Pointer[tea.Program],
|
||||
) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "pomerium",
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
_, cmdIsInteractive := cmd.Annotations["interactive"]
|
||||
switch {
|
||||
case (ptyInfo == nil) && cmdIsInteractive:
|
||||
cmd.SilenceUsage = true
|
||||
return fmt.Errorf("\x1b[31m'%s' is an interactive command and requires a TTY (try passing '-t' to ssh)\x1b[0m", cmd.Use)
|
||||
case (ptyInfo != nil) && !cmdIsInteractive:
|
||||
cmd.SilenceUsage = true
|
||||
return fmt.Errorf("\x1b[31m'%s' is not an interactive command (try passing '-T' to ssh, or removing '-t')\x1b[0m\r", cmd.Use)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
|
||||
cmd.SetIn(stdin)
|
||||
cmd.SetOut(stdout)
|
||||
cmd.SetErr(stdout)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func NewPortalCommand(
|
||||
cfg *config.Config,
|
||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||
sendC chan any,
|
||||
activeProgram *atomic.Pointer[tea.Program],
|
||||
) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "portal",
|
||||
Short: "Interactive route portal",
|
||||
Annotations: map[string]string{
|
||||
"interactive": "",
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var routes []string
|
||||
for r := range cfg.Options.GetAllPolicies() {
|
||||
if strings.HasPrefix(r.From, "ssh://") {
|
||||
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
|
||||
}
|
||||
}
|
||||
items := []list.Item{}
|
||||
for _, route := range routes {
|
||||
items = append(items, item(route))
|
||||
}
|
||||
activeStreamIds.Range(func(key, value any) bool {
|
||||
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
|
||||
return true
|
||||
})
|
||||
|
||||
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
|
||||
l.Title = "Connect to which server?"
|
||||
l.SetShowStatusBar(false)
|
||||
l.SetFilteringEnabled(false)
|
||||
l.Styles.Title = titleStyle
|
||||
l.Styles.PaginationStyle = paginationStyle
|
||||
l.Styles.HelpStyle = helpStyle
|
||||
|
||||
program := tea.NewProgram(model{list: l},
|
||||
tea.WithInput(cmd.InOrStdin()),
|
||||
tea.WithOutput(cmd.OutOrStdout()),
|
||||
tea.WithAltScreen(),
|
||||
tea.WithContext(cmd.Context()),
|
||||
tea.WithEnvironment([]string{"TERM=" + ptyInfo.TermEnv}),
|
||||
)
|
||||
activeProgram.Store(program)
|
||||
|
||||
go program.Send(tea.WindowSizeMsg{Width: int(ptyInfo.WidthColumns), Height: int(ptyInfo.HeightRows)})
|
||||
answer, err := program.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var handOff *anypb.Any
|
||||
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
|
||||
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{
|
||||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||||
DownstreamChannelInfo: channelInfo,
|
||||
DownstreamPtyInfo: ptyInfo,
|
||||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||||
Target: &extensions_ssh.AllowResponse_MirrorSession{
|
||||
MirrorSession: &extensions_ssh.MirrorSessionTarget{
|
||||
SourceId: id,
|
||||
Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
|
||||
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
||||
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
|
||||
Format: extensions_session_recording.Format_AsciicastFormat,
|
||||
})
|
||||
handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{
|
||||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||||
DownstreamChannelInfo: channelInfo,
|
||||
DownstreamPtyInfo: ptyInfo,
|
||||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||||
Username: username,
|
||||
Target: &extensions_ssh.AllowResponse_Upstream{
|
||||
Upstream: &extensions_ssh.UpstreamTarget{
|
||||
AllowMirrorConnections: true,
|
||||
Hostname: hostname,
|
||||
Extensions: []*corev3.TypedExtensionConfig{
|
||||
{
|
||||
TypedConfig: sessionRecordingExt,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
sendC <- &extensions_ssh.ChannelControl{
|
||||
Protocol: "ssh",
|
||||
ControlAction: handOff,
|
||||
}
|
||||
return ErrHandoff
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
var (
|
||||
titleStyle = lipgloss.NewStyle().MarginLeft(2)
|
||||
|
@ -1032,7 +1154,8 @@ func (m model) Init() tea.Cmd {
|
|||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.list.SetWidth(msg.Width)
|
||||
m.list.SetWidth(msg.Width - 2)
|
||||
m.list.SetHeight(msg.Height - 2)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
|
@ -1058,3 +1181,189 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
func (m model) View() string {
|
||||
return "\n" + m.list.View()
|
||||
}
|
||||
|
||||
// code below copied from x/crypto/ssh/common.go
|
||||
|
||||
const (
|
||||
// channelMaxPacket contains the maximum number of bytes that will be
|
||||
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
||||
// the minimum.
|
||||
channelMaxPacket = 1 << 15
|
||||
// We follow OpenSSH here.
|
||||
channelWindowSize = 64 * channelMaxPacket
|
||||
)
|
||||
|
||||
// window represents the buffer available to clients
|
||||
// wishing to write to a channel.
|
||||
type window struct {
|
||||
*sync.Cond
|
||||
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
||||
writeWaiters int
|
||||
closed bool
|
||||
}
|
||||
|
||||
// add adds win to the amount of window available
|
||||
// for consumers.
|
||||
func (w *window) add(win uint32) bool {
|
||||
// a zero sized window adjust is a noop.
|
||||
if win == 0 {
|
||||
return true
|
||||
}
|
||||
w.L.Lock()
|
||||
if w.win+win < win {
|
||||
w.L.Unlock()
|
||||
return false
|
||||
}
|
||||
w.win += win
|
||||
// It is unusual that multiple goroutines would be attempting to reserve
|
||||
// window space, but not guaranteed. Use broadcast to notify all waiters
|
||||
// that additional window is available.
|
||||
w.Broadcast()
|
||||
w.L.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// close sets the window to closed, so all reservations fail
|
||||
// immediately.
|
||||
func (w *window) close() {
|
||||
w.L.Lock()
|
||||
w.closed = true
|
||||
w.Broadcast()
|
||||
w.L.Unlock()
|
||||
}
|
||||
|
||||
// reserve reserves win from the available window capacity.
|
||||
// If no capacity remains, reserve will block. reserve may
|
||||
// return less than requested.
|
||||
func (w *window) reserve(win uint32) (uint32, error) {
|
||||
var err error
|
||||
w.L.Lock()
|
||||
w.writeWaiters++
|
||||
w.Broadcast()
|
||||
for w.win == 0 && !w.closed {
|
||||
w.Wait()
|
||||
}
|
||||
w.writeWaiters--
|
||||
if w.win < win {
|
||||
win = w.win
|
||||
}
|
||||
w.win -= win
|
||||
if w.closed {
|
||||
err = io.EOF
|
||||
}
|
||||
w.L.Unlock()
|
||||
return win, err
|
||||
}
|
||||
|
||||
// code below copied from x/crypto/ssh/messages.go
|
||||
// (with some additional messages not included there)
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpen = 90
|
||||
|
||||
type channelOpenMsg struct {
|
||||
ChanType string `sshtype:"90"`
|
||||
PeersID uint32
|
||||
PeersWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
const (
|
||||
msgChannelExtendedData = 95
|
||||
msgChannelData = 94
|
||||
)
|
||||
|
||||
// See RFC 4253, section 11.1.
|
||||
const msgDisconnect = 1
|
||||
|
||||
// disconnectMsg is the message that signals a disconnect. It is also
|
||||
// the error type returned from mux.Wait()
|
||||
type disconnectMsg struct {
|
||||
Reason uint32 `sshtype:"1"`
|
||||
Message string
|
||||
Language string
|
||||
}
|
||||
|
||||
// Used for debug print outs of packets.
|
||||
type channelDataMsg struct {
|
||||
PeersID uint32 `sshtype:"94"`
|
||||
Length uint32
|
||||
Rest []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenConfirm = 91
|
||||
|
||||
type channelOpenConfirmMsg struct {
|
||||
PeersID uint32 `sshtype:"91"`
|
||||
MyID uint32
|
||||
MyWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
const msgChannelRequest = 98
|
||||
|
||||
type channelRequestMsg struct {
|
||||
PeersID uint32 `sshtype:"98"`
|
||||
Request string
|
||||
WantReply bool
|
||||
RequestSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
type channelOpenDirectMsg struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
SrcAddr string
|
||||
SrcPort uint32
|
||||
}
|
||||
|
||||
type channelWindowChangeRequestMsg struct {
|
||||
WidthColumns uint32
|
||||
HeightRows uint32
|
||||
WidthPx uint32
|
||||
HeightPx uint32
|
||||
}
|
||||
|
||||
type shellChannelRequestMsg struct{}
|
||||
|
||||
type execChannelRequestMsg struct {
|
||||
Command string
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.2
|
||||
const msgChannelWindowAdjust = 93
|
||||
|
||||
type windowAdjustMsg struct {
|
||||
PeersID uint32 `sshtype:"93"`
|
||||
AdditionalBytes uint32
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelSuccess = 99
|
||||
|
||||
type channelRequestSuccessMsg struct {
|
||||
PeersID uint32 `sshtype:"99"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelFailure = 100
|
||||
|
||||
type channelRequestFailureMsg struct {
|
||||
PeersID uint32 `sshtype:"100"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelClose = 97
|
||||
|
||||
type channelCloseMsg struct {
|
||||
PeersID uint32 `sshtype:"97"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelEOF = 96
|
||||
|
||||
type channelEOFMsg struct {
|
||||
PeersID uint32 `sshtype:"96"`
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue