mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
ssh: add runtime flag for jump host mode (#5699)
Adds a new runtime flag `ssh_allow_direct_tcpip` (default false) which enables the "jump-host mode". This is disabled by default since we are missing related config options/policy criteria.
This commit is contained in:
parent
624622f236
commit
559545f686
4 changed files with 71 additions and 20 deletions
|
@ -33,7 +33,12 @@ var (
|
||||||
// RuntimeFlagMCP enables the MCP services for the authorize service
|
// RuntimeFlagMCP enables the MCP services for the authorize service
|
||||||
RuntimeFlagMCP = runtimeFlag("mcp", false)
|
RuntimeFlagMCP = runtimeFlag("mcp", false)
|
||||||
|
|
||||||
|
// RuntimeFlagSSHRoutesPortal enables the SSH routes portal
|
||||||
RuntimeFlagSSHRoutesPortal = runtimeFlag("ssh_routes_portal", false)
|
RuntimeFlagSSHRoutesPortal = runtimeFlag("ssh_routes_portal", false)
|
||||||
|
|
||||||
|
// RuntimeFlagSSHAllowDirectTcpip allows downstream clients to open 'direct-tcpip'
|
||||||
|
// channels (jump host mode)
|
||||||
|
RuntimeFlagSSHAllowDirectTcpip = runtimeFlag("ssh_allow_direct_tcpip", false)
|
||||||
)
|
)
|
||||||
|
|
||||||
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
||||||
|
|
|
@ -22,6 +22,12 @@ type SSHConfig struct {
|
||||||
// upstream. An Ed25519 key will be generated if not set.
|
// upstream. An Ed25519 key will be generated if not set.
|
||||||
// Must be a type supported by [ssh.NewSignerFromKey].
|
// Must be a type supported by [ssh.NewSignerFromKey].
|
||||||
UserCAKey any
|
UserCAKey any
|
||||||
|
|
||||||
|
// If true, enables the 'ssh_allow_direct_tcpip' runtime flag
|
||||||
|
EnableDirectTcpip bool
|
||||||
|
|
||||||
|
// If true, enables the 'ssh_routes_portal' runtime flag
|
||||||
|
EnableRoutesPortal bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func SSH(c SSHConfig) testenv.Modifier {
|
func SSH(c SSHConfig) testenv.Modifier {
|
||||||
|
@ -44,6 +50,8 @@ func SSH(c SSHConfig) testenv.Modifier {
|
||||||
configHostKeys := slices.Map(c.HostKeys, marshalPrivateKey)
|
configHostKeys := slices.Map(c.HostKeys, marshalPrivateKey)
|
||||||
cfg.Options.SSHHostKeys = &configHostKeys
|
cfg.Options.SSHHostKeys = &configHostKeys
|
||||||
cfg.Options.SSHUserCAKey = marshalPrivateKey(c.UserCAKey)
|
cfg.Options.SSHUserCAKey = marshalPrivateKey(c.UserCAKey)
|
||||||
|
cfg.Options.RuntimeFlags[config.RuntimeFlagSSHAllowDirectTcpip] = c.EnableDirectTcpip
|
||||||
|
cfg.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = c.EnableRoutesPortal
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,11 @@ import (
|
||||||
const (
|
const (
|
||||||
MethodPublicKey = "publickey"
|
MethodPublicKey = "publickey"
|
||||||
MethodKeyboardInteractive = "keyboard-interactive"
|
MethodKeyboardInteractive = "keyboard-interactive"
|
||||||
|
|
||||||
|
ChannelTypeSession = "session"
|
||||||
|
ChannelTypeDirectTcpip = "direct-tcpip"
|
||||||
|
|
||||||
|
ServiceConnection = "ssh-connection"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyboardInteractiveQuerier interface {
|
type KeyboardInteractiveQuerier interface {
|
||||||
|
@ -71,6 +76,7 @@ type StreamAuthInfo struct {
|
||||||
Hostname *string
|
Hostname *string
|
||||||
StreamID uint64
|
StreamID uint64
|
||||||
SourceAddress string
|
SourceAddress string
|
||||||
|
ChannelType string
|
||||||
PublicKeyFingerprintSha256 []byte
|
PublicKeyFingerprintSha256 []byte
|
||||||
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
|
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
|
||||||
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
|
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
|
||||||
|
@ -83,7 +89,6 @@ func (i *StreamAuthInfo) allMethodsValid() bool {
|
||||||
|
|
||||||
type StreamState struct {
|
type StreamState struct {
|
||||||
StreamAuthInfo
|
StreamAuthInfo
|
||||||
DirectTcpip bool
|
|
||||||
RemainingUnauthenticatedMethods []string
|
RemainingUnauthenticatedMethods []string
|
||||||
DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
||||||
}
|
}
|
||||||
|
@ -146,7 +151,7 @@ func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.Key
|
||||||
case req := <-sh.readC:
|
case req := <-sh.readC:
|
||||||
switch msg := req.Message.(type) {
|
switch msg := req.Message.(type) {
|
||||||
case *extensions_ssh.ClientMessage_InfoResponse:
|
case *extensions_ssh.ClientMessage_InfoResponse:
|
||||||
if msg.InfoResponse.Method != "keyboard-interactive" {
|
if msg.InfoResponse.Method != MethodKeyboardInteractive {
|
||||||
return nil, status.Errorf(codes.Internal, "received invalid info response")
|
return nil, status.Errorf(codes.Internal, "received invalid info response")
|
||||||
}
|
}
|
||||||
r, _ := msg.InfoResponse.Response.UnmarshalNew()
|
r, _ := msg.InfoResponse.Response.UnmarshalNew()
|
||||||
|
@ -235,9 +240,10 @@ func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_Ser
|
||||||
InitialWindowSize: msg.PeersWindow,
|
InitialWindowSize: msg.PeersWindow,
|
||||||
MaxPacketSize: msg.MaxPacketSize,
|
MaxPacketSize: msg.MaxPacketSize,
|
||||||
}
|
}
|
||||||
|
sh.state.ChannelType = msg.ChanType
|
||||||
channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
|
channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
|
||||||
switch msg.ChanType {
|
switch msg.ChanType {
|
||||||
case "session":
|
case ChannelTypeSession:
|
||||||
if err := channel.SendMessage(ChannelOpenConfirmMsg{
|
if err := channel.SendMessage(ChannelOpenConfirmMsg{
|
||||||
PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
|
PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
|
||||||
MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
|
MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
|
||||||
|
@ -248,12 +254,14 @@ func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_Ser
|
||||||
}
|
}
|
||||||
ch := NewChannelHandler(channel, sh.config)
|
ch := NewChannelHandler(channel, sh.config)
|
||||||
return ch.Run(stream.Context())
|
return ch.Run(stream.Context())
|
||||||
case "direct-tcpip":
|
case ChannelTypeDirectTcpip:
|
||||||
|
if !sh.config.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip) {
|
||||||
|
return status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled")
|
||||||
|
}
|
||||||
var subMsg ChannelOpenDirectMsg
|
var subMsg ChannelOpenDirectMsg
|
||||||
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sh.state.DirectTcpip = true
|
|
||||||
action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
|
action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -268,7 +276,7 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
|
||||||
if req.Protocol != "ssh" {
|
if req.Protocol != "ssh" {
|
||||||
return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
|
return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
|
||||||
}
|
}
|
||||||
if req.Service != "ssh-connection" {
|
if req.Service != ServiceConnection {
|
||||||
return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
|
return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
|
||||||
}
|
}
|
||||||
if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
|
if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
|
||||||
|
@ -494,7 +502,7 @@ func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowRespo
|
||||||
Target: &extensions_ssh.AllowResponse_Upstream{
|
Target: &extensions_ssh.AllowResponse_Upstream{
|
||||||
Upstream: &extensions_ssh.UpstreamTarget{
|
Upstream: &extensions_ssh.UpstreamTarget{
|
||||||
Hostname: *sh.state.Hostname,
|
Hostname: *sh.state.Hostname,
|
||||||
DirectTcpip: sh.state.DirectTcpip,
|
DirectTcpip: sh.state.ChannelType == ChannelTypeDirectTcpip,
|
||||||
AllowedMethods: allowedMethods,
|
AllowedMethods: allowedMethods,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -66,6 +66,17 @@ func HookWithArgs(f func(s *StreamHandlerSuite, args []any) any, args ...any) []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RuntimeFlagDependentHookWithArgs(f func(s *StreamHandlerSuite, args []any) any, flag config.RuntimeFlag, argsIfEnabled []any, argsIfDisabled []any) []func(s *StreamHandlerSuite) any {
|
||||||
|
return []func(s *StreamHandlerSuite) any{
|
||||||
|
func(s *StreamHandlerSuite) any {
|
||||||
|
if s.cfg.Options.IsRuntimeFlagSet(flag) {
|
||||||
|
return f(s, argsIfEnabled)
|
||||||
|
}
|
||||||
|
return f(s, argsIfDisabled)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
||||||
StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
||||||
|
@ -158,7 +169,7 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
|
||||||
case err := <-s.errC:
|
case err := <-s.errC:
|
||||||
s.ErrorContains(err, msg)
|
s.ErrorContains(err, msg)
|
||||||
case <-time.After(DefaultTimeout):
|
case <-time.After(DefaultTimeout):
|
||||||
s.FailNowf("timed out waiting for error %q", msg)
|
s.FailNow(fmt.Sprintf("timed out waiting for error %q", msg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1183,12 +1194,17 @@ func init() {
|
||||||
})
|
})
|
||||||
return stream
|
return stream
|
||||||
}
|
}
|
||||||
|
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = HookWithArgs(hook, Not(Nil()))
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = HookWithArgs(hook, Not(Nil()))
|
config.RuntimeFlagSSHAllowDirectTcpip, []any{Not(Nil())}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = HookWithArgs(hook, Eq(status.Errorf(codes.PermissionDenied, "test error")))
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = HookWithArgs(hook, Nil())
|
config.RuntimeFlagSSHAllowDirectTcpip, []any{Not(Nil())}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||||
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||||
|
config.RuntimeFlagSSHAllowDirectTcpip, []any{Eq(status.Errorf(codes.PermissionDenied, "test error"))}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||||
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||||
|
config.RuntimeFlagSSHAllowDirectTcpip, []any{Nil()}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown")))
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown")))
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||||
|
@ -1908,10 +1924,12 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_BadHostname() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
|
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
|
||||||
|
if s.directTcpipEnabled() {
|
||||||
s.mockAuth.EXPECT().
|
s.mockAuth.EXPECT().
|
||||||
EvaluateDelayed(Any(), Any()).
|
EvaluateDelayed(Any(), Any()).
|
||||||
Times(1).
|
Times(1).
|
||||||
Return(errors.New("test error"))
|
Return(errors.New("test error"))
|
||||||
|
}
|
||||||
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
||||||
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
||||||
ChanType: "direct-tcpip",
|
ChanType: "direct-tcpip",
|
||||||
|
@ -1929,11 +1947,14 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
||||||
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
||||||
|
|
||||||
|
if s.directTcpipEnabled() {
|
||||||
s.mockAuth.EXPECT().
|
s.mockAuth.EXPECT().
|
||||||
EvaluateDelayed(Any(), Any()).
|
EvaluateDelayed(Any(), Any()).
|
||||||
Times(1).
|
Times(1).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
}
|
||||||
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
||||||
ChanType: "direct-tcpip",
|
ChanType: "direct-tcpip",
|
||||||
PeersID: 2,
|
PeersID: 2,
|
||||||
|
@ -1946,7 +1967,11 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
||||||
SrcPort: 12345,
|
SrcPort: 12345,
|
||||||
}),
|
}),
|
||||||
}))
|
}))
|
||||||
|
if !s.directTcpipEnabled() {
|
||||||
|
return // error checked in cleanup
|
||||||
|
}
|
||||||
recv, err := stream.RecvServerToClient()
|
recv, err := stream.RecvServerToClient()
|
||||||
|
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
action := recv.GetChannelControl().GetControlAction()
|
action := recv.GetChannelControl().GetControlAction()
|
||||||
s.Require().NotNil(action, "received a message, but it was not a channel control action")
|
s.Require().NotNil(action, "received a message, but it was not a channel control action")
|
||||||
|
@ -1982,6 +2007,10 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
||||||
}, handoff.GetHandOff())
|
}, handoff.GetHandOff())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StreamHandlerSuite) directTcpipEnabled() bool {
|
||||||
|
return s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() {
|
func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() {
|
||||||
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
||||||
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
||||||
|
@ -2063,6 +2092,7 @@ func TestStreamHandlerSuiteWithRuntimeFlags(t *testing.T) {
|
||||||
ConfigModifiers: []func(*config.Config){
|
ConfigModifiers: []func(*config.Config){
|
||||||
func(c *config.Config) {
|
func(c *config.Config) {
|
||||||
c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true
|
c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true
|
||||||
|
c.Options.RuntimeFlags[config.RuntimeFlagSSHAllowDirectTcpip] = true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue