package ssh import ( "bytes" "crypto/ed25519" "errors" "io" "net" "strings" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "golang.org/x/term" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/scenarios" "github.com/pomerium/pomerium/internal/testenv/snippets" "github.com/pomerium/pomerium/internal/testenv/upstreams" ) func TestSSH(t *testing.T) { clientKey := newSSHKey(t) serverHostKey := newSSHKey(t) userCAKey := newSSHKey(t) // ssh client setup var ki scenarios.EmptyKeyboardInteractiveChallenge var bannerReceived string clientConfig := &ssh.ClientConfig{ User: "demo@example", Auth: []ssh.AuthMethod{ ssh.PublicKeys(newSignerFromKey(t, clientKey)), ssh.KeyboardInteractive(ki.Do), }, HostKeyCallback: ssh.FixedHostKey(newPublicKey(t, serverHostKey.Public())), BannerCallback: func(message string) error { bannerReceived = message; return nil }, } // pomerium + upstream setup env := testenv.New(t) env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}, scenarios.WithEnableDeviceAuth(true))) env.Add(scenarios.SSH(scenarios.SSHConfig{ HostKeys: []any{serverHostKey}, UserCAKey: userCAKey, })) env.Add(&ki) userCAPublicKey := newPublicKey(t, userCAKey.Public()) certChecker := ssh.CertChecker{ IsUserAuthority: func(auth ssh.PublicKey) bool { return bytes.Equal(userCAPublicKey.Marshal(), auth.Marshal()) }, } up := upstreams.SSH( upstreams.WithHostKeys(newSignerFromKey(t, serverHostKey)), upstreams.WithPublicKeyCallback(certChecker.Authenticate), upstreams.WithBannerCallback(func(_ ssh.ConnMetadata) string { return "UPSTREAM BANNER" }), ) up.SetServerConnCallback(echoShell{t}.handleConnection) r := up.Route(). From(env.SubdomainURLWithScheme("example", "ssh")). Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true }) env.AddUpstream(up) env.Start() snippets.WaitStartupComplete(env) // verify that a connection can be established client, err := up.Dial(r, clientConfig) require.NoError(t, err) defer client.Close() assert.Equal(t, "UPSTREAM BANNER", bannerReceived) sess, err := client.NewSession() require.NoError(t, err) defer sess.Close() var b bytes.Buffer sess.Stdout = &b sess.Stdin = strings.NewReader("hello world\r") require.NoError(t, sess.Shell()) require.NoError(t, sess.Wait()) assert.Equal(t, "> hello world\r\nhello world\r\n> ", b.String()) } func TestSSH_JumpHostMode(t *testing.T) { clientKey := newSSHKey(t) serverHostKey := newSSHKey(t) // ssh client setup var ki scenarios.EmptyKeyboardInteractiveChallenge clientConfig := &ssh.ClientConfig{ User: "demo", Auth: []ssh.AuthMethod{ ssh.PublicKeys(newSignerFromKey(t, clientKey)), ssh.KeyboardInteractive(ki.Do), }, HostKeyCallback: ssh.FixedHostKey(newPublicKey(t, serverHostKey.Public())), } // pomerium + upstream setup env := testenv.New(t) env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}, scenarios.WithEnableDeviceAuth(true))) env.Add(scenarios.SSH(scenarios.SSHConfig{ HostKeys: []any{serverHostKey}, })) env.Add(&ki) up := upstreams.SSH( upstreams.WithHostKeys(newSignerFromKey(t, serverHostKey)), upstreams.WithAuthorizedKey(newPublicKey(t, clientKey.Public()), "demo"), ) up.SetServerConnCallback(echoShell{t}.handleConnection) r := up.Route(). From(env.SubdomainURLWithScheme("example", "ssh")). Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true }) env.AddUpstream(up) env.Start() snippets.WaitStartupComplete(env) // verify that a tunneled connection can be established client, err := up.Dial(r, clientConfig) require.NoError(t, err) defer client.Close() _, port, err := net.SplitHostPort(up.Addr().Value()) addr := "example:" + port tunneledClient, err := TunneledClient(client, addr, clientConfig) require.NoError(t, err) sess, err := tunneledClient.NewSession() require.NoError(t, err) var b bytes.Buffer sess.Stdout = &b sess.Stdin = strings.NewReader("hello world\r") require.NoError(t, sess.Shell()) require.NoError(t, sess.Wait()) assert.Equal(t, "> hello world\r\nhello world\r\n> ", b.String()) } func TunneledClient(outer *ssh.Client, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { conn, err := outer.Dial("tcp", addr) if err != nil { return nil, err } c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) if err != nil { return nil, err } return ssh.NewClient(c, chans, reqs), nil } type echoShell struct { t *testing.T } func (sh echoShell) handleConnection(conn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { var wg sync.WaitGroup defer wg.Wait() // Reject any global requests from the client. wg.Add(1) go func() { ssh.DiscardRequests(reqs) wg.Done() }() // Accept shell session requests. for newChannel := range chans { if newChannel.ChannelType() != "session" { newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChannel.Accept() require.NoError(sh.t, err, "echoShell: couldn't accept channel") // Acknowledge a 'shell' request. wg.Add(1) go func(in <-chan *ssh.Request) { for req := range in { req.Reply(req.Type == "shell", nil) } wg.Done() }(requests) // Simulate a terminal that echoes all input lines. term := term.NewTerminal(channel, "> ") wg.Add(1) go func() { defer func() { channel.Close() wg.Done() }() for { line, err := term.ReadLine() if errors.Is(err, io.EOF) { break } require.NoError(sh.t, err, "echoShell: couldn't read line") reply := append([]byte(line), '\n') _, err = term.Write(reply) require.NoError(sh.t, err, "echoShell: couldn't write line") } channel.SendRequest("exit-status", false, make([]byte, 4) /* uint32 0 */) }() } } // newSSHKey generates a new Ed25519 ssh key. func newSSHKey(t *testing.T) ed25519.PrivateKey { t.Helper() _, priv, err := ed25519.GenerateKey(nil) require.NoError(t, err) return priv } // newSignerFromKey is a wrapper around ssh.NewSignerFromKey that will fail on error. func newSignerFromKey(t *testing.T, key any) ssh.Signer { t.Helper() signer, err := ssh.NewSignerFromKey(key) require.NoError(t, err) return signer } // newPublicKey is a wrapper around ssh.NewPublicKey that will fail on error. func newPublicKey(t *testing.T, key any) ssh.PublicKey { t.Helper() sshkey, err := ssh.NewPublicKey(key) require.NoError(t, err) return sshkey }