diff --git a/integration2/ssh_int_test.go b/integration2/ssh_int_test.go index 91153c80d..dee3b84c1 100644 --- a/integration2/ssh_int_test.go +++ b/integration2/ssh_int_test.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "errors" "io" + "net" "strings" "sync" "testing" @@ -90,6 +91,77 @@ func TestSSH(t *testing.T) { assert.Equal(t, "> hello world\r\nhello world\r\n> ", b.String()) } +func TestSSH_JumpHostMode(t *testing.T) { + clientKey := newSSHKey(t) + serverHostKey := newSSHKey(t) + + // ssh client setup + var ki scenarios.EmptyKeyboardInteractiveChallenge + clientConfig := &ssh.ClientConfig{ + User: "demo", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(newSignerFromKey(t, clientKey)), + ssh.KeyboardInteractive(ki.Do), + }, + HostKeyCallback: ssh.FixedHostKey(newPublicKey(t, serverHostKey.Public())), + } + + // pomerium + upstream setup + env := testenv.New(t) + + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}, scenarios.WithEnableDeviceAuth(true))) + env.Add(scenarios.SSH(scenarios.SSHConfig{ + HostKeys: []any{serverHostKey}, + })) + env.Add(&ki) + + up := upstreams.SSH( + upstreams.WithHostKeys(newSignerFromKey(t, serverHostKey)), + upstreams.WithAuthorizedKey(newPublicKey(t, clientKey.Public()), "demo"), + ) + up.SetServerConnCallback(echoShell{t}.handleConnection) + r := up.Route(). + From(env.SubdomainURLWithScheme("example", "ssh")). + Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true }) + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + // verify that a tunneled connection can be established + client, err := up.Dial(r, clientConfig) + require.NoError(t, err) + defer client.Close() + + _, port, err := net.SplitHostPort(up.Addr().Value()) + addr := "example:" + port + + tunneledClient, err := TunneledClient(client, addr, clientConfig) + require.NoError(t, err) + + sess, err := tunneledClient.NewSession() + require.NoError(t, err) + + var b bytes.Buffer + sess.Stdout = &b + sess.Stdin = strings.NewReader("hello world\r") + require.NoError(t, sess.Shell()) + require.NoError(t, sess.Wait()) + + assert.Equal(t, "> hello world\r\nhello world\r\n> ", b.String()) +} + +func TunneledClient(outer *ssh.Client, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + conn, err := outer.Dial("tcp", addr) + if err != nil { + return nil, err + } + c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return ssh.NewClient(c, chans, reqs), nil +} + type echoShell struct { t *testing.T }