starting ssh proxy test harness

This commit is contained in:
Kenneth Jenkins 2025-03-18 12:36:00 -07:00
parent 45da45a7a3
commit d1847e0c94
3 changed files with 287 additions and 1 deletions

View file

@ -0,0 +1,98 @@
package ssh
import (
"bytes"
"crypto/ed25519"
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
)
func TestSSH(t *testing.T) {
// generate client ssh key
_, priv, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(priv)
require.NoError(t, err)
// ssh client setup
clientConfig := &ssh.ClientConfig{
User: "demo",
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// pomerium + upstream setup
env := testenv.New(t)
up := upstreams.SSH()
r := up.Route().
From(env.SubdomainURLWithScheme("ssh", "ssh")).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
// test scenario -- first verify that the upstream is working at all
client, err := up.DirectDial(r, clientConfig)
require.NoError(t, err)
defer client.Close()
sess, err := client.NewSession()
require.NoError(t, err)
defer sess.Close()
}
func TestHelloWorld(t *testing.T) {
t.Skip("debugging...")
key, err := os.ReadFile("/Users/kjenkins/scratch/sshd/demo_key")
require.NoError(t, err)
signer, err := ssh.ParsePrivateKey(key)
require.NoError(t, err)
config := &ssh.ClientConfig{
User: "demo",
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
conn, err := ssh.Dial("tcp", "localhost:2222", config)
require.NoError(t, err, "unable to connect")
defer conn.Close()
//conn.ServerVersion()
sess, err := conn.NewSession()
require.NoError(t, err, "unable to start session")
defer sess.Close()
var output bytes.Buffer
sess.Stdout = &output
sess.Stdin = strings.NewReader("whoami\n")
err = sess.Shell()
fmt.Println("Shell() returned ", err)
err = sess.Wait()
fmt.Println("Wait() returned ", err)
fmt.Println(" --> output:\n\n", output.String())
//sess.SendRequest()
}

View file

@ -141,6 +141,13 @@ type Environment interface {
// can be used as the 'from' value for routes. // can be used as the 'from' value for routes.
SubdomainURL(subdomain string) values.Value[string] SubdomainURL(subdomain string) values.Value[string]
// SubdomainURL returns a string [values.Value] which will contain a complete
// URL for the given subdomain of the server's domain (given by its serving
// certificate), with the given scheme and the random http server port.
// This value will only be resolved some time after Start() is called, and
// can be used as the 'from' value for routes.
SubdomainURLWithScheme(subdomain, scheme string) values.Value[string]
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for // NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
// the Pomerium server and Envoy. // the Pomerium server and Envoy.
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
@ -510,8 +517,12 @@ func (e *environment) Require() *require.Assertions {
} }
func (e *environment) SubdomainURL(subdomain string) values.Value[string] { func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
return e.SubdomainURLWithScheme(subdomain, "https")
}
func (e *environment) SubdomainURLWithScheme(subdomain, scheme string) values.Value[string] {
return values.Bind(e.ports.ProxyHTTP, func(port int) string { return values.Bind(e.ports.ProxyHTTP, func(port int) string {
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port) return fmt.Sprintf("%s://%s.%s:%d", scheme, subdomain, e.domain, port)
}) })
} }

View file

@ -0,0 +1,177 @@
package upstreams
import (
"context"
"fmt"
"log"
"net"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"golang.org/x/crypto/ssh"
)
type SSHUpstreamOptions struct {
displayName string
serverConfig ssh.ServerConfig
}
type SSHUpstreamOption interface {
applySSH(*SSHUpstreamOptions)
}
type sshUpstreamOption func(o *SSHUpstreamOptions)
func (s sshUpstreamOption) applySSH(o *SSHUpstreamOptions) { s(o) }
func WithPublicKeyAuthAlgorithms(algs []string) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
o.serverConfig.PublicKeyAuthAlgorithms = algs
})
}
func WithHostKeys(keys ...ssh.Signer) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
for _, key := range keys {
o.serverConfig.AddHostKey(key)
}
})
}
func WithBannerCallback(c func(ssh.ConnMetadata) string) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
o.serverConfig.BannerCallback = c
})
}
func WithPublicKeyCallback(c func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
o.serverConfig.PublicKeyCallback = c
})
}
type ServerConnCallback func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)
var closeConnCallback ServerConnCallback = func(conn *ssh.ServerConn, ncc <-chan ssh.NewChannel, rq <-chan *ssh.Request) {
conn.Close()
}
// SSHUpstream represents an ssh server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// Use SetServerConnCallback() to define the behavior of this server once
// a connection is established.
//
// Dial() can be used to make a client-side connection through the Pomerium
// route, while DirectDial() can be used to connect bypassing Pomerium.
type SSHUpstream interface {
testenv.Upstream
SetServerConnCallback(callback ServerConnCallback)
Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
}
type sshUpstream struct {
SSHUpstreamOptions
testenv.Aggregate
serverPort values.MutableValue[int]
// XXX: does it make sense to cache clients?
//clientCache sync.Map // map[testenv.Route]*ssh.Client
serverConnCallback ServerConnCallback
}
var (
_ testenv.Upstream = (*sshUpstream)(nil)
_ SSHUpstream = (*sshUpstream)(nil)
)
// SSH creates a new ssh upstream server.
func SSH(opts ...SSHUpstreamOption) SSHUpstream {
options := SSHUpstreamOptions{
displayName: "SSH Upstream",
}
for _, op := range opts {
op.applySSH(&options)
}
up := &sshUpstream{
SSHUpstreamOptions: options,
serverPort: values.Deferred[int](),
serverConnCallback: closeConnCallback, // default handler, to avoid hanging connections
}
up.RecordCaller()
return up
}
// Port implements SSHUpstream.
func (h *sshUpstream) Port() values.Value[int] {
return h.serverPort
}
// Router implements SSHUpstream.
func (h *sshUpstream) SetServerConnCallback(callback ServerConnCallback) {
h.serverConnCallback = callback
}
// Route implements SSHUpstream.
func (h *sshUpstream) Route() testenv.RouteStub {
r := &testenv.PolicyRoute{}
protocol := "ssh"
r.To(values.Bind(h.serverPort, func(port int) string {
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
}))
h.Add(r)
return r
}
// Run implements SSHUpstream.
func (h *sshUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
go func() {
<-ctx.Done()
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go h.handleConnection(ctx, conn)
}
}
func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) {
serverConn, ncc, rc, err := ssh.NewServerConn(conn, &h.serverConfig)
if err != nil {
// XXX: figure out the right way to log this
log.Println("ssh connection handshake unsuccessful:", err)
return
}
go func() {
<-ctx.Done()
conn.Close()
}()
h.serverConnCallback(serverConn, ncc, rc)
}
// Dial implements SSHUpstream.
func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
// XXX: need to add ssh listener configuration to Env
//ssh.Dial("tcp", h.Env().)
return nil, fmt.Errorf("not implemented")
}
// DirectDial implements SSHUpstream.
func (h *sshUpstream) DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
addr := fmt.Sprintf("127.0.0.1:%d", h.serverPort.Value())
return ssh.Dial("tcp", addr, config)
}