ssh: add runtime flag for jump host mode (#5699)

Adds a new runtime flag `ssh_allow_direct_tcpip` (default false) which
enables the "jump-host mode". This is disabled by default since we are
missing related config options/policy criteria.
This commit is contained in:
Joe Kralicky 2025-07-07 12:29:05 -04:00 committed by GitHub
parent 624622f236
commit 559545f686
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 71 additions and 20 deletions

View file

@ -33,7 +33,12 @@ var (
// RuntimeFlagMCP enables the MCP services for the authorize service
RuntimeFlagMCP = runtimeFlag("mcp", false)
// RuntimeFlagSSHRoutesPortal enables the SSH routes portal
RuntimeFlagSSHRoutesPortal = runtimeFlag("ssh_routes_portal", false)
// RuntimeFlagSSHAllowDirectTcpip allows downstream clients to open 'direct-tcpip'
// channels (jump host mode)
RuntimeFlagSSHAllowDirectTcpip = runtimeFlag("ssh_allow_direct_tcpip", false)
)
// RuntimeFlag is a runtime flag that can flip on/off certain features

View file

@ -22,6 +22,12 @@ type SSHConfig struct {
// upstream. An Ed25519 key will be generated if not set.
// Must be a type supported by [ssh.NewSignerFromKey].
UserCAKey any
// If true, enables the 'ssh_allow_direct_tcpip' runtime flag
EnableDirectTcpip bool
// If true, enables the 'ssh_routes_portal' runtime flag
EnableRoutesPortal bool
}
func SSH(c SSHConfig) testenv.Modifier {
@ -44,6 +50,8 @@ func SSH(c SSHConfig) testenv.Modifier {
configHostKeys := slices.Map(c.HostKeys, marshalPrivateKey)
cfg.Options.SSHHostKeys = &configHostKeys
cfg.Options.SSHUserCAKey = marshalPrivateKey(c.UserCAKey)
cfg.Options.RuntimeFlags[config.RuntimeFlagSSHAllowDirectTcpip] = c.EnableDirectTcpip
cfg.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = c.EnableRoutesPortal
})
}

View file

