mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
wip: ssh cli mode
This commit is contained in:
parent
5e06f2aef9
commit
3225d3b032
1 changed files with 113 additions and 28 deletions
|
@ -15,6 +15,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
|
@ -45,6 +46,7 @@ import (
|
|||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
@ -226,24 +228,6 @@ func (a *Authorize) ManageStream(
|
|||
if authReq.Username == "" {
|
||||
return fmt.Errorf("no username given")
|
||||
}
|
||||
if authReq.Hostname == "" {
|
||||
resp := extensions_ssh.ServerMessage{
|
||||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||||
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
||||
Allow: &extensions_ssh.AllowResponse{
|
||||
Username: state.Username,
|
||||
Target: &extensions_ssh.AllowResponse_Internal{
|
||||
Internal: &extensions_ssh.InternalTarget{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendC <- &resp
|
||||
continue
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
// Perform authorize check for this route
|
||||
|
@ -255,7 +239,7 @@ func (a *Authorize) ManageStream(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sendC <- handleEvaluatorResponseForSSH(res, &state)
|
||||
sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id)
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
||||
|
@ -279,9 +263,9 @@ func (a *Authorize) ManageStream(
|
|||
}
|
||||
case "keyboard-interactive":
|
||||
route := a.getSSHRouteForHostname(state.Hostname)
|
||||
if route == nil {
|
||||
return fmt.Errorf("invalid route")
|
||||
}
|
||||
// if route == nil {
|
||||
// return fmt.Errorf("invalid route")
|
||||
// }
|
||||
|
||||
opts := a.currentConfig.Load().Options
|
||||
idp, err := opts.GetIdentityProviderForPolicy(route)
|
||||
|
@ -393,7 +377,7 @@ func (a *Authorize) ManageStream(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sendC <- handleEvaluatorResponseForSSH(res, &state)
|
||||
sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID)
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
|
||||
|
@ -468,7 +452,14 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
|||
}
|
||||
route := a.getSSHRouteForHostname(state.Hostname)
|
||||
if route == nil {
|
||||
return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
|
||||
return &evaluator.Request{
|
||||
IsInternal: true,
|
||||
Session: evaluator.RequestSession{
|
||||
ID: sessionID,
|
||||
},
|
||||
}, nil
|
||||
|
||||
// return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
|
||||
}
|
||||
req := &evaluator.Request{
|
||||
IsInternal: false,
|
||||
|
@ -485,7 +476,7 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
|||
}
|
||||
|
||||
func handleEvaluatorResponseForSSH(
|
||||
result *evaluator.Result, state *StreamState,
|
||||
result *evaluator.Result, state *StreamState, sessionID string,
|
||||
) *extensions_ssh.ServerMessage {
|
||||
// fmt.Printf(" *** evaluator result: %+v\n", result)
|
||||
|
||||
|
@ -494,6 +485,33 @@ func handleEvaluatorResponseForSSH(
|
|||
|
||||
if allow {
|
||||
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
|
||||
|
||||
if state.Hostname == "" {
|
||||
return &extensions_ssh.ServerMessage{
|
||||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||||
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
||||
Allow: &extensions_ssh.AllowResponse{
|
||||
Username: state.Username,
|
||||
Target: &extensions_ssh.AllowResponse_Internal{
|
||||
Internal: &extensions_ssh.InternalTarget{
|
||||
SetMetadata: &corev3.Metadata{
|
||||
FilterMetadata: map[string]*structpb.Struct{
|
||||
"pomerium": {
|
||||
Fields: map[string]*structpb.Value{
|
||||
"session-id": structpb.NewStringValue(sessionID),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
||||
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
|
||||
Format: extensions_session_recording.Format_AsciicastFormat,
|
||||
|
@ -712,6 +730,8 @@ func (a *Authorize) ServeChannel(
|
|||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var sessionID atomic.Pointer[string]
|
||||
go func() {
|
||||
localWindow := uint32(channelWindowSize)
|
||||
for {
|
||||
|
@ -724,6 +744,16 @@ func (a *Authorize) ServeChannel(
|
|||
errC <- err
|
||||
return
|
||||
}
|
||||
if sessionID.Load() == nil {
|
||||
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
|
||||
if !ok {
|
||||
errC <- fmt.Errorf("first message was not metadata")
|
||||
return
|
||||
}
|
||||
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue()
|
||||
sessionID.Store(&id)
|
||||
continue
|
||||
}
|
||||
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
||||
msgLen := uint32(len(raw.RawBytes.GetValue()))
|
||||
if msgLen == 0 {
|
||||
|
@ -847,7 +877,7 @@ func (a *Authorize) ServeChannel(
|
|||
return err
|
||||
}
|
||||
|
||||
cmd := NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, inputR, outputW, sendC, &activeProgram)
|
||||
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram)
|
||||
if msg.Request == "shell" {
|
||||
cmd.SetArgs([]string{"portal"})
|
||||
} else {
|
||||
|
@ -956,15 +986,16 @@ func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader
|
|||
sendC <- channelDataMsg{
|
||||
PeersID: channelID,
|
||||
Length: uint32(n),
|
||||
Rest: buf[:n],
|
||||
Rest: slices.Clone(buf[:n]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewSSHCLI(
|
||||
func (a *Authorize) NewSSHCLI(
|
||||
cfg *config.Config,
|
||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||
sessionID string,
|
||||
stdin io.Reader,
|
||||
stdout io.Writer,
|
||||
sendC chan any,
|
||||
|
@ -986,12 +1017,66 @@ func NewSSHCLI(
|
|||
},
|
||||
}
|
||||
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
|
||||
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
|
||||
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
|
||||
cmd.CompletionOptions.DisableDefaultCmd = true
|
||||
cmd.SetIn(stdin)
|
||||
cmd.SetOut(stdout)
|
||||
cmd.SetErr(stdout)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (a *Authorize) NewLogoutCommand(
|
||||
cfg *config.Config,
|
||||
sessionID string,
|
||||
) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "logout",
|
||||
Short: "Log out",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client := a.state.Load().dataBrokerClient
|
||||
err := session.Delete(cmd.Context(), client, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("internal error: %w", err)
|
||||
}
|
||||
cmd.OutOrStdout().Write([]byte("Logged out successfully\r\n"))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
var whoamiTmpl = template.Must(template.New("whoami").Parse(`
|
||||
User ID: {{.UserId}}
|
||||
Session ID: {{.Id}}
|
||||
Expires at: {{.ExpiresAt.AsTime}}
|
||||
Claims:
|
||||
{{- range $k, $v := .Claims }}
|
||||
{{ $k }}: {{ $v.AsSlice }}
|
||||
{{- end }}
|
||||
`))
|
||||
|
||||
func (a *Authorize) NewWhoamiCommand(
|
||||
cfg *config.Config,
|
||||
sessionID string,
|
||||
) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "whoami",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client := a.state.Load().dataBrokerClient
|
||||
s, err := session.Get(cmd.Context(), client, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't fetch session: %w", err)
|
||||
}
|
||||
var b bytes.Buffer
|
||||
whoamiTmpl.Execute(&b, s)
|
||||
cmd.OutOrStdout().Write([]byte(b.String() + "\r\n"))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func NewPortalCommand(
|
||||
cfg *config.Config,
|
||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||
|
|
Loading…
Add table
Reference in a new issue