pomerium/internal/testenv/upstreams/ssh.go
Kenneth Jenkins d5c60b3597 add a bare-bones ssh integration test
Introduce a new SSHUpstream in the testenv package along with some
related machinery for configuring Pomerium with ssh routes. Add a basic
test case that configures one ssh upstream and attempts an ssh
connection to Pomerium itself.
2025-04-08 16:42:43 -07:00

211 lines
5.9 KiB
Go

package upstreams
import (
"context"
"fmt"
"log"
"net"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"golang.org/x/crypto/ssh"
)
type SSHUpstreamOptions struct {
displayName string
serverConfig ssh.ServerConfig
authorizedKeys authorizedKeysChecker
}
type SSHUpstreamOption interface {
applySSH(*SSHUpstreamOptions)
}
type sshUpstreamOption func(o *SSHUpstreamOptions)
func (s sshUpstreamOption) applySSH(o *SSHUpstreamOptions) { s(o) }
func WithPublicKeyAuthAlgorithms(algs ...string) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
o.serverConfig.PublicKeyAuthAlgorithms = algs
})
}
func WithHostKeys(keys ...ssh.Signer) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
for _, key := range keys {
o.serverConfig.AddHostKey(key)
}
})
}
func WithBannerCallback(c func(ssh.ConnMetadata) string) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
o.serverConfig.BannerCallback = c
})
}
// WithPublicKeyCallback sets a custom callback for the publickey authentication method.
// This will override any previous [WithAuthorizedKey] option.
func WithPublicKeyCallback(c func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
o.serverConfig.PublicKeyCallback = c
})
}
// WithAuthorizedKey allows the given key to be used to authenticate the given username,
// enabling the publickey authentication method. This will override any previous
// [WithPublicKeyCallback] option.
func WithAuthorizedKey(key ssh.PublicKey, username string) SSHUpstreamOption {
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
o.authorizedKeys.add(key, username)
o.serverConfig.PublicKeyCallback = o.authorizedKeys.check
})
}
type authorizedKeysChecker map[string]string // map from marshaled public key to corresponding username
func (c *authorizedKeysChecker) add(key ssh.PublicKey, username string) {
if *c == nil {
*c = make(map[string]string)
}
(*c)[string(key.Marshal())] = username
}
func (c authorizedKeysChecker) check(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if c[string(key.Marshal())] == conn.User() {
return &ssh.Permissions{}, nil
}
return nil, fmt.Errorf("not authorized")
}
type ServerConnCallback func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)
var closeConnCallback ServerConnCallback = func(conn *ssh.ServerConn, ncc <-chan ssh.NewChannel, rq <-chan *ssh.Request) {
conn.Close()
}
// SSHUpstream represents an ssh server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// Use SetServerConnCallback() to define the behavior of this server once
// a connection is established.
//
// Dial() can be used to make a client-side connection through the Pomerium
// route, while DirectDial() can be used to connect bypassing Pomerium.
type SSHUpstream interface {
testenv.Upstream
SetServerConnCallback(callback ServerConnCallback)
Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
}
type sshUpstream struct {
SSHUpstreamOptions
testenv.Aggregate
serverPort values.MutableValue[int]
// XXX: does it make sense to cache clients?
//clientCache sync.Map // map[testenv.Route]*ssh.Client
serverConnCallback ServerConnCallback
}
var (
_ testenv.Upstream = (*sshUpstream)(nil)
_ SSHUpstream = (*sshUpstream)(nil)
)
// SSH creates a new ssh upstream server.
func SSH(opts ...SSHUpstreamOption) SSHUpstream {
options := SSHUpstreamOptions{
displayName: "SSH Upstream",
}
for _, op := range opts {
op.applySSH(&options)
}
up := &sshUpstream{
SSHUpstreamOptions: options,
serverPort: values.Deferred[int](),
serverConnCallback: closeConnCallback, // default handler, to avoid hanging connections
}
up.RecordCaller()
return up
}
// Addr implements SSHUpstream.
func (h *sshUpstream) Addr() values.Value[string] {
return values.Bind(h.serverPort, func(port int) string {
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
})
}
// Router implements SSHUpstream.
func (h *sshUpstream) SetServerConnCallback(callback ServerConnCallback) {
h.serverConnCallback = callback
}
// Route implements SSHUpstream.
func (h *sshUpstream) Route() testenv.RouteStub {
r := &testenv.PolicyRoute{}
protocol := "ssh"
r.To(values.Bind(h.serverPort, func(port int) string {
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
}))
h.Add(r)
return r
}
// Run implements SSHUpstream.
func (h *sshUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
go func() {
<-ctx.Done()
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
// The testenv cleanup expects this function to return a "test cleanup" error,
// which propagates via the context.
if ctx.Err() != nil {
return context.Cause(ctx)
}
return err
}
go h.handleConnection(ctx, conn)
}
}
func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) {
serverConn, ncc, rc, err := ssh.NewServerConn(conn, &h.serverConfig)
if err != nil {
// XXX: figure out the right way to log this
log.Println("ssh connection handshake unsuccessful:", err)
return
}
go func() {
<-ctx.Done()
conn.Close()
}()
h.serverConnCallback(serverConn, ncc, rc)
}
// Dial implements SSHUpstream.
func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
return ssh.Dial("tcp", h.Env().Config().Options.SSHAddr, config)
}
// DirectDial implements SSHUpstream.
func (h *sshUpstream) DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
addr := fmt.Sprintf("127.0.0.1:%d", h.serverPort.Value())
return ssh.Dial("tcp", addr, config)
}