mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-03 20:36:03 +02:00
add a bare-bones ssh integration test
Introduce a new SSHUpstream in the testenv package along with some related machinery for configuring Pomerium with ssh routes. Add a basic test case that configures one ssh upstream and attempts an ssh connection to Pomerium itself.
This commit is contained in:
parent
319a801e1d
commit
d5c60b3597
6 changed files with 468 additions and 5 deletions
67
integration2/ssh_int_test.go
Normal file
67
integration2/ssh_int_test.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSH(t *testing.T) {
|
||||||
|
clientKey := newSSHKey(t)
|
||||||
|
serverHostKey := newSSHKey(t)
|
||||||
|
|
||||||
|
// ssh client setup
|
||||||
|
var ki scenarios.EmptyKeyboardInteractiveChallenge
|
||||||
|
clientConfig := &ssh.ClientConfig{
|
||||||
|
User: "demo",
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.PublicKeys(clientKey),
|
||||||
|
ssh.KeyboardInteractive(ki.Do),
|
||||||
|
},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// pomerium + upstream setup
|
||||||
|
env := testenv.New(t)
|
||||||
|
|
||||||
|
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}, scenarios.WithEnableDeviceAuth(true)))
|
||||||
|
env.Add(scenarios.SSH(scenarios.SSHConfig{}))
|
||||||
|
env.Add(&ki)
|
||||||
|
|
||||||
|
up := upstreams.SSH(
|
||||||
|
upstreams.WithHostKeys(serverHostKey),
|
||||||
|
upstreams.WithAuthorizedKey(clientKey.PublicKey(), "demo"))
|
||||||
|
r := up.Route().
|
||||||
|
From(env.SubdomainURLWithScheme("ssh", "ssh")).
|
||||||
|
Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true })
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
// verify that a connection can be established
|
||||||
|
client, err := up.Dial(r, clientConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
sess, err := client.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sess.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSSHKey generates and returns a new Ed25519 ssh key.
|
||||||
|
func newSSHKey(t *testing.T) ssh.Signer {
|
||||||
|
t.Helper()
|
||||||
|
_, priv, err := ed25519.GenerateKey(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
signer, err := ssh.NewSignerFromKey(priv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return signer
|
||||||
|
}
|
|
@ -104,6 +104,8 @@ type Environment interface {
|
||||||
SharedSecret() []byte
|
SharedSecret() []byte
|
||||||
CookieSecret() []byte
|
CookieSecret() []byte
|
||||||
|
|
||||||
|
Config() *config.Config
|
||||||
|
|
||||||
// Add adds the given [Modifier] to the environment. All modifiers will be
|
// Add adds the given [Modifier] to the environment. All modifiers will be
|
||||||
// invoked upon calling Start() to apply individual modifications to the
|
// invoked upon calling Start() to apply individual modifications to the
|
||||||
// configuration before starting the Pomerium server.
|
// configuration before starting the Pomerium server.
|
||||||
|
@ -141,6 +143,13 @@ type Environment interface {
|
||||||
// can be used as the 'from' value for routes.
|
// can be used as the 'from' value for routes.
|
||||||
SubdomainURL(subdomain string) values.Value[string]
|
SubdomainURL(subdomain string) values.Value[string]
|
||||||
|
|
||||||
|
// SubdomainURL returns a string [values.Value] which will contain a complete
|
||||||
|
// URL for the given subdomain of the server's domain (given by its serving
|
||||||
|
// certificate), with the given scheme and the random http server port.
|
||||||
|
// This value will only be resolved some time after Start() is called, and
|
||||||
|
// can be used as the 'from' value for routes.
|
||||||
|
SubdomainURLWithScheme(subdomain, scheme string) values.Value[string]
|
||||||
|
|
||||||
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
|
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
|
||||||
// the Pomerium server and Envoy.
|
// the Pomerium server and Envoy.
|
||||||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||||
|
@ -510,8 +519,12 @@ func (e *environment) Require() *require.Assertions {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||||
|
return e.SubdomainURLWithScheme(subdomain, "https")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) SubdomainURLWithScheme(subdomain, scheme string) values.Value[string] {
|
||||||
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
|
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
|
||||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
return fmt.Sprintf("%s://%s.%s:%d", scheme, subdomain, e.domain, port)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -536,6 +549,10 @@ func (e *environment) Host() string {
|
||||||
return e.host
|
return e.host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *environment) Config() *config.Config {
|
||||||
|
return e.src.GetConfig()
|
||||||
|
}
|
||||||
|
|
||||||
func (e *environment) CACert() *tls.Certificate {
|
func (e *environment) CACert() *tls.Certificate {
|
||||||
caCert, err := tls.LoadX509KeyPair(
|
caCert, err := tls.LoadX509KeyPair(
|
||||||
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||||
|
|
|
@ -20,6 +20,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -46,6 +48,7 @@ type IDP struct {
|
||||||
|
|
||||||
type IDPOptions struct {
|
type IDPOptions struct {
|
||||||
enableTLS bool
|
enableTLS bool
|
||||||
|
enableDeviceAuth bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type IDPOption func(*IDPOptions)
|
type IDPOption func(*IDPOptions)
|
||||||
|
@ -62,6 +65,12 @@ func WithEnableTLS(enableTLS bool) IDPOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithEnableDeviceAuth(enableDeviceAuth bool) IDPOption {
|
||||||
|
return func(o *IDPOptions) {
|
||||||
|
o.enableDeviceAuth = enableDeviceAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Attach implements testenv.Modifier.
|
// Attach implements testenv.Modifier.
|
||||||
func (idp *IDP) Attach(ctx context.Context) {
|
func (idp *IDP) Attach(ctx context.Context) {
|
||||||
env := testenv.EnvFromContext(ctx)
|
env := testenv.EnvFromContext(ctx)
|
||||||
|
@ -120,7 +129,7 @@ func (idp *IDP) Attach(ctx context.Context) {
|
||||||
router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send()
|
log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send()
|
||||||
rootURL, _ := url.Parse(idp.url.Value())
|
rootURL, _ := url.Parse(idp.url.Value())
|
||||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
config := map[string]interface{}{
|
||||||
"issuer": rootURL.String(),
|
"issuer": rootURL.String(),
|
||||||
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
||||||
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
||||||
|
@ -129,11 +138,19 @@ func (idp *IDP) Attach(ctx context.Context) {
|
||||||
"id_token_signing_alg_values_supported": []string{
|
"id_token_signing_alg_values_supported": []string{
|
||||||
"ES256",
|
"ES256",
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
if idp.enableDeviceAuth {
|
||||||
|
config["device_authorization_endpoint"] =
|
||||||
|
rootURL.ResolveReference(&url.URL{Path: "/oidc/device/code"}).String()
|
||||||
|
}
|
||||||
|
serveJSON(w, config)
|
||||||
})
|
})
|
||||||
router.Handle("/oidc/auth", idp.HandleAuth)
|
router.Handle("/oidc/auth", idp.HandleAuth)
|
||||||
router.Handle("/oidc/token", idp.HandleToken)
|
router.Handle("/oidc/token", idp.HandleToken)
|
||||||
router.Handle("/oidc/userinfo", idp.HandleUserInfo)
|
router.Handle("/oidc/userinfo", idp.HandleUserInfo)
|
||||||
|
if idp.enableDeviceAuth {
|
||||||
|
router.Handle("/oidc/device/code", idp.HandleDeviceCode)
|
||||||
|
}
|
||||||
|
|
||||||
env.AddUpstream(router)
|
env.AddUpstream(router)
|
||||||
}
|
}
|
||||||
|
@ -258,14 +275,25 @@ func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// HandleToken handles the token flow for OIDC.
|
// HandleToken handles the token flow for OIDC.
|
||||||
func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) {
|
func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
rawCode := r.FormValue("code")
|
if idp.enableDeviceAuth && r.FormValue("device_code") != "" {
|
||||||
|
idp.serveToken(w, r, &State{
|
||||||
|
ClientID: r.FormValue("client_id"),
|
||||||
|
Email: "fake.user@example.com",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawCode := r.FormValue("code")
|
||||||
state, err := DecodeState(rawCode)
|
state, err := DecodeState(rawCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
idp.serveToken(w, r, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (idp *IDP) serveToken(w http.ResponseWriter, r *http.Request, state *State) {
|
||||||
serveJSON(w, map[string]interface{}{
|
serveJSON(w, map[string]interface{}{
|
||||||
"access_token": state.Encode(),
|
"access_token": state.Encode(),
|
||||||
"refresh_token": state.Encode(),
|
"refresh_token": state.Encode(),
|
||||||
|
@ -300,6 +328,30 @@ func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandleDeviceCode initiates a device auth code flow.
|
||||||
|
//
|
||||||
|
// This is the bare minimum to simulate the device auth code flow. There is no client_id
|
||||||
|
// verification or any actual login.
|
||||||
|
func (idp *IDP) HandleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
|
deviceCode := "GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS"
|
||||||
|
userCode := "ABCD-EFGH"
|
||||||
|
|
||||||
|
rootURL, _ := url.Parse(idp.url.Value())
|
||||||
|
u := rootURL.ResolveReference(&url.URL{Path: "/oidc/device"}) // note: not actually implemented
|
||||||
|
verificationURI := u.String()
|
||||||
|
u.RawQuery = "user_code=" + userCode
|
||||||
|
verificationURIComplete := u.String()
|
||||||
|
|
||||||
|
serveJSON(w, &oauth2.DeviceAuthResponse{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
UserCode: userCode,
|
||||||
|
VerificationURI: verificationURI,
|
||||||
|
VerificationURIComplete: verificationURIComplete,
|
||||||
|
Expiry: time.Now().Add(5 * time.Minute),
|
||||||
|
Interval: 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type RootURLKey struct{}
|
type RootURLKey struct{}
|
||||||
|
|
||||||
var rootURLKey RootURLKey
|
var rootURLKey RootURLKey
|
||||||
|
|
114
internal/testenv/scenarios/ssh.go
Normal file
114
internal/testenv/scenarios/ssh.go
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
package scenarios
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/pkg/slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SSHConfig struct {
|
||||||
|
// SSH listener address. Defaults to ":2200" if not set.
|
||||||
|
Addr string
|
||||||
|
|
||||||
|
Hostname string
|
||||||
|
|
||||||
|
// Host key(s). An Ed25519 key will be generated if not set.
|
||||||
|
// Elements must be of a type supported by [ssh.NewSignerFromKey].
|
||||||
|
HostKeys []any
|
||||||
|
|
||||||
|
// User CA key, for signing SSH certificates used to authenticate to an
|
||||||
|
// upstream. An Ed25519 key will be generated if not set.
|
||||||
|
// Must be a type supported by [ssh.NewSignerFromKey].
|
||||||
|
UserCAKey any
|
||||||
|
}
|
||||||
|
|
||||||
|
func SSH(c SSHConfig) testenv.Modifier {
|
||||||
|
return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
|
||||||
|
env := testenv.EnvFromContext(ctx)
|
||||||
|
|
||||||
|
// Apply defaults.
|
||||||
|
if c.Addr == "" {
|
||||||
|
c.Addr = ":2200"
|
||||||
|
}
|
||||||
|
if len(c.HostKeys) == 0 {
|
||||||
|
c.HostKeys = []any{newEd25519Key(env)}
|
||||||
|
}
|
||||||
|
if c.Hostname == "" {
|
||||||
|
// XXX: is there a reasonable default for this?
|
||||||
|
}
|
||||||
|
if c.UserCAKey == nil {
|
||||||
|
c.UserCAKey = newEd25519Key(env)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update configuration.
|
||||||
|
cfg.Options.SSHAddr = c.Addr
|
||||||
|
cfg.Options.SSHHostname = c.Hostname
|
||||||
|
cfg.Options.SSHHostKeys = slices.Map(c.HostKeys, func(key any) config.SSHKeyPair {
|
||||||
|
return writeSSHKeyPair(env, key)
|
||||||
|
})
|
||||||
|
cfg.Options.SSHUserCAKey = writeSSHKeyPair(env, c.UserCAKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEd25519Key(env testenv.Environment) ed25519.PrivateKey {
|
||||||
|
_, priv, err := ed25519.GenerateKey(nil)
|
||||||
|
env.Require().NoError(err)
|
||||||
|
return priv
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeSSHKeyPair takes a private key and writes SSH private and public key
|
||||||
|
// files to the test env temp directory, returning a [config.SSHKeyPair] with
|
||||||
|
// the written filenames. The key must be of a type supported by the
|
||||||
|
// [ssh.NewSignerFromKey] method.
|
||||||
|
func writeSSHKeyPair(env testenv.Environment, key any) config.SSHKeyPair {
|
||||||
|
signer, err := ssh.NewSignerFromKey(key)
|
||||||
|
pub := signer.PublicKey()
|
||||||
|
env.Require().NoError(err)
|
||||||
|
|
||||||
|
dir := env.TempDir()
|
||||||
|
basename := fmt.Sprintf("ssh-key-%d", rand.Int())
|
||||||
|
privname := filepath.Join(dir, basename)
|
||||||
|
pubname := privname + ".pub"
|
||||||
|
|
||||||
|
// marshal and write private key to disk
|
||||||
|
pemBlock, err := ssh.MarshalPrivateKey(key, "")
|
||||||
|
env.Require().NoError(err)
|
||||||
|
privkeyContents := pem.EncodeToMemory(pemBlock)
|
||||||
|
err = os.WriteFile(privname, privkeyContents, 0o600)
|
||||||
|
env.Require().NoError(err)
|
||||||
|
|
||||||
|
// marshal and write public key to disk
|
||||||
|
pubkeyContents := ssh.MarshalAuthorizedKey(pub)
|
||||||
|
err = os.WriteFile(pubname, pubkeyContents, 0o600)
|
||||||
|
env.Require().NoError(err)
|
||||||
|
|
||||||
|
return config.SSHKeyPair{
|
||||||
|
PublicKeyFile: pubname,
|
||||||
|
PrivateKeyFile: privname,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmptyKeyboardInteractiveChallenge responds to any keyboard-interactive
|
||||||
|
// challenges with zero prompts, and fails otherwise.
|
||||||
|
type EmptyKeyboardInteractiveChallenge struct {
|
||||||
|
testenv.DefaultAttach
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EmptyKeyboardInteractiveChallenge) Do(
|
||||||
|
name, instruction string, questions []string, echos []bool,
|
||||||
|
) (answers []string, err error) {
|
||||||
|
if len(questions) > 0 {
|
||||||
|
c.Env().Require().FailNow("unsupported keyboard-interactive challenge")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
|
@ -69,6 +69,8 @@ func (d *DefaultAttach) RecordCaller() {
|
||||||
d.caller = getCaller(4)
|
d.caller = getCaller(4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DefaultAttach) Modify(*config.Config) {}
|
||||||
|
|
||||||
// Aggregate should be embedded in types implementing [Modifier] when the type
|
// Aggregate should be embedded in types implementing [Modifier] when the type
|
||||||
// contains other modifiers. Used as an alternative to [DefaultAttach].
|
// contains other modifiers. Used as an alternative to [DefaultAttach].
|
||||||
// Embedding this struct will properly keep track of when constituent modifiers
|
// Embedding this struct will properly keep track of when constituent modifiers
|
||||||
|
|
211
internal/testenv/upstreams/ssh.go
Normal file
211
internal/testenv/upstreams/ssh.go
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
package upstreams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SSHUpstreamOptions struct {
|
||||||
|
displayName string
|
||||||
|
serverConfig ssh.ServerConfig
|
||||||
|
authorizedKeys authorizedKeysChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
type SSHUpstreamOption interface {
|
||||||
|
applySSH(*SSHUpstreamOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshUpstreamOption func(o *SSHUpstreamOptions)
|
||||||
|
|
||||||
|
func (s sshUpstreamOption) applySSH(o *SSHUpstreamOptions) { s(o) }
|
||||||
|
|
||||||
|
func WithPublicKeyAuthAlgorithms(algs ...string) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
o.serverConfig.PublicKeyAuthAlgorithms = algs
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHostKeys(keys ...ssh.Signer) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
for _, key := range keys {
|
||||||
|
o.serverConfig.AddHostKey(key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBannerCallback(c func(ssh.ConnMetadata) string) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
o.serverConfig.BannerCallback = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPublicKeyCallback sets a custom callback for the publickey authentication method.
|
||||||
|
// This will override any previous [WithAuthorizedKey] option.
|
||||||
|
func WithPublicKeyCallback(c func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
o.serverConfig.PublicKeyCallback = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAuthorizedKey allows the given key to be used to authenticate the given username,
|
||||||
|
// enabling the publickey authentication method. This will override any previous
|
||||||
|
// [WithPublicKeyCallback] option.
|
||||||
|
func WithAuthorizedKey(key ssh.PublicKey, username string) SSHUpstreamOption {
|
||||||
|
return sshUpstreamOption(func(o *SSHUpstreamOptions) {
|
||||||
|
o.authorizedKeys.add(key, username)
|
||||||
|
o.serverConfig.PublicKeyCallback = o.authorizedKeys.check
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type authorizedKeysChecker map[string]string // map from marshaled public key to corresponding username
|
||||||
|
|
||||||
|
func (c *authorizedKeysChecker) add(key ssh.PublicKey, username string) {
|
||||||
|
if *c == nil {
|
||||||
|
*c = make(map[string]string)
|
||||||
|
}
|
||||||
|
(*c)[string(key.Marshal())] = username
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c authorizedKeysChecker) check(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
if c[string(key.Marshal())] == conn.User() {
|
||||||
|
return &ssh.Permissions{}, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("not authorized")
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConnCallback func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)
|
||||||
|
|
||||||
|
var closeConnCallback ServerConnCallback = func(conn *ssh.ServerConn, ncc <-chan ssh.NewChannel, rq <-chan *ssh.Request) {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHUpstream represents an ssh server which can be used as the target for
|
||||||
|
// one or more Pomerium routes in a test environment.
|
||||||
|
//
|
||||||
|
// Use SetServerConnCallback() to define the behavior of this server once
|
||||||
|
// a connection is established.
|
||||||
|
//
|
||||||
|
// Dial() can be used to make a client-side connection through the Pomerium
|
||||||
|
// route, while DirectDial() can be used to connect bypassing Pomerium.
|
||||||
|
type SSHUpstream interface {
|
||||||
|
testenv.Upstream
|
||||||
|
|
||||||
|
SetServerConnCallback(callback ServerConnCallback)
|
||||||
|
|
||||||
|
Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
|
||||||
|
DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshUpstream struct {
|
||||||
|
SSHUpstreamOptions
|
||||||
|
testenv.Aggregate
|
||||||
|
serverPort values.MutableValue[int]
|
||||||
|
|
||||||
|
// XXX: does it make sense to cache clients?
|
||||||
|
//clientCache sync.Map // map[testenv.Route]*ssh.Client
|
||||||
|
|
||||||
|
serverConnCallback ServerConnCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ testenv.Upstream = (*sshUpstream)(nil)
|
||||||
|
_ SSHUpstream = (*sshUpstream)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSH creates a new ssh upstream server.
|
||||||
|
func SSH(opts ...SSHUpstreamOption) SSHUpstream {
|
||||||
|
options := SSHUpstreamOptions{
|
||||||
|
displayName: "SSH Upstream",
|
||||||
|
}
|
||||||
|
for _, op := range opts {
|
||||||
|
op.applySSH(&options)
|
||||||
|
}
|
||||||
|
up := &sshUpstream{
|
||||||
|
SSHUpstreamOptions: options,
|
||||||
|
serverPort: values.Deferred[int](),
|
||||||
|
serverConnCallback: closeConnCallback, // default handler, to avoid hanging connections
|
||||||
|
}
|
||||||
|
up.RecordCaller()
|
||||||
|
return up
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Addr() values.Value[string] {
|
||||||
|
return values.Bind(h.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) SetServerConnCallback(callback ServerConnCallback) {
|
||||||
|
h.serverConnCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Route() testenv.RouteStub {
|
||||||
|
r := &testenv.PolicyRoute{}
|
||||||
|
protocol := "ssh"
|
||||||
|
r.To(values.Bind(h.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
|
||||||
|
}))
|
||||||
|
h.Add(r)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Run(ctx context.Context) error {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
// The testenv cleanup expects this function to return a "test cleanup" error,
|
||||||
|
// which propagates via the context.
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return context.Cause(ctx)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go h.handleConnection(ctx, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) {
|
||||||
|
serverConn, ncc, rc, err := ssh.NewServerConn(conn, &h.serverConfig)
|
||||||
|
if err != nil {
|
||||||
|
// XXX: figure out the right way to log this
|
||||||
|
log.Println("ssh connection handshake unsuccessful:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
h.serverConnCallback(serverConn, ncc, rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
return ssh.Dial("tcp", h.Env().Config().Options.SSHAddr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DirectDial implements SSHUpstream.
|
||||||
|
func (h *sshUpstream) DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", h.serverPort.Value())
|
||||||
|
return ssh.Dial("tcp", addr, config)
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue