mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
starting ssh proxy test harness
This commit is contained in:
parent
45da45a7a3
commit
d1847e0c94
3 changed files with 287 additions and 1 deletions
98
integration2/ssh_int_test.go
Normal file
98
integration2/ssh_int_test.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSH(t *testing.T) {
|
||||||
|
// generate client ssh key
|
||||||
|
_, priv, err := ed25519.GenerateKey(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
signer, err := ssh.NewSignerFromKey(priv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// ssh client setup
|
||||||
|
clientConfig := &ssh.ClientConfig{
|
||||||
|
User: "demo",
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.PublicKeys(signer),
|
||||||
|
},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// pomerium + upstream setup
|
||||||
|
env := testenv.New(t)
|
||||||
|
|
||||||
|
up := upstreams.SSH()
|
||||||
|
r := up.Route().
|
||||||
|
From(env.SubdomainURLWithScheme("ssh", "ssh")).
|
||||||
|
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
// test scenario -- first verify that the upstream is working at all
|
||||||
|
client, err := up.DirectDial(r, clientConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
sess, err := client.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sess.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHelloWorld(t *testing.T) {
|
||||||
|
t.Skip("debugging...")
|
||||||
|
|
||||||
|
key, err := os.ReadFile("/Users/kjenkins/scratch/sshd/demo_key")
|
||||||
|
require.NoError(t, err)
|
||||||
|
signer, err := ssh.ParsePrivateKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: "demo",
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.PublicKeys(signer),
|
||||||
|
},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := ssh.Dial("tcp", "localhost:2222", config)
|
||||||
|
require.NoError(t, err, "unable to connect")
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
//conn.ServerVersion()
|
||||||
|
|
||||||
|
sess, err := conn.NewSession()
|
||||||
|
require.NoError(t, err, "unable to start session")
|
||||||
|
defer sess.Close()
|
||||||
|
|
||||||
|
var output bytes.Buffer
|
||||||
|
sess.Stdout = &output
|
||||||
|
sess.Stdin = strings.NewReader("whoami\n")
|
||||||
|
|
||||||
|
err = sess.Shell()
|
||||||
|
|
||||||
|
fmt.Println("Shell() returned ", err)
|
||||||
|
|
||||||
|
err = sess.Wait()
|
||||||
|
|
||||||
|
fmt.Println("Wait() returned ", err)
|
||||||
|
|
||||||
|
fmt.Println(" --> output:\n\n", output.String())
|
||||||
|
|
||||||
|
//sess.SendRequest()
|
||||||
|
}
|
|
@ -141,6 +141,13 @@ type Environment interface {
|
||||||
// can be used as the 'from' value for routes.
|
// can be used as the 'from' value for routes.
|
||||||
SubdomainURL(subdomain string) values.Value[string]
|
SubdomainURL(subdomain string) values.Value[string]
|
||||||
|
|
||||||
|
// SubdomainURL returns a string [values.Value] which will contain a complete
|
||||||
|
// URL for the given subdomain of the server's domain (given by its serving
|
||||||
|
// certificate), with the given scheme and the random http server port.
|
||||||
|
// This value will only be resolved some time after Start() is called, and
|
||||||
|
// can be used as the 'from' value for routes.
|
||||||
|
SubdomainURLWithScheme(subdomain, scheme string) values.Value[string]
|
||||||
|
|
||||||
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
|
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
|
||||||
// the Pomerium server and Envoy.
|
// the Pomerium server and Envoy.
|
||||||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||||
|
@ -510,8 +517,12 @@ func (e *environment) Require() *require.Assertions {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||||
|
return e.SubdomainURLWithScheme(subdomain, "https")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) SubdomainURLWithScheme(subdomain, scheme string) values.Value[string] {
|
||||||
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
|
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
|
||||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
return fmt.Sprintf("%s://%s.%s:%d", scheme, subdomain, e.domain, port)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
177
internal/testenv/upstreams/ssh.go
Normal file
177
internal/testenv/upstreams/ssh.go
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
package upstreams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SSHUpstreamOptions struct {
|
||||||
|
displayName string
|
||||||
|
serverConfig ssh.ServerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type SSHUpstreamOption interface {
|
||||||
|
applySSH(*SSHUpstreamOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshUpstreamOption func(o *SSHUpstreamOptions)
|
||||||
|
|
||||||
|
func (s sshUpstreamOption) applySSH(o *SSHUpstreamOptions) { s(o) }
|
||||||
|
|
||||||
|
func WithPublicKeyAuthAlgorithms(algs []string) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
o.serverConfig.PublicKeyAuthAlgorithms = algs
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHostKeys(keys ...ssh.Signer) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
for _, key := range keys {
|
||||||
|
o.serverConfig.AddHostKey(key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBannerCallback(c func(ssh.ConnMetadata) string) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
o.serverConfig.BannerCallback = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPublicKeyCallback(c func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
o.serverConfig.PublicKeyCallback = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConnCallback func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)
|
||||||
|
|
||||||
|
var closeConnCallback ServerConnCallback = func(conn *ssh.ServerConn, ncc <-chan ssh.NewChannel, rq <-chan *ssh.Request) {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHUpstream represents an ssh server which can be used as the target for
|
||||||
|
// one or more Pomerium routes in a test environment.
|
||||||
|
//
|
||||||
|
// Use SetServerConnCallback() to define the behavior of this server once
|
||||||
|
// a connection is established.
|
||||||
|
//
|
||||||
|
// Dial() can be used to make a client-side connection through the Pomerium
|
||||||
|
// route, while DirectDial() can be used to connect bypassing Pomerium.
|
||||||
|
type SSHUpstream interface {
|
||||||
|
testenv.Upstream
|
||||||
|
|
||||||
|
SetServerConnCallback(callback ServerConnCallback)
|
||||||
|
|
||||||
|
Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
|
||||||
|
DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshUpstream struct {
|
||||||
|
SSHUpstreamOptions
|
||||||
|
testenv.Aggregate
|
||||||
|
serverPort values.MutableValue[int]
|
||||||
|
|
||||||
|
// XXX: does it make sense to cache clients?
|
||||||
|
//clientCache sync.Map // map[testenv.Route]*ssh.Client
|
||||||
|
|
||||||
|
serverConnCallback ServerConnCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ testenv.Upstream = (*sshUpstream)(nil)
|
||||||
|
_ SSHUpstream = (*sshUpstream)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSH creates a new ssh upstream server.
|
||||||
|
func SSH(opts ...SSHUpstreamOption) SSHUpstream {
|
||||||
|
options := SSHUpstreamOptions{
|
||||||
|
displayName: "SSH Upstream",
|
||||||
|
}
|
||||||
|
for _, op := range opts {
|
||||||
|
op.applySSH(&options)
|
||||||
|
}
|
||||||
|
up := &sshUpstream{
|
||||||
|
SSHUpstreamOptions: options,
|
||||||
|
serverPort: values.Deferred[int](),
|
||||||
|
serverConnCallback: closeConnCallback, // default handler, to avoid hanging connections
|
||||||
|
}
|
||||||
|
up.RecordCaller()
|
||||||
|
return up
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Port() values.Value[int] {
|
||||||
|
return h.serverPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) SetServerConnCallback(callback ServerConnCallback) {
|
||||||
|
h.serverConnCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Route() testenv.RouteStub {
|
||||||
|
r := &testenv.PolicyRoute{}
|
||||||
|
protocol := "ssh"
|
||||||
|
r.To(values.Bind(h.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||||
|
}))
|
||||||
|
h.Add(r)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Run(ctx context.Context) error {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go h.handleConnection(ctx, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) {
|
||||||
|
serverConn, ncc, rc, err := ssh.NewServerConn(conn, &h.serverConfig)
|
||||||
|
if err != nil {
|
||||||
|
// XXX: figure out the right way to log this
|
||||||
|
log.Println("ssh connection handshake unsuccessful:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
h.serverConnCallback(serverConn, ncc, rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
// XXX: need to add ssh listener configuration to Env
|
||||||
|
//ssh.Dial("tcp", h.Env().)
|
||||||
|
return nil, fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DirectDial implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", h.serverPort.Value())
|
||||||
|
return ssh.Dial("tcp", addr, config)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue