diff --git a/integration2/ssh_int_test.go b/integration2/ssh_int_test.go new file mode 100644 index 000000000..de2a799d9 --- /dev/null +++ b/integration2/ssh_int_test.go @@ -0,0 +1,98 @@ +package ssh + +import ( + "bytes" + "crypto/ed25519" + "fmt" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" +) + +func TestSSH(t *testing.T) { + // generate client ssh key + _, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + signer, err := ssh.NewSignerFromKey(priv) + require.NoError(t, err) + + // ssh client setup + clientConfig := &ssh.ClientConfig{ + User: "demo", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + // pomerium + upstream setup + env := testenv.New(t) + + up := upstreams.SSH() + r := up.Route(). + From(env.SubdomainURLWithScheme("ssh", "ssh")). + Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + // test scenario -- first verify that the upstream is working at all + client, err := up.DirectDial(r, clientConfig) + require.NoError(t, err) + defer client.Close() + + sess, err := client.NewSession() + require.NoError(t, err) + defer sess.Close() +} + +func TestHelloWorld(t *testing.T) { + t.Skip("debugging...") + + key, err := os.ReadFile("/Users/kjenkins/scratch/sshd/demo_key") + require.NoError(t, err) + signer, err := ssh.ParsePrivateKey(key) + require.NoError(t, err) + + config := &ssh.ClientConfig{ + User: "demo", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + conn, err := ssh.Dial("tcp", "localhost:2222", config) + require.NoError(t, err, "unable to connect") + defer conn.Close() + + //conn.ServerVersion() + + sess, err := conn.NewSession() + require.NoError(t, err, "unable to start session") + defer sess.Close() + + var output bytes.Buffer + sess.Stdout = &output + sess.Stdin = strings.NewReader("whoami\n") + + err = sess.Shell() + + fmt.Println("Shell() returned ", err) + + err = sess.Wait() + + fmt.Println("Wait() returned ", err) + + fmt.Println(" --> output:\n\n", output.String()) + + //sess.SendRequest() +} diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index dce233b2a..4a70033ff 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -141,6 +141,13 @@ type Environment interface { // can be used as the 'from' value for routes. SubdomainURL(subdomain string) values.Value[string] + // SubdomainURL returns a string [values.Value] which will contain a complete + // URL for the given subdomain of the server's domain (given by its serving + // certificate), with the given scheme and the random http server port. + // This value will only be resolved some time after Start() is called, and + // can be used as the 'from' value for routes. + SubdomainURLWithScheme(subdomain, scheme string) values.Value[string] + // NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for // the Pomerium server and Envoy. NewLogRecorder(opts ...LogRecorderOption) *LogRecorder @@ -510,8 +517,12 @@ func (e *environment) Require() *require.Assertions { } func (e *environment) SubdomainURL(subdomain string) values.Value[string] { + return e.SubdomainURLWithScheme(subdomain, "https") +} + +func (e *environment) SubdomainURLWithScheme(subdomain, scheme string) values.Value[string] { return values.Bind(e.ports.ProxyHTTP, func(port int) string { - return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port) + return fmt.Sprintf("%s://%s.%s:%d", scheme, subdomain, e.domain, port) }) } diff --git a/internal/testenv/upstreams/ssh.go b/internal/testenv/upstreams/ssh.go new file mode 100644 index 000000000..eeaec3d82 --- /dev/null +++ b/internal/testenv/upstreams/ssh.go @@ -0,0 +1,177 @@ +package upstreams + +import ( + "context" + "fmt" + "log" + "net" + + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/values" + "golang.org/x/crypto/ssh" +) + +type SSHUpstreamOptions struct { + displayName string + serverConfig ssh.ServerConfig +} + +type SSHUpstreamOption interface { + applySSH(*SSHUpstreamOptions) +} + +type sshUpstreamOption func(o *SSHUpstreamOptions) + +func (s sshUpstreamOption) applySSH(o *SSHUpstreamOptions) { s(o) } + +func WithPublicKeyAuthAlgorithms(algs []string) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + o.serverConfig.PublicKeyAuthAlgorithms = algs + }) +} + +func WithHostKeys(keys ...ssh.Signer) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + for _, key := range keys { + o.serverConfig.AddHostKey(key) + } + }) +} + +func WithBannerCallback(c func(ssh.ConnMetadata) string) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + o.serverConfig.BannerCallback = c + }) +} + +func WithPublicKeyCallback(c func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + o.serverConfig.PublicKeyCallback = c + }) +} + +type ServerConnCallback func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request) + +var closeConnCallback ServerConnCallback = func(conn *ssh.ServerConn, ncc <-chan ssh.NewChannel, rq <-chan *ssh.Request) { + conn.Close() +} + +// SSHUpstream represents an ssh server which can be used as the target for +// one or more Pomerium routes in a test environment. +// +// Use SetServerConnCallback() to define the behavior of this server once +// a connection is established. +// +// Dial() can be used to make a client-side connection through the Pomerium +// route, while DirectDial() can be used to connect bypassing Pomerium. +type SSHUpstream interface { + testenv.Upstream + + SetServerConnCallback(callback ServerConnCallback) + + Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) + DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) +} + +type sshUpstream struct { + SSHUpstreamOptions + testenv.Aggregate + serverPort values.MutableValue[int] + + // XXX: does it make sense to cache clients? + //clientCache sync.Map // map[testenv.Route]*ssh.Client + + serverConnCallback ServerConnCallback +} + +var ( + _ testenv.Upstream = (*sshUpstream)(nil) + _ SSHUpstream = (*sshUpstream)(nil) +) + +// SSH creates a new ssh upstream server. +func SSH(opts ...SSHUpstreamOption) SSHUpstream { + options := SSHUpstreamOptions{ + displayName: "SSH Upstream", + } + for _, op := range opts { + op.applySSH(&options) + } + up := &sshUpstream{ + SSHUpstreamOptions: options, + serverPort: values.Deferred[int](), + serverConnCallback: closeConnCallback, // default handler, to avoid hanging connections + } + up.RecordCaller() + return up +} + +// Port implements SSHUpstream. +func (h *sshUpstream) Port() values.Value[int] { + return h.serverPort +} + +// Router implements SSHUpstream. +func (h *sshUpstream) SetServerConnCallback(callback ServerConnCallback) { + h.serverConnCallback = callback +} + +// Route implements SSHUpstream. +func (h *sshUpstream) Route() testenv.RouteStub { + r := &testenv.PolicyRoute{} + protocol := "ssh" + r.To(values.Bind(h.serverPort, func(port int) string { + return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) + })) + h.Add(r) + return r +} + +// Run implements SSHUpstream. +func (h *sshUpstream) Run(ctx context.Context) error { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) + + go func() { + <-ctx.Done() + listener.Close() + }() + + for { + conn, err := listener.Accept() + if err != nil { + return err + } + go h.handleConnection(ctx, conn) + } +} + +func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) { + serverConn, ncc, rc, err := ssh.NewServerConn(conn, &h.serverConfig) + if err != nil { + // XXX: figure out the right way to log this + log.Println("ssh connection handshake unsuccessful:", err) + return + } + go func() { + <-ctx.Done() + conn.Close() + }() + h.serverConnCallback(serverConn, ncc, rc) +} + +// Dial implements SSHUpstream. +func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) { + // XXX: need to add ssh listener configuration to Env + //ssh.Dial("tcp", h.Env().) + return nil, fmt.Errorf("not implemented") +} + +// DirectDial implements SSHUpstream. +func (h *sshUpstream) DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) { + addr := fmt.Sprintf("127.0.0.1:%d", h.serverPort.Value()) + return ssh.Dial("tcp", addr, config) +}