mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
ssh: add runtime flag for jump host mode (#5699)
Adds a new runtime flag `ssh_allow_direct_tcpip` (default false) which enables the "jump-host mode". This is disabled by default since we are missing related config options/policy criteria.
This commit is contained in:
parent
624622f236
commit
559545f686
4 changed files with 71 additions and 20 deletions
|
@ -33,7 +33,12 @@ var (
|
|||
// RuntimeFlagMCP enables the MCP services for the authorize service
|
||||
RuntimeFlagMCP = runtimeFlag("mcp", false)
|
||||
|
||||
// RuntimeFlagSSHRoutesPortal enables the SSH routes portal
|
||||
RuntimeFlagSSHRoutesPortal = runtimeFlag("ssh_routes_portal", false)
|
||||
|
||||
// RuntimeFlagSSHAllowDirectTcpip allows downstream clients to open 'direct-tcpip'
|
||||
// channels (jump host mode)
|
||||
RuntimeFlagSSHAllowDirectTcpip = runtimeFlag("ssh_allow_direct_tcpip", false)
|
||||
)
|
||||
|
||||
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
||||
|
|
|
@ -22,6 +22,12 @@ type SSHConfig struct {
|
|||
// upstream. An Ed25519 key will be generated if not set.
|
||||
// Must be a type supported by [ssh.NewSignerFromKey].
|
||||
UserCAKey any
|
||||
|
||||
// If true, enables the 'ssh_allow_direct_tcpip' runtime flag
|
||||
EnableDirectTcpip bool
|
||||
|
||||
// If true, enables the 'ssh_routes_portal' runtime flag
|
||||
EnableRoutesPortal bool
|
||||
}
|
||||
|
||||
func SSH(c SSHConfig) testenv.Modifier {
|
||||
|
@ -44,6 +50,8 @@ func SSH(c SSHConfig) testenv.Modifier {
|
|||
configHostKeys := slices.Map(c.HostKeys, marshalPrivateKey)
|
||||
cfg.Options.SSHHostKeys = &configHostKeys
|
||||
cfg.Options.SSHUserCAKey = marshalPrivateKey(c.UserCAKey)
|
||||
cfg.Options.RuntimeFlags[config.RuntimeFlagSSHAllowDirectTcpip] = c.EnableDirectTcpip
|
||||
cfg.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = c.EnableRoutesPortal
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,11 @@ import (
|
|||
const (
|
||||
MethodPublicKey = "publickey"
|
||||
MethodKeyboardInteractive = "keyboard-interactive"
|
||||
|
||||
ChannelTypeSession = "session"
|
||||
ChannelTypeDirectTcpip = "direct-tcpip"
|
||||
|
||||
ServiceConnection = "ssh-connection"
|
||||
)
|
||||
|
||||
type KeyboardInteractiveQuerier interface {
|
||||
|
@ -71,6 +76,7 @@ type StreamAuthInfo struct {
|
|||
Hostname *string
|
||||
StreamID uint64
|
||||
SourceAddress string
|
||||
ChannelType string
|
||||
PublicKeyFingerprintSha256 []byte
|
||||
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
|
||||
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
|
||||
|
@ -83,7 +89,6 @@ func (i *StreamAuthInfo) allMethodsValid() bool {
|
|||
|
||||
type StreamState struct {
|
||||
StreamAuthInfo
|
||||
DirectTcpip bool
|
||||
RemainingUnauthenticatedMethods []string
|
||||
DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
||||
}
|
||||
|
@ -146,7 +151,7 @@ func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.Key
|
|||
case req := <-sh.readC:
|
||||
switch msg := req.Message.(type) {
|
||||
case *extensions_ssh.ClientMessage_InfoResponse:
|
||||
if msg.InfoResponse.Method != "keyboard-interactive" {
|
||||
if msg.InfoResponse.Method != MethodKeyboardInteractive {
|
||||
return nil, status.Errorf(codes.Internal, "received invalid info response")
|
||||
}
|
||||
r, _ := msg.InfoResponse.Response.UnmarshalNew()
|
||||
|
@ -235,9 +240,10 @@ func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_Ser
|
|||
InitialWindowSize: msg.PeersWindow,
|
||||
MaxPacketSize: msg.MaxPacketSize,
|
||||
}
|
||||
sh.state.ChannelType = msg.ChanType
|
||||
channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
|
||||
switch msg.ChanType {
|
||||
case "session":
|
||||
case ChannelTypeSession:
|
||||
if err := channel.SendMessage(ChannelOpenConfirmMsg{
|
||||
PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
|
||||
MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
|
||||
|
@ -248,12 +254,14 @@ func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_Ser
|
|||
}
|
||||
ch := NewChannelHandler(channel, sh.config)
|
||||
return ch.Run(stream.Context())
|
||||
case "direct-tcpip":
|
||||
case ChannelTypeDirectTcpip:
|
||||
if !sh.config.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip) {
|
||||
return status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled")
|
||||
}
|
||||
var subMsg ChannelOpenDirectMsg
|
||||
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
sh.state.DirectTcpip = true
|
||||
action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -268,7 +276,7 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
|
|||
if req.Protocol != "ssh" {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
|
||||
}
|
||||
if req.Service != "ssh-connection" {
|
||||
if req.Service != ServiceConnection {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
|
||||
}
|
||||
if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
|
||||
|
@ -494,7 +502,7 @@ func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowRespo
|
|||
Target: &extensions_ssh.AllowResponse_Upstream{
|
||||
Upstream: &extensions_ssh.UpstreamTarget{
|
||||
Hostname: *sh.state.Hostname,
|
||||
DirectTcpip: sh.state.DirectTcpip,
|
||||
DirectTcpip: sh.state.ChannelType == ChannelTypeDirectTcpip,
|
||||
AllowedMethods: allowedMethods,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -66,6 +66,17 @@ func HookWithArgs(f func(s *StreamHandlerSuite, args []any) any, args ...any) []
|
|||
}
|
||||
}
|
||||
|
||||
func RuntimeFlagDependentHookWithArgs(f func(s *StreamHandlerSuite, args []any) any, flag config.RuntimeFlag, argsIfEnabled []any, argsIfDisabled []any) []func(s *StreamHandlerSuite) any {
|
||||
return []func(s *StreamHandlerSuite) any{
|
||||
func(s *StreamHandlerSuite) any {
|
||||
if s.cfg.Options.IsRuntimeFlagSet(flag) {
|
||||
return f(s, argsIfEnabled)
|
||||
}
|
||||
return f(s, argsIfDisabled)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
||||
StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
||||
|
@ -158,7 +169,7 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
|
|||
case err := <-s.errC:
|
||||
s.ErrorContains(err, msg)
|
||||
case <-time.After(DefaultTimeout):
|
||||
s.FailNowf("timed out waiting for error %q", msg)
|
||||
s.FailNow(fmt.Sprintf("timed out waiting for error %q", msg))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1183,12 +1194,17 @@ func init() {
|
|||
})
|
||||
return stream
|
||||
}
|
||||
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = HookWithArgs(hook, Not(Nil()))
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = HookWithArgs(hook, Not(Nil()))
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = HookWithArgs(hook, Eq(status.Errorf(codes.PermissionDenied, "test error")))
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = HookWithArgs(hook, Nil())
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||
config.RuntimeFlagSSHAllowDirectTcpip, []any{Not(Nil())}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||
config.RuntimeFlagSSHAllowDirectTcpip, []any{Not(Nil())}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||
config.RuntimeFlagSSHAllowDirectTcpip, []any{Eq(status.Errorf(codes.PermissionDenied, "test error"))}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = RuntimeFlagDependentHookWithArgs(hook,
|
||||
config.RuntimeFlagSSHAllowDirectTcpip, []any{Nil()}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown")))
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
||||
|
@ -1908,10 +1924,12 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_BadHostname() {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
|
||||
s.mockAuth.EXPECT().
|
||||
EvaluateDelayed(Any(), Any()).
|
||||
Times(1).
|
||||
Return(errors.New("test error"))
|
||||
if s.directTcpipEnabled() {
|
||||
s.mockAuth.EXPECT().
|
||||
EvaluateDelayed(Any(), Any()).
|
||||
Times(1).
|
||||
Return(errors.New("test error"))
|
||||
}
|
||||
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
||||
ChanType: "direct-tcpip",
|
||||
|
@ -1929,11 +1947,14 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
||||
s.mockAuth.EXPECT().
|
||||
EvaluateDelayed(Any(), Any()).
|
||||
Times(1).
|
||||
Return(nil)
|
||||
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
||||
|
||||
if s.directTcpipEnabled() {
|
||||
s.mockAuth.EXPECT().
|
||||
EvaluateDelayed(Any(), Any()).
|
||||
Times(1).
|
||||
Return(nil)
|
||||
}
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
||||
ChanType: "direct-tcpip",
|
||||
PeersID: 2,
|
||||
|
@ -1946,7 +1967,11 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
|||
SrcPort: 12345,
|
||||
}),
|
||||
}))
|
||||
if !s.directTcpipEnabled() {
|
||||
return // error checked in cleanup
|
||||
}
|
||||
recv, err := stream.RecvServerToClient()
|
||||
|
||||
s.Require().NoError(err)
|
||||
action := recv.GetChannelControl().GetControlAction()
|
||||
s.Require().NotNil(action, "received a message, but it was not a channel control action")
|
||||
|
@ -1982,6 +2007,10 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
|||
}, handoff.GetHandOff())
|
||||
}
|
||||
|
||||
func (s *StreamHandlerSuite) directTcpipEnabled() bool {
|
||||
return s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip)
|
||||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() {
|
||||
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
||||
|
@ -2063,6 +2092,7 @@ func TestStreamHandlerSuiteWithRuntimeFlags(t *testing.T) {
|
|||
ConfigModifiers: []func(*config.Config){
|
||||
func(c *config.Config) {
|
||||
c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true
|
||||
c.Options.RuntimeFlags[config.RuntimeFlagSSHAllowDirectTcpip] = true
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue