wip: ssh cli mode

This commit is contained in:
Joe Kralicky 2025-03-21 18:38:26 +00:00
parent 5e06f2aef9
commit 3225d3b032
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79

View file

@ -15,6 +15,7 @@ import (
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"text/template"
"time" "time"
"github.com/charmbracelet/bubbles/list" "github.com/charmbracelet/bubbles/list"
@ -45,6 +46,7 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb" "google.golang.org/protobuf/types/known/wrapperspb"
) )
@ -226,24 +228,6 @@ func (a *Authorize) ManageStream(
if authReq.Username == "" { if authReq.Username == "" {
return fmt.Errorf("no username given") return fmt.Errorf("no username given")
} }
if authReq.Hostname == "" {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{},
},
},
},
},
},
}
sendC <- &resp
continue
}
if session != nil { if session != nil {
// Perform authorize check for this route // Perform authorize check for this route
@ -255,7 +239,7 @@ func (a *Authorize) ManageStream(
if err != nil { if err != nil {
return err return err
} }
sendC <- handleEvaluatorResponseForSSH(res, &state) sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id)
if res.Allow.Value && !res.Deny.Value { if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id) a.startContinuousAuthorization(ctx, errC, req, session.Id)
@ -279,9 +263,9 @@ func (a *Authorize) ManageStream(
} }
case "keyboard-interactive": case "keyboard-interactive":
route := a.getSSHRouteForHostname(state.Hostname) route := a.getSSHRouteForHostname(state.Hostname)
if route == nil { // if route == nil {
return fmt.Errorf("invalid route") // return fmt.Errorf("invalid route")
} // }
opts := a.currentConfig.Load().Options opts := a.currentConfig.Load().Options
idp, err := opts.GetIdentityProviderForPolicy(route) idp, err := opts.GetIdentityProviderForPolicy(route)
@ -393,7 +377,7 @@ func (a *Authorize) ManageStream(
if err != nil { if err != nil {
return err return err
} }
sendC <- handleEvaluatorResponseForSSH(res, &state) sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID)
if res.Allow.Value && !res.Deny.Value { if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID) a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
@ -468,7 +452,14 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
} }
route := a.getSSHRouteForHostname(state.Hostname) route := a.getSSHRouteForHostname(state.Hostname)
if route == nil { if route == nil {
return nil, fmt.Errorf("no route found for hostname %q", state.Hostname) return &evaluator.Request{
IsInternal: true,
Session: evaluator.RequestSession{
ID: sessionID,
},
}, nil
// return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
} }
req := &evaluator.Request{ req := &evaluator.Request{
IsInternal: false, IsInternal: false,
@ -485,7 +476,7 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
} }
func handleEvaluatorResponseForSSH( func handleEvaluatorResponseForSSH(
result *evaluator.Result, state *StreamState, result *evaluator.Result, state *StreamState, sessionID string,
) *extensions_ssh.ServerMessage { ) *extensions_ssh.ServerMessage {
// fmt.Printf(" *** evaluator result: %+v\n", result) // fmt.Printf(" *** evaluator result: %+v\n", result)
@ -494,6 +485,33 @@ func handleEvaluatorResponseForSSH(
if allow { if allow {
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey)) pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
if state.Hostname == "" {
return &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{
SetMetadata: &corev3.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"pomerium": {
Fields: map[string]*structpb.Value{
"session-id": structpb.NewStringValue(sessionID),
},
},
},
},
},
},
},
},
},
},
}
}
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()), RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat, Format: extensions_session_recording.Format_AsciicastFormat,
@ -712,6 +730,8 @@ func (a *Authorize) ServeChannel(
} }
} }
}() }()
var sessionID atomic.Pointer[string]
go func() { go func() {
localWindow := uint32(channelWindowSize) localWindow := uint32(channelWindowSize)
for { for {
@ -724,6 +744,16 @@ func (a *Authorize) ServeChannel(
errC <- err errC <- err
return return
} }
if sessionID.Load() == nil {
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
if !ok {
errC <- fmt.Errorf("first message was not metadata")
return
}
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue()
sessionID.Store(&id)
continue
}
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok { if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
msgLen := uint32(len(raw.RawBytes.GetValue())) msgLen := uint32(len(raw.RawBytes.GetValue()))
if msgLen == 0 { if msgLen == 0 {
@ -847,7 +877,7 @@ func (a *Authorize) ServeChannel(
return err return err
} }
cmd := NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, inputR, outputW, sendC, &activeProgram) cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram)
if msg.Request == "shell" { if msg.Request == "shell" {
cmd.SetArgs([]string{"portal"}) cmd.SetArgs([]string{"portal"})
} else { } else {
@ -956,15 +986,16 @@ func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader
sendC <- channelDataMsg{ sendC <- channelDataMsg{
PeersID: channelID, PeersID: channelID,
Length: uint32(n), Length: uint32(n),
Rest: buf[:n], Rest: slices.Clone(buf[:n]),
} }
} }
} }
func NewSSHCLI( func (a *Authorize) NewSSHCLI(
cfg *config.Config, cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
channelInfo *extensions_ssh.SSHDownstreamChannelInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
sessionID string,
stdin io.Reader, stdin io.Reader,
stdout io.Writer, stdout io.Writer,
sendC chan any, sendC chan any,
@ -986,12 +1017,66 @@ func NewSSHCLI(
}, },
} }
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram)) cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
cmd.CompletionOptions.DisableDefaultCmd = true
cmd.SetIn(stdin) cmd.SetIn(stdin)
cmd.SetOut(stdout) cmd.SetOut(stdout)
cmd.SetErr(stdout) cmd.SetErr(stdout)
return cmd return cmd
} }
func (a *Authorize) NewLogoutCommand(
cfg *config.Config,
sessionID string,
) *cobra.Command {
cmd := &cobra.Command{
Use: "logout",
Short: "Log out",
RunE: func(cmd *cobra.Command, args []string) error {
client := a.state.Load().dataBrokerClient
err := session.Delete(cmd.Context(), client, sessionID)
if err != nil {
return fmt.Errorf("internal error: %w", err)
}
cmd.OutOrStdout().Write([]byte("Logged out successfully\r\n"))
return nil
},
}
return cmd
}
var whoamiTmpl = template.Must(template.New("whoami").Parse(`
User ID: {{.UserId}}
Session ID: {{.Id}}
Expires at: {{.ExpiresAt.AsTime}}
Claims:
{{- range $k, $v := .Claims }}
{{ $k }}: {{ $v.AsSlice }}
{{- end }}
`))
func (a *Authorize) NewWhoamiCommand(
cfg *config.Config,
sessionID string,
) *cobra.Command {
cmd := &cobra.Command{
Use: "whoami",
RunE: func(cmd *cobra.Command, args []string) error {
client := a.state.Load().dataBrokerClient
s, err := session.Get(cmd.Context(), client, sessionID)
if err != nil {
return fmt.Errorf("couldn't fetch session: %w", err)
}
var b bytes.Buffer
whoamiTmpl.Execute(&b, s)
cmd.OutOrStdout().Write([]byte(b.String() + "\r\n"))
return nil
},
}
return cmd
}
func NewPortalCommand( func NewPortalCommand(
cfg *config.Config, cfg *config.Config,
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,