add test exercising the jump host mode

This commit is contained in:
Kenneth Jenkins 2025-04-11 14:29:40 -07:00
parent fe2df405e0
commit 64b4c284c8

View file

@ -5,6 +5,7 @@ import (
"crypto/ed25519"
"errors"
"io"
"net"
"strings"
"sync"
"testing"
@ -90,6 +91,77 @@ func TestSSH(t *testing.T) {
assert.Equal(t, "> hello world\r\nhello world\r\n> ", b.String())
}
func TestSSH_JumpHostMode(t *testing.T) {
clientKey := newSSHKey(t)
serverHostKey := newSSHKey(t)
// ssh client setup
var ki scenarios.EmptyKeyboardInteractiveChallenge
clientConfig := &ssh.ClientConfig{
User: "demo",
Auth: []ssh.AuthMethod{
ssh.PublicKeys(newSignerFromKey(t, clientKey)),
ssh.KeyboardInteractive(ki.Do),
},
HostKeyCallback: ssh.FixedHostKey(newPublicKey(t, serverHostKey.Public())),
}
// pomerium + upstream setup
env := testenv.New(t)
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}, scenarios.WithEnableDeviceAuth(true)))
env.Add(scenarios.SSH(scenarios.SSHConfig{
HostKeys: []any{serverHostKey},
}))
env.Add(&ki)
up := upstreams.SSH(
upstreams.WithHostKeys(newSignerFromKey(t, serverHostKey)),
upstreams.WithAuthorizedKey(newPublicKey(t, clientKey.Public()), "demo"),
)
up.SetServerConnCallback(echoShell{t}.handleConnection)
r := up.Route().
From(env.SubdomainURLWithScheme("example", "ssh")).
Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true })
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
// verify that a tunneled connection can be established
client, err := up.Dial(r, clientConfig)
require.NoError(t, err)
defer client.Close()
_, port, err := net.SplitHostPort(up.Addr().Value())
addr := "example:" + port
tunneledClient, err := TunneledClient(client, addr, clientConfig)
require.NoError(t, err)
sess, err := tunneledClient.NewSession()
require.NoError(t, err)
var b bytes.Buffer
sess.Stdout = &b
sess.Stdin = strings.NewReader("hello world\r")
require.NoError(t, sess.Shell())
require.NoError(t, sess.Wait())
assert.Equal(t, "> hello world\r\nhello world\r\n> ", b.String())
}
func TunneledClient(outer *ssh.Client, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := outer.Dial("tcp", addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
}
type echoShell struct {
t *testing.T
}