mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
wip: ssh cli mode
This commit is contained in:
parent
5e06f2aef9
commit
3225d3b032
1 changed files with 113 additions and 28 deletions
|
@ -15,6 +15,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/charmbracelet/bubbles/list"
|
"github.com/charmbracelet/bubbles/list"
|
||||||
|
@ -45,6 +46,7 @@ import (
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
"google.golang.org/protobuf/types/known/emptypb"
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
)
|
)
|
||||||
|
@ -226,24 +228,6 @@ func (a *Authorize) ManageStream(
|
||||||
if authReq.Username == "" {
|
if authReq.Username == "" {
|
||||||
return fmt.Errorf("no username given")
|
return fmt.Errorf("no username given")
|
||||||
}
|
}
|
||||||
if authReq.Hostname == "" {
|
|
||||||
resp := extensions_ssh.ServerMessage{
|
|
||||||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
||||||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
||||||
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
|
||||||
Allow: &extensions_ssh.AllowResponse{
|
|
||||||
Username: state.Username,
|
|
||||||
Target: &extensions_ssh.AllowResponse_Internal{
|
|
||||||
Internal: &extensions_ssh.InternalTarget{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sendC <- &resp
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if session != nil {
|
if session != nil {
|
||||||
// Perform authorize check for this route
|
// Perform authorize check for this route
|
||||||
|
@ -255,7 +239,7 @@ func (a *Authorize) ManageStream(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sendC <- handleEvaluatorResponseForSSH(res, &state)
|
sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id)
|
||||||
|
|
||||||
if res.Allow.Value && !res.Deny.Value {
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
||||||
|
@ -279,9 +263,9 @@ func (a *Authorize) ManageStream(
|
||||||
}
|
}
|
||||||
case "keyboard-interactive":
|
case "keyboard-interactive":
|
||||||
route := a.getSSHRouteForHostname(state.Hostname)
|
route := a.getSSHRouteForHostname(state.Hostname)
|
||||||
if route == nil {
|
// if route == nil {
|
||||||
return fmt.Errorf("invalid route")
|
// return fmt.Errorf("invalid route")
|
||||||
}
|
// }
|
||||||
|
|
||||||
opts := a.currentConfig.Load().Options
|
opts := a.currentConfig.Load().Options
|
||||||
idp, err := opts.GetIdentityProviderForPolicy(route)
|
idp, err := opts.GetIdentityProviderForPolicy(route)
|
||||||
|
@ -393,7 +377,7 @@ func (a *Authorize) ManageStream(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sendC <- handleEvaluatorResponseForSSH(res, &state)
|
sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID)
|
||||||
|
|
||||||
if res.Allow.Value && !res.Deny.Value {
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
|
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
|
||||||
|
@ -468,7 +452,14 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
||||||
}
|
}
|
||||||
route := a.getSSHRouteForHostname(state.Hostname)
|
route := a.getSSHRouteForHostname(state.Hostname)
|
||||||
if route == nil {
|
if route == nil {
|
||||||
return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
|
return &evaluator.Request{
|
||||||
|
IsInternal: true,
|
||||||
|
Session: evaluator.RequestSession{
|
||||||
|
ID: sessionID,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
// return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
|
||||||
}
|
}
|
||||||
req := &evaluator.Request{
|
req := &evaluator.Request{
|
||||||
IsInternal: false,
|
IsInternal: false,
|
||||||
|
@ -485,7 +476,7 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleEvaluatorResponseForSSH(
|
func handleEvaluatorResponseForSSH(
|
||||||
result *evaluator.Result, state *StreamState,
|
result *evaluator.Result, state *StreamState, sessionID string,
|
||||||
) *extensions_ssh.ServerMessage {
|
) *extensions_ssh.ServerMessage {
|
||||||
// fmt.Printf(" *** evaluator result: %+v\n", result)
|
// fmt.Printf(" *** evaluator result: %+v\n", result)
|
||||||
|
|
||||||
|
@ -494,6 +485,33 @@ func handleEvaluatorResponseForSSH(
|
||||||
|
|
||||||
if allow {
|
if allow {
|
||||||
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
|
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
|
||||||
|
|
||||||
|
if state.Hostname == "" {
|
||||||
|
return &extensions_ssh.ServerMessage{
|
||||||
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||||||
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||||||
|
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
||||||
|
Allow: &extensions_ssh.AllowResponse{
|
||||||
|
Username: state.Username,
|
||||||
|
Target: &extensions_ssh.AllowResponse_Internal{
|
||||||
|
Internal: &extensions_ssh.InternalTarget{
|
||||||
|
SetMetadata: &corev3.Metadata{
|
||||||
|
FilterMetadata: map[string]*structpb.Struct{
|
||||||
|
"pomerium": {
|
||||||
|
Fields: map[string]*structpb.Value{
|
||||||
|
"session-id": structpb.NewStringValue(sessionID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
||||||
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
|
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
|
||||||
Format: extensions_session_recording.Format_AsciicastFormat,
|
Format: extensions_session_recording.Format_AsciicastFormat,
|
||||||
|
@ -712,6 +730,8 @@ func (a *Authorize) ServeChannel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var sessionID atomic.Pointer[string]
|
||||||
go func() {
|
go func() {
|
||||||
localWindow := uint32(channelWindowSize)
|
localWindow := uint32(channelWindowSize)
|
||||||
for {
|
for {
|
||||||
|
@ -724,6 +744,16 @@ func (a *Authorize) ServeChannel(
|
||||||
errC <- err
|
errC <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if sessionID.Load() == nil {
|
||||||
|
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
|
||||||
|
if !ok {
|
||||||
|
errC <- fmt.Errorf("first message was not metadata")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue()
|
||||||
|
sessionID.Store(&id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
||||||
msgLen := uint32(len(raw.RawBytes.GetValue()))
|
msgLen := uint32(len(raw.RawBytes.GetValue()))
|
||||||
if msgLen == 0 {
|
if msgLen == 0 {
|
||||||
|
@ -847,7 +877,7 @@ func (a *Authorize) ServeChannel(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, inputR, outputW, sendC, &activeProgram)
|
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram)
|
||||||
if msg.Request == "shell" {
|
if msg.Request == "shell" {
|
||||||
cmd.SetArgs([]string{"portal"})
|
cmd.SetArgs([]string{"portal"})
|
||||||
} else {
|
} else {
|
||||||
|
@ -956,15 +986,16 @@ func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader
|
||||||
sendC <- channelDataMsg{
|
sendC <- channelDataMsg{
|
||||||
PeersID: channelID,
|
PeersID: channelID,
|
||||||
Length: uint32(n),
|
Length: uint32(n),
|
||||||
Rest: buf[:n],
|
Rest: slices.Clone(buf[:n]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSSHCLI(
|
func (a *Authorize) NewSSHCLI(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||||
|
sessionID string,
|
||||||
stdin io.Reader,
|
stdin io.Reader,
|
||||||
stdout io.Writer,
|
stdout io.Writer,
|
||||||
sendC chan any,
|
sendC chan any,
|
||||||
|
@ -986,12 +1017,66 @@ func NewSSHCLI(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
|
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
|
||||||
|
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
|
||||||
|
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
|
||||||
|
cmd.CompletionOptions.DisableDefaultCmd = true
|
||||||
cmd.SetIn(stdin)
|
cmd.SetIn(stdin)
|
||||||
cmd.SetOut(stdout)
|
cmd.SetOut(stdout)
|
||||||
cmd.SetErr(stdout)
|
cmd.SetErr(stdout)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) NewLogoutCommand(
|
||||||
|
cfg *config.Config,
|
||||||
|
sessionID string,
|
||||||
|
) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "logout",
|
||||||
|
Short: "Log out",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
client := a.state.Load().dataBrokerClient
|
||||||
|
err := session.Delete(cmd.Context(), client, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("internal error: %w", err)
|
||||||
|
}
|
||||||
|
cmd.OutOrStdout().Write([]byte("Logged out successfully\r\n"))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
var whoamiTmpl = template.Must(template.New("whoami").Parse(`
|
||||||
|
User ID: {{.UserId}}
|
||||||
|
Session ID: {{.Id}}
|
||||||
|
Expires at: {{.ExpiresAt.AsTime}}
|
||||||
|
Claims:
|
||||||
|
{{- range $k, $v := .Claims }}
|
||||||
|
{{ $k }}: {{ $v.AsSlice }}
|
||||||
|
{{- end }}
|
||||||
|
`))
|
||||||
|
|
||||||
|
func (a *Authorize) NewWhoamiCommand(
|
||||||
|
cfg *config.Config,
|
||||||
|
sessionID string,
|
||||||
|
) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "whoami",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
client := a.state.Load().dataBrokerClient
|
||||||
|
s, err := session.Get(cmd.Context(), client, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't fetch session: %w", err)
|
||||||
|
}
|
||||||
|
var b bytes.Buffer
|
||||||
|
whoamiTmpl.Execute(&b, s)
|
||||||
|
cmd.OutOrStdout().Write([]byte(b.String() + "\r\n"))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
func NewPortalCommand(
|
func NewPortalCommand(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||||
|
|
Loading…
Add table
Reference in a new issue