diff --git a/go.mod b/go.mod index 2ecceed7f..3ab879f5a 100644 --- a/go.mod +++ b/go.mod @@ -255,6 +255,7 @@ require ( go.uber.org/zap/exp v0.3.0 // indirect golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect golang.org/x/mod v0.21.0 // indirect + golang.org/x/term v0.30.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/tools v0.24.0 // indirect google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect diff --git a/integration2/ssh_int_test.go b/integration2/ssh_int_test.go index 5673f5d51..ebe8bf102 100644 --- a/integration2/ssh_int_test.go +++ b/integration2/ssh_int_test.go @@ -3,12 +3,16 @@ package ssh import ( "bytes" "crypto/ed25519" + "errors" + "io" "strings" + "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + "golang.org/x/term" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/testenv" @@ -20,31 +24,40 @@ import ( func TestSSH(t *testing.T) { clientKey := newSSHKey(t) serverHostKey := newSSHKey(t) + userCAKey := newSSHKey(t) // ssh client setup var ki scenarios.EmptyKeyboardInteractiveChallenge clientConfig := &ssh.ClientConfig{ User: "demo@example", Auth: []ssh.AuthMethod{ - ssh.PublicKeys(clientKey), + ssh.PublicKeys(newSignerFromKey(t, clientKey)), ssh.KeyboardInteractive(ki.Do), }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: ssh.FixedHostKey(newPublicKey(t, serverHostKey.Public())), } // pomerium + upstream setup env := testenv.New(t) env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}, scenarios.WithEnableDeviceAuth(true))) - env.Add(scenarios.SSH(scenarios.SSHConfig{})) + env.Add(scenarios.SSH(scenarios.SSHConfig{ + HostKeys: []any{serverHostKey}, + UserCAKey: userCAKey, + })) env.Add(&ki) + userCAPublicKey := newPublicKey(t, userCAKey.Public()) + certChecker := ssh.CertChecker{ + IsUserAuthority: func(auth ssh.PublicKey) bool { + return bytes.Equal(userCAPublicKey.Marshal(), auth.Marshal()) + }, + } up := upstreams.SSH( - upstreams.WithHostKeys(serverHostKey), - upstreams.WithAuthorizedKey(clientKey.PublicKey(), "demo"), - upstreams.WithBannerCallback(func(_ ssh.ConnMetadata) string { - return "TEST BANNER" - })) + upstreams.WithHostKeys(newSignerFromKey(t, serverHostKey)), + upstreams.WithPublicKeyCallback(certChecker.Authenticate), + ) + up.SetServerConnCallback(echoShell{t}.handleConnection) r := up.Route(). From(env.SubdomainURLWithScheme("example", "ssh")). Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true }) @@ -63,19 +76,90 @@ func TestSSH(t *testing.T) { var b bytes.Buffer sess.Stdout = &b - sess.Stdin = strings.NewReader("") - sess.Shell() - sess.Wait() + sess.Stdin = strings.NewReader("hello world\r") + require.NoError(t, sess.Shell()) + require.NoError(t, sess.Wait()) - assert.Equal(t, "TEST BANNER", b.String()) + assert.Equal(t, "> hello world\r\nhello world\r\n> ", b.String()) } -// newSSHKey generates and returns a new Ed25519 ssh key. -func newSSHKey(t *testing.T) ssh.Signer { +type echoShell struct { + t *testing.T +} + +func (sh echoShell) handleConnection(conn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + var wg sync.WaitGroup + defer wg.Wait() + + // Reject any global requests from the client. + wg.Add(1) + go func() { + ssh.DiscardRequests(reqs) + wg.Done() + }() + + // Accept shell session requests. + for newChannel := range chans { + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + require.NoError(sh.t, err, "echoShell: couldn't accept channel") + + // Acknowledge a 'shell' request. + wg.Add(1) + go func(in <-chan *ssh.Request) { + for req := range in { + req.Reply(req.Type == "shell", nil) + } + wg.Done() + }(requests) + + // Simulate a terminal that echoes all input lines. + term := term.NewTerminal(channel, "> ") + + wg.Add(1) + go func() { + defer func() { + channel.Close() + wg.Done() + }() + for { + line, err := term.ReadLine() + if errors.Is(err, io.EOF) { + break + } + require.NoError(sh.t, err, "echoShell: couldn't read line") + reply := append([]byte(line), '\n') + _, err = term.Write(reply) + require.NoError(sh.t, err, "echoShell: couldn't write line") + } + channel.SendRequest("exit-status", false, make([]byte, 4) /* uint32 0 */) + }() + } +} + +// newSSHKey generates a new Ed25519 ssh key. +func newSSHKey(t *testing.T) ed25519.PrivateKey { t.Helper() _, priv, err := ed25519.GenerateKey(nil) require.NoError(t, err) - signer, err := ssh.NewSignerFromKey(priv) + return priv +} + +// newSignerFromKey is a wrapper around ssh.NewSignerFromKey that will fail on error. +func newSignerFromKey(t *testing.T, key any) ssh.Signer { + t.Helper() + signer, err := ssh.NewSignerFromKey(key) require.NoError(t, err) return signer } + +// newPublicKey is a wrapper around ssh.NewPublicKey that will fail on error. +func newPublicKey(t *testing.T, key any) ssh.PublicKey { + t.Helper() + sshkey, err := ssh.NewPublicKey(key) + require.NoError(t, err) + return sshkey +} diff --git a/internal/testenv/upstreams/ssh.go b/internal/testenv/upstreams/ssh.go index 282c3d181..e27d789d0 100644 --- a/internal/testenv/upstreams/ssh.go +++ b/internal/testenv/upstreams/ssh.go @@ -203,6 +203,7 @@ func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) { // Dial implements SSHUpstream. func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) { return ssh.Dial("tcp", strings.TrimPrefix(r.URL().Value(), "ssh://"), config) + //return ssh.Dial("tcp", h.Env().Config().Options.SSHAddr, config) } // DirectDial implements SSHUpstream.