@ -21,6 +21,11 @@ import (
const (
MethodPublicKey = "publickey"
MethodKeyboardInteractive = "keyboard-interactive"
ChannelTypeSession = "session"
ChannelTypeDirectTcpip = "direct-tcpip"
ServiceConnection = "ssh-connection"
)
type KeyboardInteractiveQuerier interface {
@ -71,6 +76,7 @@ type StreamAuthInfo struct {
Hostname *string
StreamID uint64
SourceAddress string
ChannelType string
PublicKeyFingerprintSha256 []byte
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
@ -83,7 +89,6 @@ func (i *StreamAuthInfo) allMethodsValid() bool {
type StreamState struct {
StreamAuthInfo
DirectTcpip bool
RemainingUnauthenticatedMethods []string
DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
}
@ -146,7 +151,7 @@ func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.Key
case req := <-sh.readC:
switch msg := req.Message.(type) {
case *extensions_ssh.ClientMessage_InfoResponse:
if msg.InfoResponse.Method != "keyboard-interactive" {
if msg.InfoResponse.Method != MethodKeyboardInteractive {
return nil, status.Errorf(codes.Internal, "received invalid info response")
}
r, _ := msg.InfoResponse.Response.UnmarshalNew()
@ -235,9 +240,10 @@ func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_Ser
InitialWindowSize: msg.PeersWindow,
MaxPacketSize: msg.MaxPacketSize,
}
sh.state.ChannelType = msg.ChanType
channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
switch msg.ChanType {
case "session":
case ChannelTypeSession:
if err := channel.SendMessage(ChannelOpenConfirmMsg{
PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
@ -248,12 +254,14 @@ func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_Ser
}
ch := NewChannelHandler(channel, sh.config)
return ch.Run(stream.Context())
case "direct-tcpip":
case ChannelTypeDirectTcpip:
if !sh.config.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip) {
return status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled")
}
var subMsg ChannelOpenDirectMsg
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
return err
}
sh.state.DirectTcpip = true
action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
if err != nil {
return err
@ -268,7 +276,7 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
if req.Protocol != "ssh" {
return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
}
if req.Service != "ssh-connection" {
if req.Service != ServiceConnection {
return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
}
if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
@ -494,7 +502,7 @@ func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowRespo
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: *sh.state.Hostname,
DirectTcpip: sh.state.DirectTcpip,
DirectTcpip: sh.state.ChannelType == ChannelTypeDirectTcpip,
AllowedMethods: allowedMethods,
},
},

View file

@ -66,6 +66,17 @@ func HookWithArgs(f func(s *StreamHandlerSuite, args []any) any, args ...any) []
}
}
func RuntimeFlagDependentHookWithArgs(f func(s *StreamHandlerSuite, args []any) any, flag config.RuntimeFlag, argsIfEnabled []any, argsIfDisabled []any) []func(s *StreamHandlerSuite) any {
return []func(s *StreamHandlerSuite) any{
func(s *StreamHandlerSuite) any {
if s.cfg.Options.IsRuntimeFlagSet(flag) {
return f(s, argsIfEnabled)
}
return f(s, argsIfDisabled)
},
}
}
var (
StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
@ -158,7 +169,7 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
case err := <-s.errC:
s.ErrorContains(err, msg)
case <-time.After(DefaultTimeout):
s.FailNowf("timed out waiting for error %q", msg)
s.FailNow(fmt.Sprintf("timed out waiting for error %q", msg))
}
}
@ -1183,12 +1194,17 @@ func init() {
})
return stream
}
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = HookWithArgs(hook, Not(Nil()))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = HookWithArgs(hook, Not(Nil()))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = HookWithArgs(hook, Eq(status.Errorf(codes.PermissionDenied, "test error")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = HookWithArgs(hook, Nil())
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = RuntimeFlagDependentHookWithArgs(hook,
config.RuntimeFlagSSHAllowDirectTcpip, []any{Not(Nil())}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = RuntimeFlagDependentHookWithArgs(hook,
config.RuntimeFlagSSHAllowDirectTcpip, []any{Not(Nil())}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = RuntimeFlagDependentHookWithArgs(hook,
config.RuntimeFlagSSHAllowDirectTcpip, []any{Eq(status.Errorf(codes.PermissionDenied, "test error"))}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = RuntimeFlagDependentHookWithArgs(hook,
config.RuntimeFlagSSHAllowDirectTcpip, []any{Nil()}, []any{Eq(status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled"))})
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
@ -1908,10 +1924,12 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_BadHostname() {
}
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
s.mockAuth.EXPECT().
EvaluateDelayed(Any(), Any()).
Times(1).
Return(errors.New("test error"))
if s.directTcpipEnabled() {
s.mockAuth.EXPECT().
EvaluateDelayed(Any(), Any()).
Times(1).
Return(errors.New("test error"))
}
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "direct-tcpip",
@ -1929,11 +1947,14 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
}
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
s.mockAuth.EXPECT().
EvaluateDelayed(Any(), Any()).
Times(1).
Return(nil)
stream := s.BeforeTestHookResult.(*mockChannelStream)
if s.directTcpipEnabled() {
s.mockAuth.EXPECT().
EvaluateDelayed(Any(), Any()).
Times(1).
Return(nil)
}
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "direct-tcpip",
PeersID: 2,
@ -1946,7 +1967,11 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
SrcPort: 12345,
}),
}))
if !s.directTcpipEnabled() {
return // error checked in cleanup
}
recv, err := stream.RecvServerToClient()
s.Require().NoError(err)
action := recv.GetChannelControl().GetControlAction()
s.Require().NotNil(action, "received a message, but it was not a channel control action")
@ -1982,6 +2007,10 @@ func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
}, handoff.GetHandOff())
}
func (s *StreamHandlerSuite) directTcpipEnabled() bool {
return s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip)
}
func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
@ -2063,6 +2092,7 @@ func TestStreamHandlerSuiteWithRuntimeFlags(t *testing.T) {
ConfigModifiers: []func(*config.Config){
func(c *config.Config) {
c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true
c.Options.RuntimeFlags[config.RuntimeFlagSSHAllowDirectTcpip] = true
},
},
},