diff --git a/integration2/ssh_int_test.go b/integration2/ssh_int_test.go new file mode 100644 index 000000000..ff9acc1ec --- /dev/null +++ b/integration2/ssh_int_test.go @@ -0,0 +1,67 @@ +package ssh + +import ( + "crypto/ed25519" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" +) + +func TestSSH(t *testing.T) { + clientKey := newSSHKey(t) + serverHostKey := newSSHKey(t) + + // ssh client setup + var ki scenarios.EmptyKeyboardInteractiveChallenge + clientConfig := &ssh.ClientConfig{ + User: "demo", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(clientKey), + ssh.KeyboardInteractive(ki.Do), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + // pomerium + upstream setup + env := testenv.New(t) + + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}, scenarios.WithEnableDeviceAuth(true))) + env.Add(scenarios.SSH(scenarios.SSHConfig{})) + env.Add(&ki) + + up := upstreams.SSH( + upstreams.WithHostKeys(serverHostKey), + upstreams.WithAuthorizedKey(clientKey.PublicKey(), "demo")) + r := up.Route(). + From(env.SubdomainURLWithScheme("ssh", "ssh")). + Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true }) + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + // verify that a connection can be established + client, err := up.Dial(r, clientConfig) + require.NoError(t, err) + defer client.Close() + + sess, err := client.NewSession() + require.NoError(t, err) + defer sess.Close() +} + +// newSSHKey generates and returns a new Ed25519 ssh key. +func newSSHKey(t *testing.T) ssh.Signer { + t.Helper() + _, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + signer, err := ssh.NewSignerFromKey(priv) + require.NoError(t, err) + return signer +} diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index dce233b2a..dcf47b552 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -104,6 +104,8 @@ type Environment interface { SharedSecret() []byte CookieSecret() []byte + Config() *config.Config + // Add adds the given [Modifier] to the environment. All modifiers will be // invoked upon calling Start() to apply individual modifications to the // configuration before starting the Pomerium server. @@ -141,6 +143,13 @@ type Environment interface { // can be used as the 'from' value for routes. SubdomainURL(subdomain string) values.Value[string] + // SubdomainURL returns a string [values.Value] which will contain a complete + // URL for the given subdomain of the server's domain (given by its serving + // certificate), with the given scheme and the random http server port. + // This value will only be resolved some time after Start() is called, and + // can be used as the 'from' value for routes. + SubdomainURLWithScheme(subdomain, scheme string) values.Value[string] + // NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for // the Pomerium server and Envoy. NewLogRecorder(opts ...LogRecorderOption) *LogRecorder @@ -510,8 +519,12 @@ func (e *environment) Require() *require.Assertions { } func (e *environment) SubdomainURL(subdomain string) values.Value[string] { + return e.SubdomainURLWithScheme(subdomain, "https") +} + +func (e *environment) SubdomainURLWithScheme(subdomain, scheme string) values.Value[string] { return values.Bind(e.ports.ProxyHTTP, func(port int) string { - return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port) + return fmt.Sprintf("%s://%s.%s:%d", scheme, subdomain, e.domain, port) }) } @@ -536,6 +549,10 @@ func (e *environment) Host() string { return e.host } +func (e *environment) Config() *config.Config { + return e.src.GetConfig() +} + func (e *environment) CACert() *tls.Certificate { caCert, err := tls.LoadX509KeyPair( filepath.Join(e.tempDir, "certs", "ca.pem"), diff --git a/internal/testenv/scenarios/mock_idp.go b/internal/testenv/scenarios/mock_idp.go index 9e07b70e1..b3dce31ed 100644 --- a/internal/testenv/scenarios/mock_idp.go +++ b/internal/testenv/scenarios/mock_idp.go @@ -20,6 +20,8 @@ import ( "strings" "time" + "golang.org/x/oauth2" + "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" "github.com/google/uuid" @@ -45,7 +47,8 @@ type IDP struct { } type IDPOptions struct { - enableTLS bool + enableTLS bool + enableDeviceAuth bool } type IDPOption func(*IDPOptions) @@ -62,6 +65,12 @@ func WithEnableTLS(enableTLS bool) IDPOption { } } +func WithEnableDeviceAuth(enableDeviceAuth bool) IDPOption { + return func(o *IDPOptions) { + o.enableDeviceAuth = enableDeviceAuth + } +} + // Attach implements testenv.Modifier. func (idp *IDP) Attach(ctx context.Context) { env := testenv.EnvFromContext(ctx) @@ -120,7 +129,7 @@ func (idp *IDP) Attach(ctx context.Context) { router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send() rootURL, _ := url.Parse(idp.url.Value()) - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + config := map[string]interface{}{ "issuer": rootURL.String(), "authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(), "token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(), @@ -129,11 +138,19 @@ func (idp *IDP) Attach(ctx context.Context) { "id_token_signing_alg_values_supported": []string{ "ES256", }, - }) + } + if idp.enableDeviceAuth { + config["device_authorization_endpoint"] = + rootURL.ResolveReference(&url.URL{Path: "/oidc/device/code"}).String() + } + serveJSON(w, config) }) router.Handle("/oidc/auth", idp.HandleAuth) router.Handle("/oidc/token", idp.HandleToken) router.Handle("/oidc/userinfo", idp.HandleUserInfo) + if idp.enableDeviceAuth { + router.Handle("/oidc/device/code", idp.HandleDeviceCode) + } env.AddUpstream(router) } @@ -258,14 +275,25 @@ func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) { // HandleToken handles the token flow for OIDC. func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) { - rawCode := r.FormValue("code") + if idp.enableDeviceAuth && r.FormValue("device_code") != "" { + idp.serveToken(w, r, &State{ + ClientID: r.FormValue("client_id"), + Email: "fake.user@example.com", + }) + return + } + rawCode := r.FormValue("code") state, err := DecodeState(rawCode) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } + idp.serveToken(w, r, state) +} + +func (idp *IDP) serveToken(w http.ResponseWriter, r *http.Request, state *State) { serveJSON(w, map[string]interface{}{ "access_token": state.Encode(), "refresh_token": state.Encode(), @@ -300,6 +328,30 @@ func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) { serveJSON(w, state.GetUserInfo(idp.userLookup)) } +// HandleDeviceCode initiates a device auth code flow. +// +// This is the bare minimum to simulate the device auth code flow. There is no client_id +// verification or any actual login. +func (idp *IDP) HandleDeviceCode(w http.ResponseWriter, r *http.Request) { + deviceCode := "GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS" + userCode := "ABCD-EFGH" + + rootURL, _ := url.Parse(idp.url.Value()) + u := rootURL.ResolveReference(&url.URL{Path: "/oidc/device"}) // note: not actually implemented + verificationURI := u.String() + u.RawQuery = "user_code=" + userCode + verificationURIComplete := u.String() + + serveJSON(w, &oauth2.DeviceAuthResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: verificationURI, + VerificationURIComplete: verificationURIComplete, + Expiry: time.Now().Add(5 * time.Minute), + Interval: 1, + }) +} + type RootURLKey struct{} var rootURLKey RootURLKey diff --git a/internal/testenv/scenarios/ssh.go b/internal/testenv/scenarios/ssh.go new file mode 100644 index 000000000..d31901459 --- /dev/null +++ b/internal/testenv/scenarios/ssh.go @@ -0,0 +1,114 @@ +package scenarios + +import ( + "context" + "crypto/ed25519" + "encoding/pem" + "fmt" + "math/rand/v2" + "os" + "path/filepath" + + "golang.org/x/crypto/ssh" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/pkg/slices" +) + +type SSHConfig struct { + // SSH listener address. Defaults to ":2200" if not set. + Addr string + + Hostname string + + // Host key(s). An Ed25519 key will be generated if not set. + // Elements must be of a type supported by [ssh.NewSignerFromKey]. + HostKeys []any + + // User CA key, for signing SSH certificates used to authenticate to an + // upstream. An Ed25519 key will be generated if not set. + // Must be a type supported by [ssh.NewSignerFromKey]. + UserCAKey any +} + +func SSH(c SSHConfig) testenv.Modifier { + return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) { + env := testenv.EnvFromContext(ctx) + + // Apply defaults. + if c.Addr == "" { + c.Addr = ":2200" + } + if len(c.HostKeys) == 0 { + c.HostKeys = []any{newEd25519Key(env)} + } + if c.Hostname == "" { + // XXX: is there a reasonable default for this? + } + if c.UserCAKey == nil { + c.UserCAKey = newEd25519Key(env) + } + + // Update configuration. + cfg.Options.SSHAddr = c.Addr + cfg.Options.SSHHostname = c.Hostname + cfg.Options.SSHHostKeys = slices.Map(c.HostKeys, func(key any) config.SSHKeyPair { + return writeSSHKeyPair(env, key) + }) + cfg.Options.SSHUserCAKey = writeSSHKeyPair(env, c.UserCAKey) + }) +} + +func newEd25519Key(env testenv.Environment) ed25519.PrivateKey { + _, priv, err := ed25519.GenerateKey(nil) + env.Require().NoError(err) + return priv +} + +// writeSSHKeyPair takes a private key and writes SSH private and public key +// files to the test env temp directory, returning a [config.SSHKeyPair] with +// the written filenames. The key must be of a type supported by the +// [ssh.NewSignerFromKey] method. +func writeSSHKeyPair(env testenv.Environment, key any) config.SSHKeyPair { + signer, err := ssh.NewSignerFromKey(key) + pub := signer.PublicKey() + env.Require().NoError(err) + + dir := env.TempDir() + basename := fmt.Sprintf("ssh-key-%d", rand.Int()) + privname := filepath.Join(dir, basename) + pubname := privname + ".pub" + + // marshal and write private key to disk + pemBlock, err := ssh.MarshalPrivateKey(key, "") + env.Require().NoError(err) + privkeyContents := pem.EncodeToMemory(pemBlock) + err = os.WriteFile(privname, privkeyContents, 0o600) + env.Require().NoError(err) + + // marshal and write public key to disk + pubkeyContents := ssh.MarshalAuthorizedKey(pub) + err = os.WriteFile(pubname, pubkeyContents, 0o600) + env.Require().NoError(err) + + return config.SSHKeyPair{ + PublicKeyFile: pubname, + PrivateKeyFile: privname, + } +} + +// EmptyKeyboardInteractiveChallenge responds to any keyboard-interactive +// challenges with zero prompts, and fails otherwise. +type EmptyKeyboardInteractiveChallenge struct { + testenv.DefaultAttach +} + +func (c *EmptyKeyboardInteractiveChallenge) Do( + name, instruction string, questions []string, echos []bool, +) (answers []string, err error) { + if len(questions) > 0 { + c.Env().Require().FailNow("unsupported keyboard-interactive challenge") + } + return nil, nil +} diff --git a/internal/testenv/types.go b/internal/testenv/types.go index 959fd9a33..dfbec80ab 100644 --- a/internal/testenv/types.go +++ b/internal/testenv/types.go @@ -69,6 +69,8 @@ func (d *DefaultAttach) RecordCaller() { d.caller = getCaller(4) } +func (d *DefaultAttach) Modify(*config.Config) {} + // Aggregate should be embedded in types implementing [Modifier] when the type // contains other modifiers. Used as an alternative to [DefaultAttach]. // Embedding this struct will properly keep track of when constituent modifiers diff --git a/internal/testenv/upstreams/ssh.go b/internal/testenv/upstreams/ssh.go new file mode 100644 index 000000000..a54ec6165 --- /dev/null +++ b/internal/testenv/upstreams/ssh.go @@ -0,0 +1,211 @@ +package upstreams + +import ( + "context" + "fmt" + "log" + "net" + + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/values" + "golang.org/x/crypto/ssh" +) + +type SSHUpstreamOptions struct { + displayName string + serverConfig ssh.ServerConfig + authorizedKeys authorizedKeysChecker +} + +type SSHUpstreamOption interface { + applySSH(*SSHUpstreamOptions) +} + +type sshUpstreamOption func(o *SSHUpstreamOptions) + +func (s sshUpstreamOption) applySSH(o *SSHUpstreamOptions) { s(o) } + +func WithPublicKeyAuthAlgorithms(algs ...string) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + o.serverConfig.PublicKeyAuthAlgorithms = algs + }) +} + +func WithHostKeys(keys ...ssh.Signer) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + for _, key := range keys { + o.serverConfig.AddHostKey(key) + } + }) +} + +func WithBannerCallback(c func(ssh.ConnMetadata) string) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + o.serverConfig.BannerCallback = c + }) +} + +// WithPublicKeyCallback sets a custom callback for the publickey authentication method. +// This will override any previous [WithAuthorizedKey] option. +func WithPublicKeyCallback(c func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + o.serverConfig.PublicKeyCallback = c + }) +} + +// WithAuthorizedKey allows the given key to be used to authenticate the given username, +// enabling the publickey authentication method. This will override any previous +// [WithPublicKeyCallback] option. +func WithAuthorizedKey(key ssh.PublicKey, username string) SSHUpstreamOption { + return sshUpstreamOption(func(o *SSHUpstreamOptions) { + o.authorizedKeys.add(key, username) + o.serverConfig.PublicKeyCallback = o.authorizedKeys.check + }) +} + +type authorizedKeysChecker map[string]string // map from marshaled public key to corresponding username + +func (c *authorizedKeysChecker) add(key ssh.PublicKey, username string) { + if *c == nil { + *c = make(map[string]string) + } + (*c)[string(key.Marshal())] = username +} + +func (c authorizedKeysChecker) check(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if c[string(key.Marshal())] == conn.User() { + return &ssh.Permissions{}, nil + } + return nil, fmt.Errorf("not authorized") +} + +type ServerConnCallback func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request) + +var closeConnCallback ServerConnCallback = func(conn *ssh.ServerConn, ncc <-chan ssh.NewChannel, rq <-chan *ssh.Request) { + conn.Close() +} + +// SSHUpstream represents an ssh server which can be used as the target for +// one or more Pomerium routes in a test environment. +// +// Use SetServerConnCallback() to define the behavior of this server once +// a connection is established. +// +// Dial() can be used to make a client-side connection through the Pomerium +// route, while DirectDial() can be used to connect bypassing Pomerium. +type SSHUpstream interface { + testenv.Upstream + + SetServerConnCallback(callback ServerConnCallback) + + Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) + DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) +} + +type sshUpstream struct { + SSHUpstreamOptions + testenv.Aggregate + serverPort values.MutableValue[int] + + // XXX: does it make sense to cache clients? + //clientCache sync.Map // map[testenv.Route]*ssh.Client + + serverConnCallback ServerConnCallback +} + +var ( + _ testenv.Upstream = (*sshUpstream)(nil) + _ SSHUpstream = (*sshUpstream)(nil) +) + +// SSH creates a new ssh upstream server. +func SSH(opts ...SSHUpstreamOption) SSHUpstream { + options := SSHUpstreamOptions{ + displayName: "SSH Upstream", + } + for _, op := range opts { + op.applySSH(&options) + } + up := &sshUpstream{ + SSHUpstreamOptions: options, + serverPort: values.Deferred[int](), + serverConnCallback: closeConnCallback, // default handler, to avoid hanging connections + } + up.RecordCaller() + return up +} + +// Addr implements SSHUpstream. +func (h *sshUpstream) Addr() values.Value[string] { + return values.Bind(h.serverPort, func(port int) string { + return fmt.Sprintf("%s:%d", h.Env().Host(), port) + }) +} + +// Router implements SSHUpstream. +func (h *sshUpstream) SetServerConnCallback(callback ServerConnCallback) { + h.serverConnCallback = callback +} + +// Route implements SSHUpstream. +func (h *sshUpstream) Route() testenv.RouteStub { + r := &testenv.PolicyRoute{} + protocol := "ssh" + r.To(values.Bind(h.serverPort, func(port int) string { + return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port) + })) + h.Add(r) + return r +} + +// Run implements SSHUpstream. +func (h *sshUpstream) Run(ctx context.Context) error { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) + + go func() { + <-ctx.Done() + listener.Close() + }() + + for { + conn, err := listener.Accept() + if err != nil { + // The testenv cleanup expects this function to return a "test cleanup" error, + // which propagates via the context. + if ctx.Err() != nil { + return context.Cause(ctx) + } + return err + } + go h.handleConnection(ctx, conn) + } +} + +func (h *sshUpstream) handleConnection(ctx context.Context, conn net.Conn) { + serverConn, ncc, rc, err := ssh.NewServerConn(conn, &h.serverConfig) + if err != nil { + // XXX: figure out the right way to log this + log.Println("ssh connection handshake unsuccessful:", err) + return + } + go func() { + <-ctx.Done() + conn.Close() + }() + h.serverConnCallback(serverConn, ncc, rc) +} + +// Dial implements SSHUpstream. +func (h *sshUpstream) Dial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) { + return ssh.Dial("tcp", h.Env().Config().Options.SSHAddr, config) +} + +// DirectDial implements SSHUpstream. +func (h *sshUpstream) DirectDial(r testenv.Route, config *ssh.ClientConfig) (*ssh.Client, error) { + addr := fmt.Sprintf("127.0.0.1:%d", h.serverPort.Value()) + return ssh.Dial("tcp", addr, config) +}