mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-03 04:16:03 +02:00
Introduce a new SSHUpstream in the testenv package along with some related machinery for configuring Pomerium with ssh routes. Add a basic test case that configures one ssh upstream and attempts an ssh connection to Pomerium itself.
211 lines
5.9 KiB
Go
211 lines
5.9 KiB
Go
package upstreams
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
|
|
"github.com/pomerium/pomerium/internal/testenv"
|
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type SSHUpstreamOptions struct {
|
|
displayName string
|
|
serverConfig ssh.ServerConfig
|
|
authorizedKeys authorizedKeysChecker
|
|
}
|
|
|
|
type SSHUpstreamOption interface {
|
|
applySSH(*SSHUpstreamOptions)
|
|
}
|
|
|
|
type sshUpstreamOption func(o *SSHUpstreamOptions)
|
|
|
|
func (s sshUpstreamOption) applySSH(o *SSHUpstreamOptions) { s(o) }
|
|
|
|
func WithPublicKeyAuthAlgorithms(algs ...string) SSHUpstreamOption {
|
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
|
o.serverConfig.PublicKeyAuthAlgorithms = algs
|
|
})
|
|
}
|
|
|
|
func WithHostKeys(keys ...ssh.Signer) SSHUpstreamOption {
|
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
|
for _, key := range keys {
|
|
o.serverConfig.AddHostKey(key)
|
|
}
|
|
})
|
|
}
|
|
|
|
func WithBannerCallback(c func(ssh.ConnMetadata) string) SSHUpstreamOption {
|
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
|
o.serverConfig.BannerCallback = c
|
|
})
|
|
}
|
|
|
|
// WithPublicKeyCallback sets a custom callback for the publickey authentication method.
|
|
// This will override any previous [WithAuthorizedKey] option.
|
|
func WithPublicKeyCallback(c func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) SSHUpstreamOption {
|
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
|
o.serverConfig.PublicKeyCallback = c
|
|
})
|
|
}
|
|
|
|
// WithAuthorizedKey allows the given key to be used to authenticate the given username,
|
|
// enabling the publickey authentication method. This will override any previous
|
|
// [WithPublicKeyCallback] option.
|
|
func WithAuthorizedKey(key ssh.PublicKey, username string) SSHUpstreamOption {
|
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
|
o.authorizedKeys.add(key, username)
|
|
o.serverConfig.PublicKeyCallback = o.authorizedKeys.check
|
|
})
|
|
}
|
|
|
|
type authorizedKeysChecker map[string]string // map from marshaled public key to corresponding username
|
|
|
|
func (c *authorizedKeysChecker) add(key ssh.PublicKey, username string) {
|
|
if *c == nil {
|
|
*c = make(map[string]string)
|
|
}
|
|
(*c)[string(key.Marshal())] = username
|
|
}
|
|
|
|
func (c authorizedKeysChecker) check(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
if c[string(key.Marshal())] == conn.User() {
|
|
return &ssh.Permissions{}, nil
|
|
}
|
|
return nil, fmt.Errorf("not authorized")
|
|
}
|
|
|
|
type ServerConnCallback func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)
|
|
|
|
var closeConnCallback ServerConnCallback = func(conn *ssh.ServerConn, ncc <-chan ssh.NewChannel, rq <-chan *ssh.Request) {
|
|
conn.Close()
|
|
}
|
|
|
|
// SSHUpstream represents an ssh server which can be used as the target for
|
|
// one or more Pomerium routes in a test environment.
|
|
//
|
|
// Use SetServerConnCallback() to define the behavior of this server once
|
|
// a connection is established.
|
|
//
|
|
// Dial() can be used to make a client-side connection through the Pomerium
|
|
// route, while DirectDial() can be used to connect bypassing Pomerium.
|
|
type SSHUpstream interface {
|
|
testenv.Upstream
|
|
|
|
SetServerConnCallback(callback ServerConnCallback)
|
|
|
|
Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
|
|
DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
|
|
}
|
|
|
|
type sshUpstream struct {
|
|
SSHUpstreamOptions
|
|
testenv.Aggregate
|
|
serverPort values.MutableValue[int]
|
|
|
|
// XXX: does it make sense to cache clients?
|
|
//clientCache sync.Map // map[testenv.Route]*ssh.Client
|
|
|
|
serverConnCallback ServerConnCallback
|
|
}
|
|
|
|
var (
|
|
_ testenv.Upstream = (*sshUpstream)(nil)
|
|
_ SSHUpstream = (*sshUpstream)(nil)
|
|
)
|
|
|
|
// SSH creates a new ssh upstream server.
|
|
func SSH(opts ...SSHUpstreamOption) SSHUpstream {
|
|
options := SSHUpstreamOptions{
|
|
displayName: "SSH Upstream",
|
|
}
|
|
for _, op := range opts {
|
|
op.applySSH(&options)
|
|
}
|
|
up := &sshUpstream{
|
|
SSHUpstreamOptions: options,
|
|
serverPort: values.Deferred[int](),
|
|
serverConnCallback: closeConnCallback, // default handler, to avoid hanging connections
|
|
}
|
|
up.RecordCaller()
|
|
return up
|
|
}
|
|
|
|
// Addr implements SSHUpstream.
|
|
func (h *sshUpstream) Addr() values.Value[string] {
|
|
return values.Bind(h.serverPort, func(port int) string {
|
|
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
|
|
})
|
|
}
|
|
|
|
// Router implements SSHUpstream.
|
|
func (h *sshUpstream) SetServerConnCallback(callback ServerConnCallback) {
|
|
h.serverConnCallback = callback
|
|
}
|
|
|
|
// Route implements SSHUpstream.
|
|
func (h *sshUpstream) Route() testenv.RouteStub {
|
|
r := &testenv.PolicyRoute{}
|
|
protocol := "ssh"
|
|
r.To(values.Bind(h.serverPort, func(port int) string {
|
|
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
|
|
}))
|
|
h.Add(r)
|
|
return r
|
|
}
|
|
|
|
// Run implements SSHUpstream.
|
|
func (h *sshUpstream) Run(ctx context.Context) error {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
listener.Close()
|
|
}()
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
// The testenv cleanup expects this function to return a "test cleanup" error,
|
|
// which propagates via the context.
|
|
if ctx.Err() != nil {
|
|
return context.Cause(ctx)
|
|
}
|
|
return err
|
|
}
|
|
go h.handleConnection(ctx, conn)
|
|
}
|
|
}
|
|
|
|
func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) {
|
|
serverConn, ncc, rc, err := ssh.NewServerConn(conn, &h.serverConfig)
|
|
if err != nil {
|
|
// XXX: figure out the right way to log this
|
|
log.Println("ssh connection handshake unsuccessful:", err)
|
|
return
|
|
}
|
|
go func() {
|
|
<-ctx.Done()
|
|
conn.Close()
|
|
}()
|
|
h.serverConnCallback(serverConn, ncc, rc)
|
|
}
|
|
|
|
// Dial implements SSHUpstream.
|
|
func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
|
|
return ssh.Dial("tcp", h.Env().Config().Options.SSHAddr, config)
|
|
}
|
|
|
|
// DirectDial implements SSHUpstream.
|
|
func (h *sshUpstream) DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
|
|
addr := fmt.Sprintf("127.0.0.1:%d", h.serverPort.Value())
|
|
return ssh.Dial("tcp", addr, config)
|
|
}
|