diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index 083dec464..692c876fb 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -1,16 +1,89 @@ package authorize import ( + "errors" + "io" + + "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/pkg/storage" ) -func (a *Authorize) ManageStream(extensions_ssh.StreamManagement_ManageStreamServer) error { - return status.Errorf(codes.Unimplemented, "method ManageStream not implemented") +func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageStreamServer) error { + event, err := stream.Recv() + if err != nil { + return err + } + // first message should be a downstream connected event + downstream := event.GetEvent().GetDownstreamConnected() + if downstream == nil { + return status.Errorf(codes.Internal, "first message was not a downstream connected event") + } + handler := a.state.Load().ssh.NewStreamHandler(a.currentConfig.Load(), downstream) + defer handler.Close() + + eg, ctx := errgroup.WithContext(stream.Context()) + querier := storage.NewCachingQuerier( + storage.NewQuerier(a.state.Load().dataBrokerClient), + storage.GlobalCache, + ) + ctx = storage.WithQuerier(ctx, querier) + + eg.Go(func() error { + for { + req, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + handler.ReadC() <- req + } + }) + + eg.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case msg := <-handler.WriteC(): + if err := stream.Send(msg); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + } + } + }) + + return handler.Run(ctx) } -func (a *Authorize) ServeChannel(extensions_ssh.StreamManagement_ServeChannelServer) error { - return status.Errorf(codes.Unimplemented, "method ServeChannel not implemented") +func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeChannelServer) error { + metadata, err := stream.Recv() + if err != nil { + return err + } + // first message contains metadata + var streamID uint64 + if md := metadata.GetMetadata(); md != nil { + var typedMd extensions_ssh.FilterMetadata + if err := md.GetTypedFilterMetadata()["com.pomerium.ssh"].UnmarshalTo(&typedMd); err != nil { + return err + } + streamID = typedMd.StreamId + } else { + return status.Errorf(codes.Internal, "first message was not metadata") + } + handler := a.state.Load().ssh.LookupStream(streamID) + if handler == nil || !handler.IsExpectingInternalChannel() { + return status.Errorf(codes.InvalidArgument, "stream not found") + } + + return handler.ServeChannel(stream) } diff --git a/authorize/state.go b/authorize/state.go index 3df2e24a8..bc0a37364 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -20,6 +20,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/ssh" "github.com/pomerium/pomerium/pkg/storage" ) @@ -39,6 +40,7 @@ type authorizeState struct { authenticateFlow authenticateFlow syncQueriers map[string]storage.Querier mcp *mcp.Handler + ssh *ssh.StreamManager } func newAuthorizeStateFromConfig( @@ -70,6 +72,8 @@ func newAuthorizeStateFromConfig( evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp)) } + state.ssh = ssh.NewStreamManager(nil) // XXX + state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...) if err != nil { return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) diff --git a/config/envoyconfig/listeners_ssh.go b/config/envoyconfig/listeners_ssh.go index b45fd18e4..6a7b7594d 100644 --- a/config/envoyconfig/listeners_ssh.go +++ b/config/envoyconfig/listeners_ssh.go @@ -194,5 +194,5 @@ func buildRouteConfig(cfg *config.Config) (*envoy_generic_proxy_v3.RouteConfigur } func shouldStartSSHListener(options *config.Options) bool { - return config.IsAuthorize(options.Services) + return config.IsProxy(options.Services) } diff --git a/config/policy.go b/config/policy.go index 847c3de34..a820ec853 100644 --- a/config/policy.go +++ b/config/policy.go @@ -932,7 +932,7 @@ func (p *Policy) IsUDPUpstream() bool { // IsSSH returns true if the route is for SSH. func (p *Policy) IsSSH() bool { - return len(p.From) > 0 && strings.HasPrefix(p.From, "ssh://") + return strings.HasPrefix(p.From, "ssh://") } // AllAllowedDomains returns all the allowed domains. diff --git a/config/runtime_flags.go b/config/runtime_flags.go index 930b44a4d..aaaa1721f 100644 --- a/config/runtime_flags.go +++ b/config/runtime_flags.go @@ -32,6 +32,8 @@ var ( // RuntimeFlagMCP enables the MCP services for the authorize service RuntimeFlagMCP = runtimeFlag("mcp", false) + + RuntimeFlagSSHRoutesPortal = runtimeFlag("ssh_routes_portal", false) ) // RuntimeFlag is a runtime flag that can flip on/off certain features diff --git a/go.mod b/go.mod index 02ea5c007..d27b17710 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,10 @@ require ( github.com/bufbuild/protovalidate-go v0.10.1 github.com/caddyserver/certmagic v0.23.0 github.com/cenkalti/backoff/v4 v4.3.0 + github.com/charmbracelet/bubbles v0.21.0 + github.com/charmbracelet/bubbletea v1.3.4 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/charmbracelet/x/ansi v0.8.0 github.com/cloudflare/circl v1.6.1 github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f github.com/cockroachdb/pebble/v2 v2.0.4 @@ -55,7 +59,7 @@ require ( github.com/pires/go-proxyproto v0.8.1 github.com/pomerium/csrf v1.7.0 github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 - github.com/pomerium/envoy-custom v1.33.1-0.20250618175753-a0feae248696 + github.com/pomerium/envoy-custom v1.34.1-rc2.0.20250625214310-c029d58dae62 github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46 github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172 github.com/prometheus/client_golang v1.22.0 @@ -127,6 +131,7 @@ require ( github.com/andybalholm/brotli v1.0.5 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect + github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect @@ -142,10 +147,14 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/smithy-go v1.22.2 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cenkalti/backoff/v5 v5.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/cockroachdb/crlib v0.0.0-20241015224233-894974b3ad94 // indirect github.com/cockroachdb/errors v1.11.3 // indirect github.com/cockroachdb/fifo v0.0.0-20240606204812-0bbfbd93a7ce // indirect @@ -164,6 +173,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.2 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/getsentry/sentry-go v0.27.0 // indirect @@ -198,10 +208,13 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240624235931-330eb762e74c // indirect github.com/libdns/libdns v1.0.0-beta.1 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/minio/crc64nvme v1.0.1 // indirect github.com/minio/md5-simd v1.1.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect @@ -213,6 +226,9 @@ require ( github.com/moby/sys/userns v0.1.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/ginkgo/v2 v2.19.1 // indirect @@ -228,9 +244,11 @@ require ( github.com/prometheus/statsd_exporter v0.22.7 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/rs/xid v1.6.0 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect + github.com/sahilm/fuzzy v0.1.1 // indirect github.com/shirou/gopsutil/v4 v4.25.1 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect @@ -253,6 +271,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yashtewari/glob-intersection v0.2.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/zeebo/assert v1.3.1 // indirect diff --git a/go.sum b/go.sum index c77f4d633..1966ccc66 100644 --- a/go.sum +++ b/go.sum @@ -104,6 +104,8 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7D github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= @@ -140,6 +142,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/Xv github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= +github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -164,6 +170,22 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= +github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= +github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI= +github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE= +github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -244,6 +266,8 @@ github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJP github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/exaring/otelpgx v0.9.4-0.20250625070127-170cf59316c5 h1:x/jxx2ODOrUlmVHnb2eGzFWs6h2TpOk/+W9YYTDbamI= github.com/exaring/otelpgx v0.9.4-0.20250625070127-170cf59316c5/go.mod h1:R5/M5LWsPPBZc1SrRE5e0DiU48bI78C1/GPTWs6I66U= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -475,6 +499,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/libdns/libdns v1.0.0-beta.1 h1:KIf4wLfsrEpXpZ3vmc/poM8zCATXT2klbdPe6hyOBjQ= github.com/libdns/libdns v1.0.0-beta.1/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= @@ -487,6 +513,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mholt/acmez/v3 v3.1.2 h1:auob8J/0FhmdClQicvJvuDavgd5ezwLBfKuYmynhYzc= github.com/mholt/acmez/v3 v3.1.2/go.mod h1:L1wOU06KKvq7tswuMDwKdcHeKpFFgkppZy/y0DFxagQ= @@ -525,6 +555,12 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -579,8 +615,8 @@ github.com/pomerium/csrf v1.7.0 h1:Qp4t6oyEod3svQtKfJZs589mdUTWKVf7q0PgCKYCshY= github.com/pomerium/csrf v1.7.0/go.mod h1:hAPZV47mEj2T9xFs+ysbum4l7SF1IdrryYaY6PdoIqw= github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb54tEEbr0L73rjHkpLB0IB6qh3zl1+XQbMLis= github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0= -github.com/pomerium/envoy-custom v1.33.1-0.20250618175753-a0feae248696 h1:ojei2rggKHZYnDQyCbjeG2mdyqCW8E2tZpxOuiDBwxc= -github.com/pomerium/envoy-custom v1.33.1-0.20250618175753-a0feae248696/go.mod h1:+wpbZvum83bq/OD4cp9/8IZiMV6boBkwDhlFPLOoWoI= +github.com/pomerium/envoy-custom v1.34.1-rc2.0.20250625214310-c029d58dae62 h1:H0UYd/lI+U/+TZC3vZ+6jeSCuaNiAc67GBhZuXbfEVw= +github.com/pomerium/envoy-custom v1.34.1-rc2.0.20250625214310-c029d58dae62/go.mod h1:+wpbZvum83bq/OD4cp9/8IZiMV6boBkwDhlFPLOoWoI= github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46 h1:NRTg8JOXCxcIA1lAgD74iYud0rbshbWOB3Ou4+Huil8= github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46/go.mod h1:QqZmx6ZgPxz18va7kqoT4t/0yJtP7YFIDiT/W2n2fZ4= github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172 h1:TqoPqRgXSHpn+tEJq6H72iCS5pv66j3rPprThUEZg0E= @@ -628,6 +664,9 @@ github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47p github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -641,6 +680,8 @@ github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= +github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= +github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= @@ -727,6 +768,8 @@ github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMc github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yashtewari/glob-intersection v0.2.0 h1:8iuHdN88yYuCzCdjt0gDe+6bAhUwBeEWqThExu54RFg= github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Yml7g8xa5bsjfx2v1fwok= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -968,6 +1011,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/ssh/channel.go b/pkg/ssh/channel.go new file mode 100644 index 000000000..165b484dd --- /dev/null +++ b/pkg/ssh/channel.go @@ -0,0 +1,266 @@ +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "iter" + "slices" + "strings" + "sync" + "time" + + tea "github.com/charmbracelet/bubbletea" + gossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/log" +) + +type ChannelControlInterface interface { + StreamHandlerInterface + SendControlAction(*extensions_ssh.SSHChannelControlAction) error + SendMessage(any) error + RecvMsg() (any, error) +} + +type StreamHandlerInterface interface { + PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) + FormatSession(ctx context.Context) ([]byte, error) + DeleteSession(ctx context.Context) error + AllSSHRoutes() iter.Seq[*config.Policy] + Hostname() *string + Username() *string + DownstreamChannelID() uint32 +} + +type ChannelHandler struct { + ctrl ChannelControlInterface + config *config.Config + cli *CLI + ptyInfo *extensions_ssh.SSHDownstreamPTYInfo + stdinR io.Reader + stdinW io.Writer + stdoutR io.Reader + stdoutW io.WriteCloser + cancel context.CancelCauseFunc + stdoutStreamDone chan struct{} + + sendChannelCloseMsgOnce sync.Once +} + +func (ch *ChannelHandler) Run(ctx context.Context) error { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + ch.stdinR, ch.stdinW, ch.stdoutR, ch.stdoutW = stdinR, stdinW, stdoutR, stdoutW + + recvC := make(chan any) + ctx, ch.cancel = context.WithCancelCause(ctx) + go func() { + for { + msg, err := ch.ctrl.RecvMsg() + if err != nil { + if !errors.Is(err, io.EOF) { + ch.cancel(err) + } + return + } + select { + case recvC <- msg: + case <-ctx.Done(): + return + } + } + }() + ch.stdoutStreamDone = make(chan struct{}) + go func() { + defer close(ch.stdoutStreamDone) + var buf [4096]byte + channelID := ch.ctrl.DownstreamChannelID() + for { + n, err := ch.stdoutR.Read(buf[:]) + if err != nil { + if !errors.Is(err, io.EOF) { + ch.cancel(err) + } + return + } + msg := ChannelDataMsg{ + PeersID: channelID, + Length: uint32(n), + Rest: slices.Clone(buf[:n]), + } + if err := ch.ctrl.SendMessage(msg); err != nil { + ch.cancel(err) + return + } + } + }() + + for { + select { + case msg := <-recvC: + switch msg := msg.(type) { + case ChannelRequestMsg: + if err := ch.handleChannelRequestMsg(ctx, msg); err != nil { + ch.cancel(err) + } + case ChannelDataMsg: + if err := ch.handleChannelDataMsg(msg); err != nil { + ch.cancel(err) + } + case ChannelCloseMsg: + ch.sendChannelCloseMsgOnce.Do(func() { + ch.flushStdout() + ch.sendChannelCloseMsg() + }) + ch.cancel(status.Errorf(codes.Canceled, "channel closed")) + case ChannelEOFMsg: + log.Ctx(ctx).Debug().Msg("ssh: received channel EOF") + default: + panic(fmt.Sprintf("bug: unhandled message type: %T", msg)) + } + case <-ctx.Done(): + return context.Cause(ctx) + } + } +} + +func (ch *ChannelHandler) flushStdout() { + ch.stdoutW.Close() + <-ch.stdoutStreamDone // ensure all output is written before sending the channel close message +} + +func (ch *ChannelHandler) sendChannelCloseMsg() { + _ = ch.ctrl.SendMessage(ChannelCloseMsg{ + PeersID: ch.ctrl.DownstreamChannelID(), + }) +} + +func (ch *ChannelHandler) sendExitStatus(err error) { + var code byte + if err != nil { + code = 1 + } + _ = ch.ctrl.SendMessage(ChannelRequestMsg{ + PeersID: ch.ctrl.DownstreamChannelID(), + Request: "exit-status", + WantReply: false, + RequestSpecificData: []byte{0x0, 0x0, 0x0, code}, + }) +} + +func (ch *ChannelHandler) initiateChannelClose(err error) { + ch.sendChannelCloseMsgOnce.Do(func() { + ch.flushStdout() + ch.sendExitStatus(err) + ch.sendChannelCloseMsg() + // the client needs to respond to our close request before we send a + // disconnect in order to get a clean exit, but if they don't respond in + // a timely manner we will disconnect anyway + time.AfterFunc(5*time.Second, func() { + ch.cancel(status.Errorf(codes.DeadlineExceeded, "timed out waiting for channel close")) + }) + }) +} + +func (ch *ChannelHandler) handleChannelRequestMsg(ctx context.Context, msg ChannelRequestMsg) error { + switch msg.Request { + case "shell", "exec": + if ch.cli != nil { + return status.Errorf(codes.FailedPrecondition, "unexpected channel request: %s", msg.Request) + } + ch.cli = NewCLI(ch.config, ch.ctrl, ch.ptyInfo, ch.stdinR, ch.stdoutW) + switch msg.Request { + case "shell": + if ch.config.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { + ch.cli.SetArgs([]string{"portal"}) + } + case "exec": + var execReq ExecChannelRequestMsg + if err := gossh.Unmarshal(msg.RequestSpecificData, &execReq); err != nil { + return status.Errorf(codes.InvalidArgument, "malformed exec channel request") + } + ch.cli.SetArgs(strings.Fields(execReq.Command)) + } + if msg.WantReply { + if err := ch.ctrl.SendMessage(ChannelRequestSuccessMsg{ + PeersID: ch.ctrl.DownstreamChannelID(), + }); err != nil { + return err + } + } + go func() { + err := ch.cli.ExecuteContext(ctx) + if errors.Is(err, ErrHandoff) { + return // don't disconnect + } + ch.initiateChannelClose(err) + }() + case "env": + log.Ctx(ctx).Warn().Msg("ssh: env channel requests are not implemented yet") + case "pty-req": + if ch.cli != nil || ch.ptyInfo != nil { + return status.Errorf(codes.FailedPrecondition, "unexpected channel request: %s", msg.Request) + } + var ptyReq PtyReqChannelRequestMsg + if err := gossh.Unmarshal(msg.RequestSpecificData, &ptyReq); err != nil { + return status.Errorf(codes.InvalidArgument, "malformed pty-req channel request") + } + ch.ptyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ + TermEnv: ptyReq.TermEnv, + WidthColumns: ptyReq.Width, + HeightRows: ptyReq.Height, + WidthPx: ptyReq.WidthPx, + HeightPx: ptyReq.HeightPx, + Modes: ptyReq.Modes, + } + if msg.WantReply { + if err := ch.ctrl.SendMessage(ChannelRequestSuccessMsg{ + PeersID: ch.ctrl.DownstreamChannelID(), + }); err != nil { + return err + } + } + case "window-change": + if ch.cli == nil || ch.ptyInfo == nil { + return status.Errorf(codes.InvalidArgument, "unexpected channel request: window-change") + } + var req ChannelWindowChangeRequestMsg + if err := gossh.Unmarshal(msg.RequestSpecificData, &req); err != nil { + return status.Errorf(codes.InvalidArgument, "malformed window-change channel request") + } + ch.cli.SendTeaMsg(tea.WindowSizeMsg{ + Width: int(req.WidthColumns), + Height: int(req.HeightRows), + }) + // https://datatracker.ietf.org/doc/html/rfc4254#section-6.7: + // A response SHOULD NOT be sent to this message. + default: + return status.Errorf(codes.InvalidArgument, "unknown channel request: %s", msg.Request) + } + return nil +} + +func (ch *ChannelHandler) handleChannelDataMsg(msg ChannelDataMsg) error { + if ch.cli == nil { + return status.Errorf(codes.FailedPrecondition, "unexpected ChannelDataMsg") + } + _, err := ch.stdinW.Write(msg.Rest) + if err != nil { + return err + } + return nil +} + +func NewChannelHandler(ctrl ChannelControlInterface, cfg *config.Config) *ChannelHandler { + ch := &ChannelHandler{ + ctrl: ctrl, + config: cfg, + } + return ch +} diff --git a/pkg/ssh/channel_impl.go b/pkg/ssh/channel_impl.go new file mode 100644 index 000000000..2e06d573b --- /dev/null +++ b/pkg/ssh/channel_impl.go @@ -0,0 +1,202 @@ +package ssh + +import ( + "context" + "sync" + + gossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +type ChannelImpl struct { + StreamHandlerInterface + info *extensions_ssh.SSHDownstreamChannelInfo + stream extensions_ssh.StreamManagement_ServeChannelServer + remoteWindow *Window + localWindow uint32 +} + +func NewChannelImpl( + sh StreamHandlerInterface, + stream extensions_ssh.StreamManagement_ServeChannelServer, + info *extensions_ssh.SSHDownstreamChannelInfo, +) *ChannelImpl { + remoteWindow := &Window{Cond: sync.NewCond(&sync.Mutex{})} + remoteWindow.add(info.InitialWindowSize) + context.AfterFunc(stream.Context(), func() { + remoteWindow.close() + }) + channel := &ChannelImpl{ + StreamHandlerInterface: sh, + info: info, + stream: stream, + remoteWindow: remoteWindow, + localWindow: ChannelWindowSize, + } + return channel +} + +// SendControlAction implements ChannelControlInterface. +func (ci *ChannelImpl) SendControlAction(action *extensions_ssh.SSHChannelControlAction) error { + return ci.stream.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_ChannelControl{ + ChannelControl: &extensions_ssh.ChannelControl{ + Protocol: "ssh", + ControlAction: protoutil.NewAny(action), + }, + }, + }) +} + +// SendMessage implements ChannelControlInterface. +func (ci *ChannelImpl) SendMessage(msg any) error { + switch msg := msg.(type) { + case ChannelOpenConfirmMsg, WindowAdjustMsg, ChannelRequestMsg, + ChannelRequestSuccessMsg, ChannelRequestFailureMsg, ChannelEOFMsg: + // these messages don't consume window space + data := gossh.Marshal(msg) + if err := ci.stream.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes(data), + }, + }); err != nil { + return err + } + log.Ctx(ci.stream.Context()).Debug(). + Uint8("type", data[0]). + Msg("ssh: message sent") + return nil + default: + data := gossh.Marshal(msg) + need := uint32(len(data)) + have := uint32(0) + for have < need { + n, err := ci.remoteWindow.reserve(need - have) + if err != nil { + return status.Errorf(codes.Internal, "stream closed") + } + have += n + } + if err := ci.stream.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes(data), + }, + }); err != nil { + return err + } + log.Ctx(ci.stream.Context()).Debug(). + Uint8("type", data[0]). + Uint32("size", need). + Msg("ssh: message sent") + return nil + } +} + +func (ci *ChannelImpl) RecvMsg() (any, error) { + for { + msgID, msg, err := ci.recvMsg() + switch msgID { + case MsgChannelWindowAdjust: + // handle this internally and skip to the next message + continue + default: + return msg, err + } + } +} + +func (ci *ChannelImpl) recvMsg() (byte, any, error) { + channelMsg, err := ci.stream.Recv() + if err != nil { + return 0, nil, err + } + switch channelMsg := channelMsg.Message.(type) { + case *extensions_ssh.ChannelMessage_RawBytes: + msgLen := uint32(len(channelMsg.RawBytes.GetValue())) + if msgLen == 0 { + return 0, nil, status.Errorf(codes.InvalidArgument, "peer sent empty message") + } + if msgLen > ChannelMaxPacket { + return 0, nil, status.Errorf(codes.ResourceExhausted, "message too large") + } + rawMsg := channelMsg.RawBytes.Value + + log.Ctx(ci.stream.Context()). + Debug(). + Uint8("type", rawMsg[0]). + Uint32("size", msgLen). + Msg("ssh: message received") + + // peek the first byte to check if we need to deduct from the window + switch rawMsg[0] { + case MsgChannelWindowAdjust, MsgChannelRequest, MsgChannelSuccess, MsgChannelFailure, MsgChannelEOF, MsgChannelClose: + // these messages don't consume window space + default: + // NB: It is not possible for localWindow to be < msgLen, since the window + // size is 64x the maximum packet size, and we have already checked the + // packet size above. The window adjust message is sent when the window + // size is at half of its max value. + ci.localWindow -= msgLen + if ci.localWindow < ChannelWindowSize/2 { + log.Ctx(ci.stream.Context()).Debug().Msg("ssh: flow control: increasing local window size") + ci.localWindow += ChannelWindowSize + if err := ci.SendMessage(WindowAdjustMsg{ + PeersID: ci.info.DownstreamChannelId, + AdditionalBytes: ChannelWindowSize, + }); err != nil { + return 0, nil, err + } + } + } + + // decode the channel message + switch msgID := rawMsg[0]; msgID { + case MsgChannelWindowAdjust: + var msg WindowAdjustMsg + if err := gossh.Unmarshal(rawMsg, &msg); err != nil { + return 0, nil, err + } + log.Ctx(ci.stream.Context()).Debug().Uint32("bytes", msg.AdditionalBytes).Msg("ssh: flow control: remote window size increased") + if !ci.remoteWindow.add(msg.AdditionalBytes) { + return 0, nil, status.Errorf(codes.InvalidArgument, "invalid window adjustment") + } + return msgID, msg, nil + case MsgChannelRequest: + var msg ChannelRequestMsg + if err := gossh.Unmarshal(rawMsg, &msg); err != nil { + return 0, nil, err + } + return msgID, msg, nil + case MsgChannelData: + var msg ChannelDataMsg + if err := gossh.Unmarshal(rawMsg, &msg); err != nil { + return 0, nil, err + } + return msgID, msg, nil + case MsgChannelClose: + var msg ChannelCloseMsg + if err := gossh.Unmarshal(rawMsg, &msg); err != nil { + return 0, nil, err + } + return msgID, msg, nil + case MsgChannelEOF: + var msg ChannelEOFMsg + if err := gossh.Unmarshal(rawMsg, &msg); err != nil { + return 0, nil, err + } + return msgID, msg, nil + case MsgChannelOpen: + return 0, nil, status.Errorf(codes.InvalidArgument, "only one channel can be opened") + default: + return 0, nil, status.Errorf(codes.Unimplemented, "received unexpected message with type %d", rawMsg[0]) + } + default: + return 0, nil, status.Errorf(codes.Unimplemented, "unknown channel message received") + } +} diff --git a/pkg/ssh/channel_impl_test.go b/pkg/ssh/channel_impl_test.go new file mode 100644 index 000000000..6dec74d6a --- /dev/null +++ b/pkg/ssh/channel_impl_test.go @@ -0,0 +1,302 @@ +package ssh_test + +import ( + "context" + "math" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + gossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/pkg/ssh" +) + +func TestFlowControl_BlockAndWaitForAdjust(t *testing.T) { + stream := newMockChannelStream(t) + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 1024, + MaxPacketSize: 4096, + }) + + sendDone := make(chan struct{}) + wait := make(chan struct{}) + go func() { + defer close(sendDone) + close(wait) + ci.SendMessage(ssh.ChannelDataMsg{ + PeersID: 1, + Length: 1024, + Rest: make([]byte, 1024), + }) + }() + done := make(chan struct{}) + go func() { + defer close(done) + <-wait + stream.SendClientToServer(channelMsg(ssh.WindowAdjustMsg{ + PeersID: 2, + AdditionalBytes: 1024, + })) + stream.SendClientToServer(channelMsg(ssh.ChannelDataMsg{ + PeersID: 2, + })) + msg, err := ci.RecvMsg() + <-sendDone + assert.NoError(t, err) + assert.Equal(t, ssh.ChannelDataMsg{ + PeersID: 2, + Rest: []byte{}, + }, msg) + }() + select { + case <-done: + case <-time.After(1 * time.Second): + assert.Fail(t, "timed out") + } +} + +func TestFlowControl_SendWindowAdjust(t *testing.T) { + stream := newMockChannelStream(t) + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 1024, + MaxPacketSize: 4096, + }) + + largeDataMsg := ssh.ChannelDataMsg{ + PeersID: 1, + Length: 16375, + Rest: make([]byte, 16375), + } + encodedLen := len(gossh.Marshal(largeDataMsg)) + require.Equal(t, 16384, encodedLen) // to make the numbers easier + + const MaxMsgsSentBeforeAdjust = (ssh.ChannelWindowSize / 2) / 16384 + for i := range MaxMsgsSentBeforeAdjust { + stream.SendClientToServer(channelMsg(largeDataMsg)) + dataMsg, err := ci.RecvMsg() + assert.NoError(t, err) + assert.NotNil(t, dataMsg) + require.Equalf(t, 0, len(stream.serverToClient), "unexpected window adjust on message %d", i) + } + + require.Equalf(t, 0, len(stream.serverToClient), "unexpected window adjust on message %d", MaxMsgsSentBeforeAdjust) + stream.SendClientToServer(channelMsg(largeDataMsg)) + dataMsg, err := ci.RecvMsg() + assert.NoError(t, err) + assert.NotNil(t, dataMsg) + require.Equal(t, 1, len(stream.serverToClient)) + + recv, err := stream.RecvServerToClient() + assert.NoError(t, err) + bytes := recv.GetRawBytes().GetValue() + var adjust ssh.WindowAdjustMsg + assert.NoError(t, gossh.Unmarshal(bytes, &adjust)) + assert.Equal(t, uint32(ssh.ChannelWindowSize), adjust.AdditionalBytes) + assert.Equal(t, uint32(1), adjust.PeersID) +} + +func TestFlowControl_WindowAdjustOverflow(t *testing.T) { + stream := newMockChannelStream(t) + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 1024, + MaxPacketSize: 4096, + }) + stream.SendClientToServer(channelMsg(ssh.WindowAdjustMsg{ + PeersID: 2, + AdditionalBytes: math.MaxUint32, + })) + _, err := ci.RecvMsg() + assert.ErrorIs(t, err, status.Errorf(codes.InvalidArgument, "invalid window adjustment")) +} + +func TestFlowControl_StreamClosed(t *testing.T) { + ctx, ca := context.WithCancel(t.Context()) + stream := &mockChannelStream{ + GenericServerStream: &grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]{ + ServerStream: &mockGrpcServerStream{ + ctx: ctx, + }, + }, + serverToClient: make(chan *extensions_ssh.ChannelMessage, 32), + clientToServer: make(chan *extensions_ssh.ChannelMessage, 32), + } + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 0, + MaxPacketSize: 4096, + }) + ready := make(chan struct{}) + errC := make(chan error, 1) + go func() { + close(ready) + errC <- ci.SendMessage(ssh.ChannelDataMsg{ + PeersID: 1, + Length: 1, + Rest: []byte("a"), + }) + }() + <-ready + runtime.Gosched() + ca() + select { + case err := <-errC: + assert.ErrorIs(t, err, status.Errorf(codes.Internal, "stream closed")) + case <-time.After(DefaultTimeout): + assert.Fail(t, "timed out") + } +} + +func TestRecvMsg_EmptyMessage(t *testing.T) { + stream := newMockChannelStream(t) + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 1024, + MaxPacketSize: 4096, + }) + + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes([]byte{}), + }, + }) + _, err := ci.RecvMsg() + assert.ErrorIs(t, status.Errorf(codes.InvalidArgument, "peer sent empty message"), err) +} + +func TestRecvMsg_MessageTooLarge(t *testing.T) { + stream := newMockChannelStream(t) + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 1024, + MaxPacketSize: 4096, + }) + + tooLargeDataMsg := ssh.ChannelDataMsg{ + PeersID: 1, + Length: ssh.ChannelMaxPacket, + Rest: make([]byte, ssh.ChannelMaxPacket), + } + stream.SendClientToServer(channelMsg(tooLargeDataMsg)) + _, err := ci.RecvMsg() + assert.ErrorIs(t, status.Errorf(codes.ResourceExhausted, "message too large"), err) +} + +func TestRecvMsg_AllowedMessages(t *testing.T) { + stream := newMockChannelStream(t) + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 1024, + MaxPacketSize: 4096, + }) + + // RecvMsg will immediately read another message after WindowAdjust, so + // we have to send something + stream.SendClientToServer(channelMsg(ssh.WindowAdjustMsg{})) + stream.SendClientToServer(channelMsg(ssh.ChannelDataMsg{})) + _, err := ci.RecvMsg() + assert.NoError(t, err) + + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{})) + _, err = ci.RecvMsg() + assert.NoError(t, err) + + stream.SendClientToServer(channelMsg(ssh.ChannelDataMsg{})) + _, err = ci.RecvMsg() + assert.NoError(t, err) + + stream.SendClientToServer(channelMsg(ssh.ChannelCloseMsg{})) + _, err = ci.RecvMsg() + assert.NoError(t, err) + + stream.SendClientToServer(channelMsg(ssh.ChannelEOFMsg{})) + _, err = ci.RecvMsg() + assert.NoError(t, err) + + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{})) + _, err = ci.RecvMsg() + assert.ErrorIs(t, err, status.Errorf(codes.InvalidArgument, "only one channel can be opened")) + + stream.SendClientToServer(channelMsg(ssh.ChannelRequestFailureMsg{})) + _, err = ci.RecvMsg() + assert.ErrorIs(t, err, status.Errorf(codes.Unimplemented, "received unexpected message with type 100")) + + stream.SendClientToServer(&extensions_ssh.ChannelMessage{Message: &extensions_ssh.ChannelMessage_ChannelControl{}}) + _, err = ci.RecvMsg() + assert.ErrorIs(t, err, status.Errorf(codes.Unimplemented, "unknown channel message received")) +} + +func TestRecvMsg_UnmarshalErrors(t *testing.T) { + stream := newMockChannelStream(t) + ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "session", + DownstreamChannelId: 1, + InternalUpstreamChannelId: 2, + InitialWindowSize: 1024, + MaxPacketSize: 4096, + }) + + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelWindowAdjust}), + }, + }) + _, err := ci.RecvMsg() + assert.ErrorContains(t, err, "ssh: short read") + + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelRequest}), + }, + }) + _, err = ci.RecvMsg() + assert.ErrorContains(t, err, "ssh: short read") + + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelData}), + }, + }) + _, err = ci.RecvMsg() + assert.ErrorContains(t, err, "ssh: short read") + + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelClose}), + }, + }) + _, err = ci.RecvMsg() + assert.ErrorContains(t, err, "ssh: short read") + + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelEOF}), + }, + }) + _, err = ci.RecvMsg() + assert.ErrorContains(t, err, "ssh: short read") +} diff --git a/pkg/ssh/cli.go b/pkg/ssh/cli.go new file mode 100644 index 000000000..52aac80b5 --- /dev/null +++ b/pkg/ssh/cli.go @@ -0,0 +1,244 @@ +package ssh + +import ( + "errors" + "fmt" + "io" + "strings" + + "github.com/charmbracelet/bubbles/list" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/spf13/cobra" + + "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/config" +) + +type CLI struct { + *cobra.Command + tui *tea.Program + ptyInfo *ssh.SSHDownstreamPTYInfo + username string +} + +func NewCLI( + cfg *config.Config, + ctrl ChannelControlInterface, + ptyInfo *ssh.SSHDownstreamPTYInfo, + stdin io.Reader, + stdout io.Writer, +) *CLI { + cmd := &cobra.Command{ + Use: "pomerium", + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + _, cmdIsInteractive := cmd.Annotations["interactive"] + switch { + case (ptyInfo == nil) && cmdIsInteractive: + return fmt.Errorf("\x1b[31m'%s' is an interactive command and requires a TTY (try passing '-t' to ssh)\x1b[0m", cmd.Use) + case (ptyInfo != nil) && !cmdIsInteractive: + return fmt.Errorf("\x1b[31m'%s' is not an interactive command (try passing '-T' to ssh, or removing '-t')\x1b[0m\r", cmd.Use) + } + return nil + }, + } + + cmd.CompletionOptions.DisableDefaultCmd = true + cmd.SetIn(stdin) + cmd.SetOut(stdout) + cmd.SetErr(stdout) + cmd.SilenceUsage = true + + cli := &CLI{ + Command: cmd, + tui: nil, + ptyInfo: ptyInfo, + username: *ctrl.Username(), + } + + if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { + cli.AddPortalCommand(ctrl) + } + cli.AddLogoutCommand(ctrl) + cli.AddWhoamiCommand(ctrl) + + return cli +} + +func (cli *CLI) AddLogoutCommand(ctrl ChannelControlInterface) { + cli.AddCommand(&cobra.Command{ + Use: "logout", + Short: "Log out", + RunE: func(cmd *cobra.Command, _ []string) error { + err := ctrl.DeleteSession(cmd.Context()) + if err != nil { + return fmt.Errorf("failed to delete session: %w\r", err) + } + _, _ = cmd.OutOrStdout().Write([]byte("Logged out successfully\r\n")) + return nil + }, + }) +} + +func (cli *CLI) AddWhoamiCommand(ctrl ChannelControlInterface) { + cli.AddCommand(&cobra.Command{ + Use: "whoami", + Short: "Show details for the current session", + RunE: func(cmd *cobra.Command, _ []string) error { + s, err := ctrl.FormatSession(cmd.Context()) + if err != nil { + return fmt.Errorf("couldn't fetch session: %w\r", err) + } + _, _ = cmd.OutOrStdout().Write(s) + return nil + }, + }) +} + +// ErrHandoff is a sentinel error to indicate that the command triggered a handoff, +// and we should not automatically disconnect +var ErrHandoff = errors.New("handoff") + +func (cli *CLI) AddPortalCommand(ctrl ChannelControlInterface) { + cli.AddCommand(&cobra.Command{ + Use: "portal", + Short: "Interactive route portal", + Annotations: map[string]string{ + "interactive": "", + }, + RunE: func(cmd *cobra.Command, _ []string) error { + var routes []string + for r := range ctrl.AllSSHRoutes() { + routes = append(routes, fmt.Sprintf("%s@%s", *ctrl.Username(), strings.TrimPrefix(r.From, "ssh://"))) + } + items := []list.Item{} + for _, route := range routes { + items = append(items, item(route)) + } + l := list.New(items, itemDelegate{}, int(cli.ptyInfo.WidthColumns-2), int(cli.ptyInfo.HeightRows-2)) + l.Title = "Connect to which server?" + l.SetShowStatusBar(false) + l.SetFilteringEnabled(false) + l.Styles.Title = titleStyle + l.Styles.PaginationStyle = paginationStyle + l.Styles.HelpStyle = helpStyle + + cli.tui = tea.NewProgram(model{list: l}, + tea.WithInput(cmd.InOrStdin()), + tea.WithOutput(cmd.OutOrStdout()), + tea.WithAltScreen(), + tea.WithContext(cmd.Context()), + tea.WithEnvironment([]string{"TERM=" + cli.ptyInfo.TermEnv}), + ) + + go cli.SendTeaMsg(tea.WindowSizeMsg{Width: int(cli.ptyInfo.WidthColumns), Height: int(cli.ptyInfo.HeightRows)}) + answer, err := cli.tui.Run() + if err != nil { + return err + } + if answer.(model).choice == "" { + return nil // quit/ctrl+c + } + + username, hostname, _ := strings.Cut(answer.(model).choice, "@") + // Perform authorize check for this route + if username != cli.username { + panic("bug: username mismatch") + } + if hostname == "" { + panic("bug: hostname is empty") + } + + handoffMsg, err := ctrl.PrepareHandoff(cmd.Context(), hostname, cli.ptyInfo) + if err != nil { + return err + } + if err := ctrl.SendControlAction(handoffMsg); err != nil { + return err + } + return ErrHandoff + }, + }) +} + +func (cli *CLI) SendTeaMsg(msg tea.Msg) { + if cli.tui != nil { + cli.tui.Send(msg) + } +} + +var ( + titleStyle = lipgloss.NewStyle().MarginLeft(2) + itemStyle = lipgloss.NewStyle().PaddingLeft(4) + selectedItemStyle = lipgloss.NewStyle().PaddingLeft(2).Foreground(lipgloss.Color("170")) + paginationStyle = list.DefaultStyles().PaginationStyle.PaddingLeft(4) + helpStyle = list.DefaultStyles().HelpStyle.PaddingLeft(4).PaddingBottom(1) +) + +type item string + +func (i item) FilterValue() string { return "" } + +type itemDelegate struct{} + +func (d itemDelegate) Height() int { return 1 } +func (d itemDelegate) Spacing() int { return 0 } +func (d itemDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil } +func (d itemDelegate) Render(w io.Writer, m list.Model, index int, listItem list.Item) { + i, ok := listItem.(item) + if !ok { + return + } + + str := fmt.Sprintf("%d. %s", index+1, i) + + fn := itemStyle.Render + if index == m.Index() { + fn = func(s ...string) string { + return selectedItemStyle.Render("> " + strings.Join(s, " ")) + } + } + + fmt.Fprint(w, fn(str)) +} + +type model struct { + list list.Model + choice string + quitting bool +} + +func (m model) Init() tea.Cmd { + return nil +} + +func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.list.SetWidth(msg.Width - 2) + m.list.SetHeight(msg.Height - 2) + return m, nil + + case tea.KeyMsg: + switch keypress := msg.String(); keypress { + case "q", "ctrl+c": + m.quitting = true + return m, tea.Quit + + case "enter": + i, ok := m.list.SelectedItem().(item) + if ok { + m.choice = string(i) + } + return m, tea.Quit + } + } + + var cmd tea.Cmd + m.list, cmd = m.list.Update(msg) + return m, cmd +} + +func (m model) View() string { + return "\n" + m.list.View() +} diff --git a/pkg/ssh/flow_control.go b/pkg/ssh/flow_control.go new file mode 100644 index 000000000..1b6f8edfb --- /dev/null +++ b/pkg/ssh/flow_control.go @@ -0,0 +1,83 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Unexported flow control logic copied from x/crypto/ssh + +import ( + "io" + "sync" +) + +const ( + // ChannelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + ChannelMaxPacket = 1 << 15 + // We follow OpenSSH here. + ChannelWindowSize = 64 * ChannelMaxPacket +) + +// Window represents the buffer available to clients +// wishing to write to a channel. +type Window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *Window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *Window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *Window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} diff --git a/pkg/ssh/manager.go b/pkg/ssh/manager.go new file mode 100644 index 000000000..7656411a8 --- /dev/null +++ b/pkg/ssh/manager.go @@ -0,0 +1,53 @@ +package ssh + +import ( + "sync" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/config" +) + +type StreamManager struct { + auth AuthInterface + mu sync.Mutex + activeStreams map[uint64]*StreamHandler +} + +func NewStreamManager(auth AuthInterface) *StreamManager { + return &StreamManager{ + auth: auth, + activeStreams: map[uint64]*StreamHandler{}, + } +} + +func (sm *StreamManager) LookupStream(streamID uint64) *StreamHandler { + sm.mu.Lock() + defer sm.mu.Unlock() + stream := sm.activeStreams[streamID] + if stream == nil { + return nil + } + return stream +} + +func (sm *StreamManager) NewStreamHandler(cfg *config.Config, downstream *extensions_ssh.DownstreamConnectEvent) *StreamHandler { + sm.mu.Lock() + defer sm.mu.Unlock() + streamID := downstream.StreamId + writeC := make(chan *extensions_ssh.ServerMessage, 32) + sh := &StreamHandler{ + auth: sm.auth, + config: cfg, + downstream: downstream, + readC: make(chan *extensions_ssh.ClientMessage, 32), + writeC: writeC, + close: func() { + sm.mu.Lock() + defer sm.mu.Unlock() + delete(sm.activeStreams, streamID) + close(writeC) + }, + } + sm.activeStreams[streamID] = sh + return sh +} diff --git a/pkg/ssh/manager_test.go b/pkg/ssh/manager_test.go new file mode 100644 index 000000000..0cd5aedf3 --- /dev/null +++ b/pkg/ssh/manager_test.go @@ -0,0 +1,40 @@ +package ssh_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/ssh" + mock_ssh "github.com/pomerium/pomerium/pkg/ssh/mock" +) + +func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL { + wu, err := config.ParseWeightedUrls(urls...) + require.NoError(t, err) + return wu +} + +func TestStreamManager(t *testing.T) { + ctrl := gomock.NewController(t) + auth := mock_ssh.NewMockAuthInterface(ctrl) + m := ssh.NewStreamManager(auth) + + cfg := &config.Config{Options: config.NewDefaultOptions()} + cfg.Options.Policies = []config.Policy{ + {From: "ssh://host1", To: mustParseWeightedURLs(t, "ssh://dest1:22")}, + {From: "ssh://host2", To: mustParseWeightedURLs(t, "ssh://dest2:22")}, + } + + t.Run("LookupStream", func(t *testing.T) { + assert.Nil(t, m.LookupStream(1234)) + sh := m.NewStreamHandler(cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234}) + assert.Equal(t, sh, m.LookupStream(1234)) + sh.Close() + assert.Nil(t, m.LookupStream(1234)) + }) +} diff --git a/pkg/ssh/messages.go b/pkg/ssh/messages.go new file mode 100644 index 000000000..f2e3df10b --- /dev/null +++ b/pkg/ssh/messages.go @@ -0,0 +1,124 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Unexported message types copied from x/crypto/ssh + +// See RFC 4254, section 5.1. +const MsgChannelOpen = 90 + +type ChannelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersID uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const ( + MsgChannelExtendedData = 95 + MsgChannelData = 94 +) + +// See RFC 4253, section 11.1. +const MsgDisconnect = 1 + +// DisconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type DisconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +// Used for debug print outs of packets. +type ChannelDataMsg struct { + PeersID uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const MsgChannelOpenConfirm = 91 + +type ChannelOpenConfirmMsg struct { + PeersID uint32 `sshtype:"91"` + MyID uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const MsgChannelRequest = 98 + +type ChannelRequestMsg struct { + PeersID uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +type ChannelOpenDirectMsg struct { + DestAddr string + DestPort uint32 + SrcAddr string + SrcPort uint32 +} + +type ChannelWindowChangeRequestMsg struct { + WidthColumns uint32 + HeightRows uint32 + WidthPx uint32 + HeightPx uint32 +} + +type ShellChannelRequestMsg struct{} + +type ExecChannelRequestMsg struct { + Command string +} + +// See RFC 4254, section 5.2 +const MsgChannelWindowAdjust = 93 + +type WindowAdjustMsg struct { + PeersID uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4254, section 5.4. +const MsgChannelSuccess = 99 + +type ChannelRequestSuccessMsg struct { + PeersID uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const MsgChannelFailure = 100 + +type ChannelRequestFailureMsg struct { + PeersID uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const MsgChannelClose = 97 + +type ChannelCloseMsg struct { + PeersID uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const MsgChannelEOF = 96 + +type ChannelEOFMsg struct { + PeersID uint32 `sshtype:"96"` +} + +type PtyReqChannelRequestMsg struct { + TermEnv string + Width, Height uint32 + WidthPx, HeightPx uint32 + Modes []byte +} diff --git a/pkg/ssh/mock/mock_auth_interface.go b/pkg/ssh/mock/mock_auth_interface.go new file mode 100644 index 000000000..c033561e6 --- /dev/null +++ b/pkg/ssh/mock/mock_auth_interface.go @@ -0,0 +1,236 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/pomerium/pomerium/pkg/ssh (interfaces: AuthInterface) +// +// Generated by this command: +// +// mockgen -typed . AuthInterface +// + +// Package mock_ssh is a generated GoMock package. +package mock_ssh + +import ( + context "context" + reflect "reflect" + + ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + ssh0 "github.com/pomerium/pomerium/pkg/ssh" + gomock "go.uber.org/mock/gomock" +) + +// MockAuthInterface is a mock of AuthInterface interface. +type MockAuthInterface struct { + ctrl *gomock.Controller + recorder *MockAuthInterfaceMockRecorder + isgomock struct{} +} + +// MockAuthInterfaceMockRecorder is the mock recorder for MockAuthInterface. +type MockAuthInterfaceMockRecorder struct { + mock *MockAuthInterface +} + +// NewMockAuthInterface creates a new mock instance. +func NewMockAuthInterface(ctrl *gomock.Controller) *MockAuthInterface { + mock := &MockAuthInterface{ctrl: ctrl} + mock.recorder = &MockAuthInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAuthInterface) EXPECT() *MockAuthInterfaceMockRecorder { + return m.recorder +} + +// DeleteSession mocks base method. +func (m *MockAuthInterface) DeleteSession(ctx context.Context, info ssh0.StreamAuthInfo) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSession", ctx, info) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSession indicates an expected call of DeleteSession. +func (mr *MockAuthInterfaceMockRecorder) DeleteSession(ctx, info any) *MockAuthInterfaceDeleteSessionCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSession", reflect.TypeOf((*MockAuthInterface)(nil).DeleteSession), ctx, info) + return &MockAuthInterfaceDeleteSessionCall{Call: call} +} + +// MockAuthInterfaceDeleteSessionCall wrap *gomock.Call +type MockAuthInterfaceDeleteSessionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAuthInterfaceDeleteSessionCall) Return(arg0 error) *MockAuthInterfaceDeleteSessionCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAuthInterfaceDeleteSessionCall) Do(f func(context.Context, ssh0.StreamAuthInfo) error) *MockAuthInterfaceDeleteSessionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAuthInterfaceDeleteSessionCall) DoAndReturn(f func(context.Context, ssh0.StreamAuthInfo) error) *MockAuthInterfaceDeleteSessionCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// EvaluateDelayed mocks base method. +func (m *MockAuthInterface) EvaluateDelayed(ctx context.Context, info ssh0.StreamAuthInfo) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EvaluateDelayed", ctx, info) + ret0, _ := ret[0].(error) + return ret0 +} + +// EvaluateDelayed indicates an expected call of EvaluateDelayed. +func (mr *MockAuthInterfaceMockRecorder) EvaluateDelayed(ctx, info any) *MockAuthInterfaceEvaluateDelayedCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EvaluateDelayed", reflect.TypeOf((*MockAuthInterface)(nil).EvaluateDelayed), ctx, info) + return &MockAuthInterfaceEvaluateDelayedCall{Call: call} +} + +// MockAuthInterfaceEvaluateDelayedCall wrap *gomock.Call +type MockAuthInterfaceEvaluateDelayedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAuthInterfaceEvaluateDelayedCall) Return(arg0 error) *MockAuthInterfaceEvaluateDelayedCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAuthInterfaceEvaluateDelayedCall) Do(f func(context.Context, ssh0.StreamAuthInfo) error) *MockAuthInterfaceEvaluateDelayedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAuthInterfaceEvaluateDelayedCall) DoAndReturn(f func(context.Context, ssh0.StreamAuthInfo) error) *MockAuthInterfaceEvaluateDelayedCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// FormatSession mocks base method. +func (m *MockAuthInterface) FormatSession(ctx context.Context, info ssh0.StreamAuthInfo) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FormatSession", ctx, info) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FormatSession indicates an expected call of FormatSession. +func (mr *MockAuthInterfaceMockRecorder) FormatSession(ctx, info any) *MockAuthInterfaceFormatSessionCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FormatSession", reflect.TypeOf((*MockAuthInterface)(nil).FormatSession), ctx, info) + return &MockAuthInterfaceFormatSessionCall{Call: call} +} + +// MockAuthInterfaceFormatSessionCall wrap *gomock.Call +type MockAuthInterfaceFormatSessionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAuthInterfaceFormatSessionCall) Return(arg0 []byte, arg1 error) *MockAuthInterfaceFormatSessionCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAuthInterfaceFormatSessionCall) Do(f func(context.Context, ssh0.StreamAuthInfo) ([]byte, error)) *MockAuthInterfaceFormatSessionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAuthInterfaceFormatSessionCall) DoAndReturn(f func(context.Context, ssh0.StreamAuthInfo) ([]byte, error)) *MockAuthInterfaceFormatSessionCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// HandleKeyboardInteractiveMethodRequest mocks base method. +func (m *MockAuthInterface) HandleKeyboardInteractiveMethodRequest(ctx context.Context, info ssh0.StreamAuthInfo, req *ssh.KeyboardInteractiveMethodRequest, querier ssh0.KeyboardInteractiveQuerier) (ssh0.KeyboardInteractiveAuthMethodResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleKeyboardInteractiveMethodRequest", ctx, info, req, querier) + ret0, _ := ret[0].(ssh0.KeyboardInteractiveAuthMethodResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HandleKeyboardInteractiveMethodRequest indicates an expected call of HandleKeyboardInteractiveMethodRequest. +func (mr *MockAuthInterfaceMockRecorder) HandleKeyboardInteractiveMethodRequest(ctx, info, req, querier any) *MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleKeyboardInteractiveMethodRequest", reflect.TypeOf((*MockAuthInterface)(nil).HandleKeyboardInteractiveMethodRequest), ctx, info, req, querier) + return &MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall{Call: call} +} + +// MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall wrap *gomock.Call +type MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall) Return(arg0 ssh0.KeyboardInteractiveAuthMethodResponse, arg1 error) *MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall) Do(f func(context.Context, ssh0.StreamAuthInfo, *ssh.KeyboardInteractiveMethodRequest, ssh0.KeyboardInteractiveQuerier) (ssh0.KeyboardInteractiveAuthMethodResponse, error)) *MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall) DoAndReturn(f func(context.Context, ssh0.StreamAuthInfo, *ssh.KeyboardInteractiveMethodRequest, ssh0.KeyboardInteractiveQuerier) (ssh0.KeyboardInteractiveAuthMethodResponse, error)) *MockAuthInterfaceHandleKeyboardInteractiveMethodRequestCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// HandlePublicKeyMethodRequest mocks base method. +func (m *MockAuthInterface) HandlePublicKeyMethodRequest(ctx context.Context, info ssh0.StreamAuthInfo, req *ssh.PublicKeyMethodRequest) (ssh0.PublicKeyAuthMethodResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandlePublicKeyMethodRequest", ctx, info, req) + ret0, _ := ret[0].(ssh0.PublicKeyAuthMethodResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HandlePublicKeyMethodRequest indicates an expected call of HandlePublicKeyMethodRequest. +func (mr *MockAuthInterfaceMockRecorder) HandlePublicKeyMethodRequest(ctx, info, req any) *MockAuthInterfaceHandlePublicKeyMethodRequestCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePublicKeyMethodRequest", reflect.TypeOf((*MockAuthInterface)(nil).HandlePublicKeyMethodRequest), ctx, info, req) + return &MockAuthInterfaceHandlePublicKeyMethodRequestCall{Call: call} +} + +// MockAuthInterfaceHandlePublicKeyMethodRequestCall wrap *gomock.Call +type MockAuthInterfaceHandlePublicKeyMethodRequestCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAuthInterfaceHandlePublicKeyMethodRequestCall) Return(arg0 ssh0.PublicKeyAuthMethodResponse, arg1 error) *MockAuthInterfaceHandlePublicKeyMethodRequestCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAuthInterfaceHandlePublicKeyMethodRequestCall) Do(f func(context.Context, ssh0.StreamAuthInfo, *ssh.PublicKeyMethodRequest) (ssh0.PublicKeyAuthMethodResponse, error)) *MockAuthInterfaceHandlePublicKeyMethodRequestCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAuthInterfaceHandlePublicKeyMethodRequestCall) DoAndReturn(f func(context.Context, ssh0.StreamAuthInfo, *ssh.PublicKeyMethodRequest) (ssh0.PublicKeyAuthMethodResponse, error)) *MockAuthInterfaceHandlePublicKeyMethodRequestCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/pkg/ssh/stream.go b/pkg/ssh/stream.go new file mode 100644 index 000000000..a577170fc --- /dev/null +++ b/pkg/ssh/stream.go @@ -0,0 +1,483 @@ +package ssh + +import ( + "context" + "iter" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + gossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/slices" +) + +const ( + MethodPublicKey = "publickey" + MethodKeyboardInteractive = "keyboard-interactive" +) + +type KeyboardInteractiveQuerier interface { + // Prompts the client and returns their responses to the given prompts. + Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) +} + +type AuthMethodResponse[T any] struct { + Allow *T + RequireAdditionalMethods []string +} + +type ( + PublicKeyAuthMethodResponse = AuthMethodResponse[extensions_ssh.PublicKeyAllowResponse] + KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse] +) + +type AuthInterface interface { + HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error) + HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error) + EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error + FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error) + DeleteSession(ctx context.Context, info StreamAuthInfo) error +} + +type AuthMethodValue[T any] struct { + attempted bool + Value *T +} + +func (v *AuthMethodValue[T]) Update(value *T) { + v.attempted = true + v.Value = value +} + +func (v *AuthMethodValue[T]) IsValid() bool { + if v.attempted { + // method was attempted - valid iff there is a value + return v.Value != nil + } + return true // method was not attempted - valid +} + +type StreamAuthInfo struct { + Username *string + Hostname *string + StreamID uint64 + SourceAddress string + PublicKeyFingerprintSha256 []byte + PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse] + KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse] +} + +func (i *StreamAuthInfo) allMethodsValid() bool { + return i.PublicKeyAllow.IsValid() && i.KeyboardInteractiveAllow.IsValid() +} + +type StreamState struct { + StreamAuthInfo + DirectTcpip bool + RemainingUnauthenticatedMethods []string + DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo +} + +// StreamHandler handles a single SSH stream +type StreamHandler struct { + auth AuthInterface + config *config.Config + downstream *extensions_ssh.DownstreamConnectEvent + writeC chan *extensions_ssh.ServerMessage + readC chan *extensions_ssh.ClientMessage + + state *StreamState + close func() + + channelIDCounter uint32 + expectingInternalChannel bool +} + +var _ StreamHandlerInterface = (*StreamHandler)(nil) + +func (sh *StreamHandler) Close() { + sh.close() +} + +func (sh *StreamHandler) IsExpectingInternalChannel() bool { + return sh.expectingInternalChannel +} + +func (sh *StreamHandler) ReadC() chan<- *extensions_ssh.ClientMessage { + return sh.readC +} + +func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage { + return sh.writeC +} + +// Prompt implements KeyboardInteractiveQuerier. +func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) { + sh.sendInfoPrompts(prompts) + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case req := <-sh.readC: + switch msg := req.Message.(type) { + case *extensions_ssh.ClientMessage_InfoResponse: + if msg.InfoResponse.Method != "keyboard-interactive" { + return nil, status.Errorf(codes.Internal, "received invalid info response") + } + r, _ := msg.InfoResponse.Response.UnmarshalNew() + respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses) + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "received invalid prompt response") + } + return respInfo, nil + default: + return nil, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response") + } + } +} + +func (sh *StreamHandler) Run(ctx context.Context) error { + if sh.state != nil { + panic("Run called twice") + } + sh.state = &StreamState{ + RemainingUnauthenticatedMethods: []string{MethodPublicKey}, + StreamAuthInfo: StreamAuthInfo{ + StreamID: sh.downstream.StreamId, + SourceAddress: sh.downstream.SourceAddress, + }, + } + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + case req := <-sh.readC: + switch req := req.Message.(type) { + case *extensions_ssh.ClientMessage_Event: + switch event := req.Event.Event.(type) { + case *extensions_ssh.StreamEvent_DownstreamConnected: + // this was already received as the first message in the stream + return status.Errorf(codes.Internal, "received duplicate downstream connected event") + case *extensions_ssh.StreamEvent_UpstreamConnected: + log.Ctx(ctx).Debug(). + Uint64("stream-id", event.UpstreamConnected.StreamId). + Msg("ssh: upstream connected") + case *extensions_ssh.StreamEvent_DownstreamDisconnected: + log.Ctx(ctx).Debug(). + Uint64("stream-id", sh.downstream.StreamId). + Str("reason", event.DownstreamDisconnected.Reason). + Msg("ssh: downstream disconnected") + case nil: + return status.Errorf(codes.Internal, "received invalid event") + } + case *extensions_ssh.ClientMessage_AuthRequest: + if err := sh.handleAuthRequest(ctx, req.AuthRequest); err != nil { + return err + } + default: + return status.Errorf(codes.Internal, "received invalid message") + } + } + } +} + +func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_ServeChannelServer) error { + // The first channel message on this stream should be a ChannelOpen + channelOpen, err := stream.Recv() + if err != nil { + return err + } + rawMsg, ok := channelOpen.GetMessage().(*extensions_ssh.ChannelMessage_RawBytes) + if !ok { + return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen") + } + var msg ChannelOpenMsg + if err := gossh.Unmarshal(rawMsg.RawBytes.GetValue(), &msg); err != nil { + return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen") + } + + sh.channelIDCounter++ + sh.state.DownstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: msg.ChanType, + DownstreamChannelId: msg.PeersID, + InternalUpstreamChannelId: sh.channelIDCounter, + InitialWindowSize: msg.PeersWindow, + MaxPacketSize: msg.MaxPacketSize, + } + channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo) + switch msg.ChanType { + case "session": + if err := channel.SendMessage(ChannelOpenConfirmMsg{ + PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId, + MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId, + MyWindow: ChannelWindowSize, + MaxPacketSize: ChannelMaxPacket, + }); err != nil { + return err + } + ch := NewChannelHandler(channel, sh.config) + return ch.Run(stream.Context()) + case "direct-tcpip": + var subMsg ChannelOpenDirectMsg + if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil { + return err + } + sh.state.DirectTcpip = true + action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil) + if err != nil { + return err + } + return channel.SendControlAction(action) + default: + return status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: %s", msg.ChanType) + } +} + +func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ssh.AuthenticationRequest) error { + if req.Protocol != "ssh" { + return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol) + } + if req.Service != "ssh-connection" { + return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service) + } + if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) { + return status.Errorf(codes.InvalidArgument, "unexpected auth method: %s", req.AuthMethod) + } + + if sh.state.Username == nil { + if req.Username == "" { + return status.Errorf(codes.InvalidArgument, "username missing") + } + sh.state.Username = &req.Username + } else if *sh.state.Username != req.Username { + return status.Errorf(codes.InvalidArgument, "inconsistent username") + } + if sh.state.Hostname == nil { + sh.state.Hostname = &req.Hostname + } else if *sh.state.Hostname != req.Hostname { + return status.Errorf(codes.InvalidArgument, "inconsistent hostname") + } + + updateMethods := func(add []string) { + sh.state.RemainingUnauthenticatedMethods = slices.Remove(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) + sh.state.RemainingUnauthenticatedMethods = append(sh.state.RemainingUnauthenticatedMethods, add...) + } + log.Ctx(ctx).Debug(). + Str("method", req.AuthMethod). + Str("username", *sh.state.Username). + Str("hostname", *sh.state.Hostname). + Msg("ssh: handling auth request") + + var partial bool + switch req.AuthMethod { + case MethodPublicKey: + methodReq, _ := req.MethodRequest.UnmarshalNew() + pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest) + if !ok { + return status.Errorf(codes.InvalidArgument, "invalid public key method request type") + } + response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq) + if err != nil { + return err + } + partial = response.Allow != nil + sh.state.PublicKeyAllow.Update(response.Allow) + updateMethods(response.RequireAdditionalMethods) + case MethodKeyboardInteractive: + methodReq, _ := req.MethodRequest.UnmarshalNew() + kbiReq, ok := methodReq.(*extensions_ssh.KeyboardInteractiveMethodRequest) + if !ok { + return status.Errorf(codes.InvalidArgument, "invalid keyboard-interactive method request type") + } + response, err := sh.auth.HandleKeyboardInteractiveMethodRequest(ctx, sh.state.StreamAuthInfo, kbiReq, sh) + if err != nil { + return err + } + partial = response.Allow != nil + sh.state.KeyboardInteractiveAllow.Update(response.Allow) + updateMethods(response.RequireAdditionalMethods) + default: + return status.Errorf(codes.Internal, "bug: server requested an unsupported auth method %q", req.AuthMethod) + } + log.Ctx(ctx).Debug(). + Str("method", req.AuthMethod). + Bool("partial", partial). + Strs("methods-remaining", sh.state.RemainingUnauthenticatedMethods). + Msg("ssh: auth request complete") + + if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() { + // if there are no methods remaining, the user is allowed if all attempted + // methods have a valid response in the state + log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response") + sh.sendAllowResponse() + } else { + log.Ctx(ctx).Debug().Msg("ssh: unauthenticated methods remain, sending deny response") + sh.sendDenyResponseWithRemainingMethods(partial) + } + return nil +} + +func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) { + if hostname == "" { + return nil, status.Errorf(codes.PermissionDenied, "invalid hostname") + } + if sh.state.Hostname == nil { + panic("bug: PrepareHandoff called but state is missing a hostname") + } + if *sh.state.Hostname != "" { + panic("bug: PrepareHandoff called but previous hostname is not empty") + } + *sh.state.Hostname = hostname + err := sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo) + if err != nil { + return nil, status.Error(codes.PermissionDenied, err.Error()) + } + log.Ctx(ctx).Debug(). + Str("hostname", *sh.state.Hostname). + Str("username", *sh.state.Username). + Msg("ssh: initiating handoff to upstream") + upstreamAllow := sh.buildUpstreamAllowResponse() + action := &extensions_ssh.SSHChannelControlAction{ + Action: &extensions_ssh.SSHChannelControlAction_HandOff{ + HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: sh.state.DownstreamChannelInfo, + DownstreamPtyInfo: ptyInfo, + UpstreamAuth: upstreamAllow, + }, + }, + } + return action, nil +} + +func (sh *StreamHandler) FormatSession(ctx context.Context) ([]byte, error) { + return sh.auth.FormatSession(ctx, sh.state.StreamAuthInfo) +} + +func (sh *StreamHandler) DeleteSession(ctx context.Context) error { + return sh.auth.DeleteSession(ctx, sh.state.StreamAuthInfo) +} + +func (sh *StreamHandler) AllSSHRoutes() iter.Seq[*config.Policy] { + return func(yield func(*config.Policy) bool) { + for route := range sh.config.Options.GetAllPolicies() { + if route.IsSSH() { + if !yield(route) { + return + } + } + } + } +} + +// DownstreamChannelID implements StreamHandlerInterface. +func (sh *StreamHandler) DownstreamChannelID() uint32 { + return sh.state.DownstreamChannelInfo.DownstreamChannelId +} + +// Hostname implements StreamHandlerInterface. +func (sh *StreamHandler) Hostname() *string { + return sh.state.Hostname +} + +// Username implements StreamHandlerInterface. +func (sh *StreamHandler) Username() *string { + return sh.state.Username +} + +func (sh *StreamHandler) sendDenyResponseWithRemainingMethods(partial bool) { + sh.writeC <- &extensions_ssh.ServerMessage{ + Message: &extensions_ssh.ServerMessage_AuthResponse{ + AuthResponse: &extensions_ssh.AuthenticationResponse{ + Response: &extensions_ssh.AuthenticationResponse_Deny{ + Deny: &extensions_ssh.DenyResponse{ + Partial: partial, + Methods: sh.state.RemainingUnauthenticatedMethods, + }, + }, + }, + }, + } +} + +func (sh *StreamHandler) sendAllowResponse() { + var allow *extensions_ssh.AllowResponse + if *sh.state.Hostname == "" { + sh.expectingInternalChannel = true + allow = sh.buildInternalAllowResponse() + } else { + allow = sh.buildUpstreamAllowResponse() + } + + sh.writeC <- &extensions_ssh.ServerMessage{ + Message: &extensions_ssh.ServerMessage_AuthResponse{ + AuthResponse: &extensions_ssh.AuthenticationResponse{ + Response: &extensions_ssh.AuthenticationResponse_Allow{ + Allow: allow, + }, + }, + }, + } +} + +func (sh *StreamHandler) sendInfoPrompts(prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) { + sh.writeC <- &extensions_ssh.ServerMessage{ + Message: &extensions_ssh.ServerMessage_AuthResponse{ + AuthResponse: &extensions_ssh.AuthenticationResponse{ + Response: &extensions_ssh.AuthenticationResponse_InfoRequest{ + InfoRequest: &extensions_ssh.InfoRequest{ + Method: MethodKeyboardInteractive, + Request: protoutil.NewAny(prompts), + }, + }, + }, + }, + } +} + +func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowResponse { + var allowedMethods []*extensions_ssh.AllowedMethod + if value := sh.state.PublicKeyAllow.Value; value != nil { + allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{ + Method: MethodPublicKey, + MethodData: protoutil.NewAny(value), + }) + } + if value := sh.state.KeyboardInteractiveAllow.Value; value != nil { + allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{ + Method: MethodKeyboardInteractive, + MethodData: protoutil.NewAny(value), + }) + } + return &extensions_ssh.AllowResponse{ + Username: *sh.state.Username, + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + Hostname: *sh.state.Hostname, + DirectTcpip: sh.state.DirectTcpip, + AllowedMethods: allowedMethods, + }, + }, + } +} + +func (sh *StreamHandler) buildInternalAllowResponse() *extensions_ssh.AllowResponse { + return &extensions_ssh.AllowResponse{ + Username: *sh.state.Username, + Target: &extensions_ssh.AllowResponse_Internal{ + Internal: &extensions_ssh.InternalTarget{ + SetMetadata: &corev3.Metadata{ + TypedFilterMetadata: map[string]*anypb.Any{ + "com.pomerium.ssh": protoutil.NewAny(&extensions_ssh.FilterMetadata{ + StreamId: sh.downstream.StreamId, + }), + }, + }, + }, + }, + } +} diff --git a/pkg/ssh/stream_test.go b/pkg/ssh/stream_test.go new file mode 100644 index 000000000..2046bcc21 --- /dev/null +++ b/pkg/ssh/stream_test.go @@ -0,0 +1,2073 @@ +package ssh_test + +import ( + "bytes" + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "iter" + "os" + "runtime" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/charmbracelet/x/ansi" + "github.com/stretchr/testify/suite" + . "go.uber.org/mock/gomock" //nolint + gossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/ssh" + mock_ssh "github.com/pomerium/pomerium/pkg/ssh/mock" +) + +var DefaultTimeout = 10 * time.Second + +func init() { + if isDebuggerAttached() { + DefaultTimeout = 1 * time.Hour + } +} + +func isDebuggerAttached() bool { + if runtime.GOOS == "linux" { + data, err := os.ReadFile("/proc/self/status") + if err == nil { + for line := range bytes.Lines(data) { + if bytes.HasPrefix(line, []byte("TracerPid:\t")) { + return line[11] != '0' + } + } + } + } + return false +} + +func HookWithArgs(f func(s *StreamHandlerSuite, args []any) any, args ...any) []func(s *StreamHandlerSuite) any { + return []func(s *StreamHandlerSuite) any{ + func(s *StreamHandlerSuite) any { + return f(s, args) + }, + } +} + +var ( + StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{} + StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{} +) + +type StreamHandlerSuiteOptions struct { + ConfigModifiers []func(*config.Config) +} + +type StreamHandlerSuite struct { + suite.Suite + StreamHandlerSuiteOptions + + ctrl *Controller + + mgr *ssh.StreamManager + cfg *config.Config + + cleanup []func() + errC chan error + + mockAuth *mock_ssh.MockAuthInterface + + ed25519PublicKey ed25519.PublicKey + ed25519PrivateKey ed25519.PrivateKey + ed25519SshPublicKey gossh.PublicKey + ed25519SshPrivateKey gossh.Signer + + BeforeTestHookResult any +} + +func (s *StreamHandlerSuite) SetupTest() { + s.ctrl = NewController(s.T()) + s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl) + s.mgr = ssh.NewStreamManager(s.mockAuth) + s.cleanup = []func(){} + s.errC = make(chan error, 1) + + var err error + s.ed25519PublicKey, s.ed25519PrivateKey, err = ed25519.GenerateKey(rand.Reader) + s.Require().NoError(err) + s.ed25519SshPublicKey, err = gossh.NewPublicKey(s.ed25519PublicKey) + s.Require().NoError(err) + s.ed25519SshPrivateKey, err = gossh.NewSignerFromKey(s.ed25519PrivateKey) + s.Require().NoError(err) + + s.cfg = &config.Config{Options: config.NewDefaultOptions()} + s.cfg.Options.Policies = []config.Policy{ + {From: "https://from.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to.notssh.example.com")}, + {From: "ssh://host1", To: mustParseWeightedURLs(s.T(), "ssh://dest1:22")}, + {From: "https://from1.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to1.notssh.example.com")}, + {From: "ssh://host2", To: mustParseWeightedURLs(s.T(), "ssh://dest2:22")}, + {From: "https://from2.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to2.notssh.example.com")}, + } + for _, f := range s.ConfigModifiers { + f(s.cfg) + } +} + +func (s *StreamHandlerSuite) TearDownTest() { + for _, f := range s.cleanup { + f() + } + s.ctrl.Finish() +} + +func (s *StreamHandlerSuite) BeforeTest(_, testName string) { + s.BeforeTestHookResult = nil + for _, fn := range StreamHandlerSuiteBeforeTestHooks[testName] { + s.BeforeTestHookResult = fn(s) + } +} + +// +// Helper methods +// + +func marshalAny(msg proto.Message) *anypb.Any { + a, err := anypb.New(msg) + if err != nil { + panic(err) + } + return a +} + +func (s *StreamHandlerSuite) expectError(fn func(), msg string) { + fn() + select { + case err := <-s.errC: + s.ErrorContains(err, msg) + case <-time.After(DefaultTimeout): + s.FailNowf("timed out waiting for error %q", msg) + } +} + +func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler { + sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID}) + s.errC = make(chan error, 1) + ctx, ca := context.WithCancel(s.T().Context()) + go func() { + defer close(s.errC) + s.errC <- sh.Run(ctx) + }() + s.cleanup = append(s.cleanup, func() { + start := time.Now() + for len(sh.ReadC()) > 0 && time.Since(start) < 100*time.Millisecond { + runtime.Gosched() + } + if len(sh.ReadC()) > 0 { + s.Fail(fmt.Sprintf("read channel contains %d unhandled client messages", len(sh.ReadC()))) + } + ca() + var err error + select { + case err = <-s.errC: + case <-time.After(DefaultTimeout): + s.Fail("timed out waiting for stream handler to close") + } + + sh.Close() + if err != nil { + s.Require().ErrorIs(err, context.Canceled) + } + if len(sh.WriteC()) != 0 { + logs := []string{"write channel contains unhandled server messages:"} + i := 0 + for msg := range sh.WriteC() { + logs = append(logs, fmt.Sprintf("[%d]: %s", i, msg.String())) + i++ + } + s.Fail(strings.Join(logs, "\n")) + } + }) + return sh +} + +func (s *StreamHandlerSuite) msgDownstreamConnected(streamID uint64) *extensions_ssh.ClientMessage { + return &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_Event{ + Event: &extensions_ssh.StreamEvent{ + Event: &extensions_ssh.StreamEvent_DownstreamConnected{ + DownstreamConnected: &extensions_ssh.DownstreamConnectEvent{ + StreamId: streamID, + }, + }, + }, + }, + } +} + +func (s *StreamHandlerSuite) msgDownstreamDisconnected(reason string) *extensions_ssh.ClientMessage { + return &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_Event{ + Event: &extensions_ssh.StreamEvent{ + Event: &extensions_ssh.StreamEvent_DownstreamDisconnected{ + DownstreamDisconnected: &extensions_ssh.DownstreamDisconnectedEvent{ + Reason: reason, + }, + }, + }, + }, + } +} + +func (s *StreamHandlerSuite) msgUpstreamConnected(streamID uint64) *extensions_ssh.ClientMessage { + return &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_Event{ + Event: &extensions_ssh.StreamEvent{ + Event: &extensions_ssh.StreamEvent_UpstreamConnected{ + UpstreamConnected: &extensions_ssh.UpstreamConnectEvent{ + StreamId: streamID, + }, + }, + }, + }, + } +} + +func (s *StreamHandlerSuite) expectAllowUpstream(sh *ssh.StreamHandler, hostname string) { + select { + case msg := <-sh.WriteC(): + if authResp := msg.GetAuthResponse(); authResp != nil { + if allow := authResp.GetAllow(); allow != nil { + s.Require().NotNil(allow.GetUpstream(), "received an allow response, but not to an upstream target") + s.Require().Equal(hostname, allow.GetUpstream().GetHostname()) + } else { + s.FailNowf("received an auth response, but it was not an allow response", authResp.String()) + } + } else { + s.FailNow("received a message, but it was not an auth response", msg.String()) + } + case <-time.After(DefaultTimeout): + s.FailNow("timed out waiting for upstream allow message") + } +} + +func (s *StreamHandlerSuite) expectDeny(sh *ssh.StreamHandler, partial bool, methods []string) { + select { + case msg := <-sh.WriteC(): + if authResp := msg.GetAuthResponse(); authResp != nil { + if deny := authResp.GetDeny(); deny != nil { + s.Require().Equal(partial, deny.Partial) + s.Require().Equal(methods, deny.Methods) + } else { + s.Require().Fail("received an auth response, but it was not a deny response", authResp.String()) + } + } else { + s.FailNow("received a message, but it was not an auth response", msg.String()) + } + case <-time.After(DefaultTimeout): + s.FailNow("timed out waiting for deny message") + } +} + +func (s *StreamHandlerSuite) expectAllowInternal(sh *ssh.StreamHandler) { + select { + case msg := <-sh.WriteC(): + if authResp := msg.GetAuthResponse(); authResp != nil { + if allow := authResp.GetAllow(); allow != nil { + s.Require().NotNil(allow.GetInternal(), "received an allow response, but not to an internal target") + } else { + s.FailNow("received an auth response, but it was not an allow response", authResp.String()) + } + } else { + s.FailNow("received a message, but it was not an auth response", msg.String()) + } + case <-time.After(DefaultTimeout): + s.FailNow("timed out waiting for internal allow message") + } +} + +func (s *StreamHandlerSuite) expectPrompt(sh *ssh.StreamHandler) { + select { + case msg := <-sh.WriteC(): + if authResp := msg.GetAuthResponse(); authResp != nil { + if info := authResp.GetInfoRequest(); info != nil { + s.Require().NotNil(info.GetRequest(), "received a nil info request") + } else { + s.FailNow("received an auth response, but it was not an info request", authResp.String()) + } + } else { + s.FailNow("received a message, but it was not an auth response", msg.String()) + } + case <-time.After(DefaultTimeout): + s.FailNow("timed out waiting for prompt message") + } +} + +func (s *StreamHandlerSuite) validPublicKeyMethodRequest() *anypb.Any { + return marshalAny(&extensions_ssh.PublicKeyMethodRequest{ + PublicKey: s.ed25519SshPublicKey.Marshal(), + PublicKeyAlg: s.ed25519SshPublicKey.Type(), + PublicKeyFingerprintSha256: []byte(gossh.FingerprintSHA256(s.ed25519SshPublicKey)), + }) +} + +// +// Tests +// + +func (s *StreamHandlerSuite) TestDuplicateDownstreamConnectedEvent() { + sh := s.startStreamHandler(1) + s.expectError(func() { + sh.ReadC() <- s.msgDownstreamConnected(1) + }, "received duplicate downstream connected event") +} + +func (s *StreamHandlerSuite) TestDownstreamDisconnectedEvent() { + sh := s.startStreamHandler(1) + sh.ReadC() <- s.msgDownstreamDisconnected("") // this just logs a message +} + +func (s *StreamHandlerSuite) TestUpstreamConnectedEvent() { + sh := s.startStreamHandler(1) + sh.ReadC() <- s.msgUpstreamConnected(1) // this just logs a message +} + +func (s *StreamHandlerSuite) TestInvalidEvent() { + sh := s.startStreamHandler(1) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_Event{ + Event: &extensions_ssh.StreamEvent{Event: nil}, + }, + } + }, "received invalid event") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidProtocol() { + sh := s.startStreamHandler(1) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "not-ssh", + Service: "ssh-connection", + }, + }, + } + }, "invalid protocol: not-ssh") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidService() { + sh := s.startStreamHandler(1) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-userauth", + }, + }, + } + }, "invalid service: ssh-userauth") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidMessage() { + sh := s.startStreamHandler(1) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: nil, + } + }, "received invalid message") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_FirstRequestIsKeyboardInteractive() { + sh := s.startStreamHandler(1) + s.expectError(func() { + // first request should be publickey + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "keyboard-interactive", + }, + }, + } + }, "unexpected auth method: keyboard-interactive") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_MissingUsername() { + sh := s.startStreamHandler(1) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "", + }, + }, + } + }, "username missing") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_EmptyHostname() { + sh := s.startStreamHandler(1) + + // empty hostname is allowed initially + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: s.ed25519SshPublicKey.Marshal(), + Permissions: &extensions_ssh.Permissions{}, + }}, nil) + + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + + s.expectAllowInternal(sh) +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_MismatchedAuthMethodAndRequestType() { + sh := s.startStreamHandler(1) + + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "", + MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), + }, + }, + } + }, "invalid public key method request type") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequest() { + sh := s.startStreamHandler(1) + + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: s.ed25519SshPublicKey.Marshal(), + Permissions: &extensions_ssh.Permissions{}, + }}, nil) + + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + + s.expectAllowUpstream(sh, "host1") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequestError() { + sh := s.startStreamHandler(1) + + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Return(ssh.PublicKeyAuthMethodResponse{}, errors.New("test error")) + + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + }, "test error") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_PublicKeyRetry() { + sh := s.startStreamHandler(1) + + i := -1 + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + MaxTimes(4). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + i++ + switch i { + case 0, 1, 2: + return ssh.PublicKeyAuthMethodResponse{ + RequireAdditionalMethods: []string{"publickey"}, + }, nil + case 3: + return ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: s.ed25519SshPublicKey.Marshal(), + Permissions: &extensions_ssh.Permissions{}, + }}, nil + default: + panic("unreachable") + } + }) + + for i := range 4 { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + if i < 3 { + s.expectDeny(sh, false, []string{"publickey"}) + } else { + s.expectAllowUpstream(sh, "host1") + } + } +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentUsername() { + sh := s.startStreamHandler(1) + + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(1). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + return ssh.PublicKeyAuthMethodResponse{ + RequireAdditionalMethods: []string{"publickey"}, + }, nil + }) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + s.expectDeny(sh, false, []string{"publickey"}) + s.Equal("test", *sh.Username()) + s.Equal("host1", *sh.Hostname()) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test2", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + }, "inconsistent username") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentHostname() { + sh := s.startStreamHandler(1) + + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(1). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + return ssh.PublicKeyAuthMethodResponse{ + RequireAdditionalMethods: []string{"publickey"}, + }, nil + }) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + s.expectDeny(sh, false, []string{"publickey"}) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host2", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + }, "inconsistent hostname") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentEmptyHostname() { + sh := s.startStreamHandler(1) + + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(1). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + return ssh.PublicKeyAuthMethodResponse{ + Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: req.PublicKey, + Permissions: &extensions_ssh.Permissions{}, + }, + RequireAdditionalMethods: []string{"keyboard-interactive"}, + }, nil + }) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + s.expectDeny(sh, true, []string{"keyboard-interactive"}) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "keyboard-interactive", + Username: "test", + Hostname: "host1", + MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), + }, + }, + } + }, "inconsistent hostname") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_UnknownAuthMethod() { + sh := s.startStreamHandler(1) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "password", + Username: "test", + Hostname: "host1", + }, + }, + } + }, "unexpected auth method: password") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_UnimplementedAuthMethod() { + sh := s.startStreamHandler(1) + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(1). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + return ssh.PublicKeyAuthMethodResponse{ + RequireAdditionalMethods: []string{"password"}, + }, nil + }) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + s.expectDeny(sh, false, []string{"password"}) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "password", + Username: "test", + Hostname: "host1", + }, + }, + } + }, "bug: server requested an unsupported auth method \"password\"") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_WrongClientMessage() { + sh := s.startStreamHandler(1) + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(1). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + return ssh.PublicKeyAuthMethodResponse{ + Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: req.PublicKey, + Permissions: &extensions_ssh.Permissions{}, + }, + RequireAdditionalMethods: []string{"keyboard-interactive"}, + }, nil + }) + newMsg := func() *extensions_ssh.ClientMessage_AuthRequest { + return &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + } + } + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: newMsg(), + } + s.expectDeny(sh, true, []string{"keyboard-interactive"}) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: newMsg(), + } + }, "unexpected auth method: publickey") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongMethodRequestType() { + sh := s.startStreamHandler(1) + + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(1). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + return ssh.PublicKeyAuthMethodResponse{ + Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: req.PublicKey, + Permissions: &extensions_ssh.Permissions{}, + }, + RequireAdditionalMethods: []string{"keyboard-interactive"}, + }, nil + }) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + s.expectDeny(sh, true, []string{"keyboard-interactive"}) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "keyboard-interactive", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + }, "invalid keyboard-interactive method request type") +} + +func init() { + setupKeyboardInteractive := func(s *StreamHandlerSuite, input []any) any { + querierErr, _ := input[0].(error) + sh := s.startStreamHandler(100) + + i := -1 + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(2). + DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + i++ + switch i { + case 0: + return ssh.PublicKeyAuthMethodResponse{ + RequireAdditionalMethods: []string{"publickey"}, + }, nil + case 1: + return ssh.PublicKeyAuthMethodResponse{ + Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: s.ed25519SshPublicKey.Marshal(), + Permissions: &extensions_ssh.Permissions{}, + }, + RequireAdditionalMethods: []string{"keyboard-interactive"}, + }, nil + default: + panic("unreachable") + } + }) + s.mockAuth.EXPECT(). + HandleKeyboardInteractiveMethodRequest(Any(), Any(), Any(), Any()). + DoAndReturn(func( + ctx context.Context, + info ssh.StreamAuthInfo, + _ *extensions_ssh.KeyboardInteractiveMethodRequest, + querier ssh.KeyboardInteractiveQuerier, + ) (ssh.KeyboardInteractiveAuthMethodResponse, error) { + s.Equal("test", *info.Username) + s.Equal("host1", *info.Hostname) + s.Equal(uint64(100), info.StreamID) + resp, err := querier.Prompt(ctx, &extensions_ssh.KeyboardInteractiveInfoPrompts{ + Name: "test-name", + Instruction: "test-instruction", + Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ + { + Prompt: "test-prompt", + Echo: true, + }, + }, + }) + s.Require().Equal(querierErr, err, "unexpected error from querier.Prompt") + if querierErr == nil { + s.Equal([]string{"test-prompt-response"}, resp.Responses) + return ssh.KeyboardInteractiveAuthMethodResponse{ + Allow: &extensions_ssh.KeyboardInteractiveAllowResponse{}, + }, nil + } + return ssh.KeyboardInteractiveAuthMethodResponse{}, err + }) + for range 2 { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "host1", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + } + s.expectDeny(sh, false, []string{"publickey"}) + s.expectDeny(sh, true, []string{"keyboard-interactive"}) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "keyboard-interactive", + Username: "test", + Hostname: "host1", + MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), + }, + }, + } + + return sh + } + StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive"] = HookWithArgs(setupKeyboardInteractive, (error)(nil)) + StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_NoPromptReply"] = HookWithArgs(setupKeyboardInteractive, context.Canceled) + StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.Internal, "received invalid info response")) + StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid prompt response")) + StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")) +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive() { + sh := s.BeforeTestHookResult.(*ssh.StreamHandler) + + s.expectPrompt(sh) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_InfoResponse{ + InfoResponse: &extensions_ssh.InfoResponse{ + Method: "keyboard-interactive", + Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{ + Responses: []string{"test-prompt-response"}, + }), + }, + }, + } + s.expectAllowUpstream(sh, "host1") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_NoPromptReply() { + sh := s.BeforeTestHookResult.(*ssh.StreamHandler) + s.expectPrompt(sh) +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse() { + sh := s.BeforeTestHookResult.(*ssh.StreamHandler) + s.expectPrompt(sh) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_InfoResponse{ + InfoResponse: &extensions_ssh.InfoResponse{ + Method: "publickey", + Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{ + Responses: []string{"test-prompt-response"}, + }), + }, + }, + } + }, "received invalid info response") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse() { + sh := s.BeforeTestHookResult.(*ssh.StreamHandler) + s.expectPrompt(sh) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_InfoResponse{ + InfoResponse: &extensions_ssh.InfoResponse{ + Method: "keyboard-interactive", + Response: nil, + }, + }, + } + }, "received invalid prompt response") +} + +func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType() { + sh := s.BeforeTestHookResult.(*ssh.StreamHandler) + s.expectPrompt(sh) + s.expectError(func() { + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "keyboard-interactive", + Username: "test", + Hostname: "host1", + MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), + }, + }, + } + }, "received invalid message, expecting info response") +} + +type mockGrpcServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (s *mockGrpcServerStream) Context() context.Context { + return s.ctx +} + +type mockChannelStream struct { + *grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage] + + closeServerToClientOnce sync.Once + serverToClient chan *extensions_ssh.ChannelMessage + closeClientToServerOnce sync.Once + clientToServer chan *extensions_ssh.ChannelMessage +} + +func newMockChannelStream(t *testing.T) *mockChannelStream { + cs := &mockChannelStream{ + GenericServerStream: &grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]{ + ServerStream: &mockGrpcServerStream{ + ctx: t.Context(), + }, + }, + serverToClient: make(chan *extensions_ssh.ChannelMessage, 32), + clientToServer: make(chan *extensions_ssh.ChannelMessage, 32), + } + t.Cleanup(func() { + cs.CloseClientToServer() + cs.CloseServerToClient() + }) + return cs +} + +func (cs *mockChannelStream) Send(msg *extensions_ssh.ChannelMessage) error { + cs.serverToClient <- msg + return nil +} + +func (cs *mockChannelStream) Recv() (*extensions_ssh.ChannelMessage, error) { + msg, ok := <-cs.clientToServer + if !ok { + return nil, io.EOF + } + return msg, nil +} + +func (cs *mockChannelStream) SendClientToServer(msg *extensions_ssh.ChannelMessage) { + cs.clientToServer <- msg +} + +func (cs *mockChannelStream) CloseClientToServer() { + cs.closeClientToServerOnce.Do(func() { + close(cs.clientToServer) + }) +} + +func (cs *mockChannelStream) CloseServerToClient() { + cs.closeServerToClientOnce.Do(func() { + close(cs.serverToClient) + }) +} + +func (cs *mockChannelStream) RecvServerToClient() (*extensions_ssh.ChannelMessage, error) { + select { + case msg, ok := <-cs.serverToClient: + if !ok { + return nil, io.EOF + } + return msg, nil + case <-time.After(DefaultTimeout): + return nil, errors.New("timed out waiting for server to send message") + } +} + +var _ extensions_ssh.StreamManagement_ServeChannelServer = (*mockChannelStream)(nil) + +func channelMsg(input any) *extensions_ssh.ChannelMessage { + return &extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes(gossh.Marshal(input)), + }, + } +} + +func recvChannelMsg[T any](s *StreamHandlerSuite, stream *mockChannelStream) T { + response, err := stream.RecvServerToClient() + s.Require().NoError(err) + var msg T + s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg)) + return msg +} + +func sendChannelMsg(stream *mockChannelStream, msg any) { + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: &wrapperspb.BytesValue{ + Value: gossh.Marshal(msg), + }, + }, + }) +} + +func (s *StreamHandlerSuite) TestServeChannel_InitialRecvError() { + sh := s.startStreamHandler(1) + + stream := newMockChannelStream(s.T()) + stream.CloseClientToServer() + s.Error(io.EOF, sh.ServeChannel(stream)) +} + +func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotRawBytes() { + sh := s.startStreamHandler(1) + + stream := newMockChannelStream(s.T()) + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_Metadata{}, + }) + s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream)) +} + +func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotChannelOpen() { + sh := s.startStreamHandler(1) + + stream := newMockChannelStream(s.T()) + stream.SendClientToServer(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: wrapperspb.Bytes([]byte("not ChannelOpen")), + }, + }) + s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream)) +} + +func init() { + hook := func(s *StreamHandlerSuite, args []any) any { + errorMatcher := args[0].(Matcher) + sh := s.startStreamHandler(1) + + s.mockAuth.EXPECT(). + HandlePublicKeyMethodRequest(Any(), Any(), Any()). + Times(1). + DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { + s.Equal("test", *info.Username) + s.Equal("", *info.Hostname) + return ssh.PublicKeyAuthMethodResponse{ + Allow: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: req.PublicKey, + Permissions: &extensions_ssh.Permissions{}, + }, + RequireAdditionalMethods: []string{}, + }, nil + }) + s.False(sh.IsExpectingInternalChannel()) + sh.ReadC() <- &extensions_ssh.ClientMessage{ + Message: &extensions_ssh.ClientMessage_AuthRequest{ + AuthRequest: &extensions_ssh.AuthenticationRequest{ + Protocol: "ssh", + Service: "ssh-connection", + AuthMethod: "publickey", + Username: "test", + Hostname: "", + MethodRequest: s.validPublicKeyMethodRequest(), + }, + }, + } + s.expectAllowInternal(sh) + s.True(sh.IsExpectingInternalChannel()) + s.Equal("test", *sh.Username()) + s.Equal("", *sh.Hostname()) + + stream := newMockChannelStream(s.T()) + errC := make(chan error, 1) + go func() { + errC <- sh.ServeChannel(stream) + stream.CloseServerToClient() + }() + s.cleanup = append(s.cleanup, func() { + stream.CloseClientToServer() + select { + case err := <-errC: + s.Truef(errorMatcher.Matches(err), "expected: %v\nactual: %v", errorMatcher.String(), err) + case <-time.After(DefaultTimeout): + s.FailNow("timed out waiting for ServeChannel to exit") + } + }) + return stream + } + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = HookWithArgs(hook, Not(Nil())) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = HookWithArgs(hook, Not(Nil())) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = HookWithArgs(hook, Eq(status.Errorf(codes.PermissionDenied, "test error"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = HookWithArgs(hook, Nil()) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_WhoamiError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Logout"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_LogoutError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_NonInteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_InteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ChannelCloseResponseTimeout"] = HookWithArgs(hook, Eq(status.Errorf(codes.DeadlineExceeded, "timed out waiting for channel close"))) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize) + s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow) + s.Equal(uint32(2), resp.PeersID) + s.Equal(uint32(1), resp.MyID) + sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id + recvChannelMsg[ssh.ChannelCloseMsg](s, stream) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_DifferentWindowAndPacketSizes() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, // client id + PeersWindow: ssh.ChannelWindowSize / 2, + MaxPacketSize: ssh.ChannelMaxPacket / 2, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize) + s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow) + s.Equal(uint32(2), resp.PeersID) // client id + s.Equal(uint32(1), resp.MyID) // server id + sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id + recvChannelMsg[ssh.ChannelCloseMsg](s, stream) +} + +func (s *StreamHandlerSuite) channelDataLoop(peerID uint32, stream *mockChannelStream, exitCode ...uint32) *bytes.Buffer { + s.T().Helper() + var channelData bytes.Buffer + for { + response, err := stream.RecvServerToClient() + if errors.Is(err, io.EOF) { + break + } + s.Require().NoError(err) + bytes := response.GetRawBytes().GetValue() + switch bytes[0] { + case ssh.MsgChannelData: + var msg ssh.ChannelDataMsg + s.Require().NoError(gossh.Unmarshal(bytes, &msg)) + channelData.Write(msg.Rest) + case ssh.MsgChannelRequest: + var msg ssh.ChannelRequestMsg + s.Require().NoError(gossh.Unmarshal(bytes, &msg)) + s.Equal("exit-status", msg.Request) + s.Require().NotEmpty(exitCode, "received an exit-status ChannelRequest but the test did not assert an exit code") + expected := exitCode[0] + actual := binary.BigEndian.Uint32(msg.RequestSpecificData) + s.Equal(expected, actual) + case ssh.MsgChannelClose: + sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID}) + } + } + return &channelData +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_ExecWithPtyHelp() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "pty-req", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "exec", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ + Command: "--help", + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + + maybeRoutesPortalCmd := "" + if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { + maybeRoutesPortalCmd = ` portal Interactive route portal +` + } + channelData := s.channelDataLoop(peerID, stream, 0) + s.Equal(` +Usage: + pomerium [command] + +Available Commands: + help Help about any command + logout Log out +`[1:]+maybeRoutesPortalCmd+ + ` whoami Show details for the current session + +Flags: + -h, --help help for pomerium + +Use "pomerium [command] --help" for more information about a command. +`, channelData.String()) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_ChannelCloseResponseTimeout() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "pty-req", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "exec", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ + Command: "--help", + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + for { + response, err := stream.RecvServerToClient() + if errors.Is(err, io.EOF) { + break + } + s.Require().NoError(err) + bytes := response.GetRawBytes().GetValue() + switch bytes[0] { + case ssh.MsgChannelData: + var msg ssh.ChannelDataMsg + s.Require().NoError(gossh.Unmarshal(bytes, &msg)) + case ssh.MsgChannelClose: + // don't send a response + } + } +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_NonInteractiveError() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "shell", + WantReply: true, + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + + if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { + channelData := s.channelDataLoop(peerID, stream, 1) + s.Equal("Error: 'portal' is an interactive command and requires a TTY (try passing '-t' to ssh)\n", + ansi.Strip(channelData.String())) + } else { + channelData := s.channelDataLoop(peerID, stream, 0) + s.Equal(` +Usage: + pomerium [command] + +Available Commands: + help Help about any command + logout Log out + whoami Show details for the current session + +Flags: + -h, --help help for pomerium + +Use "pomerium [command] --help" for more information about a command. +`[1:], channelData.String()) + } +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_InteractiveError() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "pty-req", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "exec", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ + Command: "whoami", + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + channelData := s.channelDataLoop(peerID, stream, 1) + s.Equal("Error: 'whoami' is not an interactive command (try passing '-T' to ssh, or removing '-t')\r\n", + ansi.Strip(channelData.String())) +} + +func printFrame(in string) string { + re := strings.NewReplacer(" ", "Ā·", "\t", "šŸ”’", "\n", "\n⤶", "\r", "⇤") + return re.Replace(ansi.Strip(in)) +} + +func postProcessFrame(in string) string { + return strings.ReplaceAll(ansi.Strip(in), "\r", "") +} + +type routesPortalTestHookOutput struct { + stream *mockChannelStream + peerID uint32 +} + +func init() { + hook := func(s *StreamHandlerSuite) any { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "pty-req", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{ + TermEnv: "dumb", + Width: 39, + Height: 10, + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "shell", + WantReply: true, + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + + if !s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { + channelData := s.channelDataLoop(peerID, stream, 0) + s.Equal(` +Usage: + pomerium [command] + +Available Commands: + help Help about any command + logout Log out + whoami Show details for the current session + +Flags: + -h, --help help for pomerium + +Use "pomerium [command] --help" for more information about a command. +`[1:], channelData.String()) + return nil + } + return &routesPortalTestHookOutput{ + stream, + peerID, + } + } + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"], hook) + StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"], hook) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal() { + res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput) + if res == nil { + return // routes portal disabled + } + stream, peerID := res.stream, res.peerID + + frames := []string{ + ` +|| +| Connect to which server? | +| | +| > 1. test@host1 | +| 2. test@host2 | +| | +| | +| ↑/k up • ↓/j down • q quit • ? more| +| |`[1:], + ` +|| +|| +|| +| 1. test@host1 | +| > 2. test@host2 | +|| +|| +|| +||`[1:], + } + for i, frame := range frames { + frames[i] = strings.ReplaceAll(frame, "|", "") + } + var ok bool + var channelData bytes.Buffer + currentFrame := 0 + start := time.Now() + frameAdvance := func() { + switch currentFrame { + case 0: + cursorDown := []byte(ansi.CursorDown(1)) + currentFrame++ + sendChannelMsg(stream, ssh.ChannelDataMsg{ + PeersID: peerID, + Length: uint32(len(cursorDown)), + Rest: cursorDown, + }) + case 1: + currentFrame++ + ok = true + sendChannelMsg(stream, ssh.ChannelDataMsg{ + PeersID: peerID, + Length: uint32(1), + Rest: []byte("q"), + }) + } + channelData.Reset() + } +LOOP: + for time.Since(start) < DefaultTimeout { + response, err := stream.RecvServerToClient() + if err != nil { + s.Fail(err.Error()) + break + } + + bytes := response.GetRawBytes().GetValue() + switch bytes[0] { + case ssh.MsgChannelData: + if ok { + continue + } + var msg ssh.ChannelDataMsg + s.Require().NoError(gossh.Unmarshal(bytes, &msg)) + channelData.Write(msg.Rest) + if postProcessFrame(channelData.String()) == frames[currentFrame] { + frameAdvance() + if currentFrame >= len(frames) { + ok = true + } + } + case ssh.MsgChannelRequest: + // the only channel request we expect to send would be "exit-status" + var msg ssh.ChannelRequestMsg + s.Require().NoError(gossh.Unmarshal(bytes, &msg)) + s.Equal("exit-status", msg.Request) + s.Equal(uint32(0), binary.BigEndian.Uint32(msg.RequestSpecificData)) + case ssh.MsgChannelClose: + sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID}) + break LOOP + default: + s.FailNow("test bug") + } + } + currentFrameStr := "" + if !ok { + currentFrameStr = printFrame(frames[currentFrame]) + } + s.Require().Truef(ok, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s", + currentFrame, + printFrame(postProcessFrame(channelData.String())), + currentFrameStr) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_Select() { + res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput) + if res == nil { + return // routes portal disabled + } + stream, peerID := res.stream, res.peerID + + frames := []string{ + ` +|| +| Connect to which server? | +| | +| > 1. test@host1 | +| 2. test@host2 | +| | +| | +| ↑/k up • ↓/j down • q quit • ? more| +| |`[1:], + ` +|| +|| +|| +| 1. test@host1 | +| > 2. test@host2 | +|| +|| +|| +||`[1:], + ` +|| +| Connect to which server? | +| | +| 1. test@host1 | +| > 2. test@host2 | +| | +| | +| ↑/k up • ↓/j down • q quit …| +| |`[1:], + } + for i, frame := range frames { + frames[i] = strings.ReplaceAll(frame, "|", "") + } + var portalOk bool + var handoffOk bool + var expectHandoff bool + + var channelData bytes.Buffer + currentFrame := 0 + start := time.Now() + frameAdvance := func() { + switch currentFrame { + case 0: + cursorDown := []byte(ansi.CursorDown(1)) + currentFrame++ + sendChannelMsg(stream, ssh.ChannelDataMsg{ + PeersID: peerID, + Length: uint32(len(cursorDown)), + Rest: cursorDown, + }) + case 1: + currentFrame++ + + sendChannelMsg(stream, ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "window-change", + WantReply: false, + RequestSpecificData: gossh.Marshal(ssh.ChannelWindowChangeRequestMsg{ + WidthColumns: 36, + HeightRows: 10, + }), + }) + case 2: + currentFrame++ + s.mockAuth.EXPECT().EvaluateDelayed(Any(), Any()). + DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo) error { + s.Equal(info.Username, ptr("test")) + s.Equal(info.Hostname, ptr("host2")) + return nil + }) + expectHandoff = true + sendChannelMsg(stream, ssh.ChannelDataMsg{ + PeersID: peerID, + Length: uint32(1), + Rest: []byte("\r"), + }) + } + channelData.Reset() + } +LOOP: + for time.Since(start) < DefaultTimeout { + response, err := stream.RecvServerToClient() + if err != nil { + s.Fail(err.Error()) + break + } + + if expectHandoff { + if response.GetRawBytes() != nil { + // we might get bytes containing a newline + var msg ssh.ChannelDataMsg + s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg)) + s.Require().Empty(strings.TrimSpace(ansi.Strip(string(msg.Rest)))) + continue + } + action := response.GetChannelControl().GetControlAction() + s.Require().NotNil(action, "expected channel control action") + var sshAction extensions_ssh.SSHChannelControlAction + s.Require().NoError(action.UnmarshalTo(&sshAction)) + handoff := sshAction.GetHandOff() + s.Require().NotNil(action, "expected handoff action") + s.Require().NotNil(handoff.GetUpstreamAuth().GetUpstream(), "expected upstream handoff action") + s.Equal("test", handoff.GetUpstreamAuth().Username) + s.Equal("host2", handoff.GetUpstreamAuth().GetUpstream().Hostname) + testutil.AssertProtoEqual(s.T(), []*extensions_ssh.AllowedMethod{ + { + Method: "publickey", + MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{ + PublicKey: s.ed25519SshPublicKey.Marshal(), + Permissions: &extensions_ssh.Permissions{}, + }), + }, + }, handoff.GetUpstreamAuth().GetUpstream().AllowedMethods) + handoffOk = true + break LOOP + } + + bytes := response.GetRawBytes().GetValue() + s.Require().NotNil(bytes, response.String()) + switch bytes[0] { + case ssh.MsgChannelData: + if portalOk { + continue + } + s.Require().False(expectHandoff) + + var msg ssh.ChannelDataMsg + s.Require().NoError(gossh.Unmarshal(bytes, &msg)) + channelData.Write(msg.Rest) + if postProcessFrame(channelData.String()) == frames[currentFrame] { + frameAdvance() + if currentFrame >= len(frames) { + portalOk = true + } + } + default: + s.FailNow("test bug") + } + } + currentFrameStr := "" + if !portalOk { + currentFrameStr = printFrame(frames[currentFrame]) + } + s.Truef(portalOk, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s", + currentFrame, + printFrame(postProcessFrame(channelData.String())), + currentFrameStr) + s.True(handoffOk, "timed out waiting for handoff") + sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID}) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Whoami() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + + s.mockAuth.EXPECT(). + FormatSession(Any(), Any()). + Return([]byte("example"), nil) + + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "exec", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ + Command: "whoami", + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + + channelData := s.channelDataLoop(peerID, stream, 0) + s.Equal("example", channelData.String()) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_WhoamiError() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + + s.mockAuth.EXPECT(). + FormatSession(Any(), Any()). + Return(nil, errors.New("test error")) + + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "exec", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ + Command: "whoami", + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + + channelData := s.channelDataLoop(peerID, stream, 1) + s.Equal("Error: couldn't fetch session: test error\r\n", channelData.String()) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Logout() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + + s.mockAuth.EXPECT(). + DeleteSession(Any(), Any()). + Return(nil) + + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "exec", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ + Command: "logout", + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + + channelData := s.channelDataLoop(peerID, stream, 0) + s.Equal("Logged out successfully\r\n", channelData.String()) +} + +func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_LogoutError() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "session", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) + peerID := resp.MyID + + s.mockAuth.EXPECT(). + DeleteSession(Any(), Any()). + Return(errors.New("test error")) + + stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ + PeersID: peerID, + Request: "exec", + WantReply: true, + RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ + Command: "logout", + }), + })) + recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) + + channelData := s.channelDataLoop(peerID, stream, 1) + s.Equal("Error: failed to delete session: test error\r\n", channelData.String()) +} + +func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_NoSubMsg() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "direct-tcpip", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + // error checked in cleanup +} + +func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_BadHostname() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "direct-tcpip", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{ + DestAddr: "", // invalid + DestPort: 22, + SrcAddr: "127.0.0.1", + SrcPort: 12345, + }), + })) + // error checked in cleanup +} + +func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() { + s.mockAuth.EXPECT(). + EvaluateDelayed(Any(), Any()). + Times(1). + Return(errors.New("test error")) + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "direct-tcpip", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{ + DestAddr: "host1", + DestPort: 22, + SrcAddr: "127.0.0.1", + SrcPort: 12345, + }), + })) + // error checked in cleanup +} + +func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() { + s.mockAuth.EXPECT(). + EvaluateDelayed(Any(), Any()). + Times(1). + Return(nil) + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "direct-tcpip", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{ + DestAddr: "host1", // i.e. 'ssh -J pomerium test@host1' + DestPort: 22, // this will be sent by the ssh client, but is ignored + SrcAddr: "127.0.0.1", + SrcPort: 12345, + }), + })) + recv, err := stream.RecvServerToClient() + s.Require().NoError(err) + action := recv.GetChannelControl().GetControlAction() + s.Require().NotNil(action, "received a message, but it was not a channel control action") + handoff := extensions_ssh.SSHChannelControlAction{} + s.Require().NoError(action.UnmarshalTo(&handoff)) + testutil.AssertProtoEqual(s.T(), extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: &extensions_ssh.SSHDownstreamChannelInfo{ + ChannelType: "direct-tcpip", + DownstreamChannelId: 2, + InternalUpstreamChannelId: 1, + InitialWindowSize: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + }, + DownstreamPtyInfo: nil, + UpstreamAuth: &extensions_ssh.AllowResponse{ + Username: "test", + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + Hostname: "host1", + DirectTcpip: true, + AllowedMethods: []*extensions_ssh.AllowedMethod{ + { + Method: "publickey", + MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{ + PublicKey: s.ed25519SshPublicKey.Marshal(), + Permissions: &extensions_ssh.Permissions{}, + }), + }, + }, + }, + }, + }, + }, handoff.GetHandOff()) +} + +func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() { + stream := s.BeforeTestHookResult.(*mockChannelStream) + stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ + ChanType: "unknown", + PeersID: 2, + PeersWindow: ssh.ChannelWindowSize, + MaxPacketSize: ssh.ChannelMaxPacket, + })) + // error checked in cleanup +} + +func (s *StreamHandlerSuite) TestFormatSession() { + s.mockAuth.EXPECT(). + FormatSession(Any(), Any()). + Return([]byte("example"), nil) + sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + ctx, ca := context.WithCancel(context.Background()) + ca() + // this will exit immediately, but it will have a state, which is only + // created upon calling Run() + sh.Run(ctx) + + res, err := sh.FormatSession(s.T().Context()) + s.NoError(err) + s.Equal([]byte("example"), res) +} + +func (s *StreamHandlerSuite) TestDeleteSession() { + s.mockAuth.EXPECT(). + DeleteSession(Any(), Any()). + Return(nil) + sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + ctx, ca := context.WithCancel(context.Background()) + ca() + // this will exit immediately, but it will have a state, which is only + // created upon calling Run() + sh.Run(ctx) + + err := sh.DeleteSession(s.T().Context()) + s.NoError(err) +} + +func (s *StreamHandlerSuite) TestRunCalledTwice() { + sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + ctx, ca := context.WithCancel(context.Background()) + ca() + sh.Run(ctx) + s.PanicsWithValue("Run called twice", func() { + sh.Run(context.Background()) + }) +} + +func (s *StreamHandlerSuite) TestAllSSHRoutes() { + sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + routes := slices.Collect(sh.AllSSHRoutes()) + s.Len(routes, 2) + s.Equal("ssh://host1", routes[0].From) + s.Equal("ssh://dest1:22", routes[0].To[0].String()) + s.Equal("ssh://host2", routes[1].From) + s.Equal("ssh://dest2:22", routes[1].To[0].String()) + + next, stop := iter.Pull(sh.AllSSHRoutes()) + v, ok := next() + s.NotNil(v) + s.True(ok) + stop() + v, ok = next() + s.Nil(v) + s.False(ok) +} + +func TestStreamHandlerSuite(t *testing.T) { + suite.Run(t, &StreamHandlerSuite{}) +} + +func TestStreamHandlerSuiteWithRuntimeFlags(t *testing.T) { + suite.Run(t, &StreamHandlerSuite{ + StreamHandlerSuiteOptions: StreamHandlerSuiteOptions{ + ConfigModifiers: []func(*config.Config){ + func(c *config.Config) { + c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true + }, + }, + }, + }) +} + +func ptr[T any](t T) *T { + return &t +}