package ssh_test import ( "bytes" "context" "crypto/ed25519" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "iter" "os" "runtime" "slices" "strings" "sync" "testing" "time" "github.com/charmbracelet/x/ansi" "github.com/stretchr/testify/suite" . "go.uber.org/mock/gomock" //nolint gossh "golang.org/x/crypto/ssh" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/wrapperspb" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/ssh" mock_ssh "github.com/pomerium/pomerium/pkg/ssh/mock" ) var DefaultTimeout = 10 * time.Second func init() { if isDebuggerAttached() { DefaultTimeout = 1 * time.Hour } } func isDebuggerAttached() bool { if runtime.GOOS == "linux" { data, err := os.ReadFile("/proc/self/status") if err == nil { for line := range bytes.Lines(data) { if bytes.HasPrefix(line, []byte("TracerPid:\t")) { return line[11] != '0' } } } } return false } func HookWithArgs(f func(s *StreamHandlerSuite, args []any) any, args ...any) []func(s *StreamHandlerSuite) any { return []func(s *StreamHandlerSuite) any{ func(s *StreamHandlerSuite) any { return f(s, args) }, } } var ( StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{} StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{} ) type StreamHandlerSuiteOptions struct { ConfigModifiers []func(*config.Config) } type StreamHandlerSuite struct { suite.Suite StreamHandlerSuiteOptions ctrl *Controller mgr *ssh.StreamManager cfg *config.Config cleanup []func() errC chan error mockAuth *mock_ssh.MockAuthInterface ed25519PublicKey ed25519.PublicKey ed25519PrivateKey ed25519.PrivateKey ed25519SshPublicKey gossh.PublicKey ed25519SshPrivateKey gossh.Signer BeforeTestHookResult any } func (s *StreamHandlerSuite) SetupTest() { s.ctrl = NewController(s.T()) s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl) s.mgr = ssh.NewStreamManager() s.cleanup = []func(){} s.errC = make(chan error, 1) var err error s.ed25519PublicKey, s.ed25519PrivateKey, err = ed25519.GenerateKey(rand.Reader) s.Require().NoError(err) s.ed25519SshPublicKey, err = gossh.NewPublicKey(s.ed25519PublicKey) s.Require().NoError(err) s.ed25519SshPrivateKey, err = gossh.NewSignerFromKey(s.ed25519PrivateKey) s.Require().NoError(err) s.cfg = &config.Config{Options: config.NewDefaultOptions()} s.cfg.Options.Policies = []config.Policy{ {From: "https://from.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to.notssh.example.com")}, {From: "ssh://host1", To: mustParseWeightedURLs(s.T(), "ssh://dest1:22")}, {From: "https://from1.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to1.notssh.example.com")}, {From: "ssh://host2", To: mustParseWeightedURLs(s.T(), "ssh://dest2:22")}, {From: "https://from2.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to2.notssh.example.com")}, } for _, f := range s.ConfigModifiers { f(s.cfg) } } func (s *StreamHandlerSuite) TearDownTest() { for _, f := range s.cleanup { f() } s.ctrl.Finish() } func (s *StreamHandlerSuite) BeforeTest(_, testName string) { s.BeforeTestHookResult = nil for _, fn := range StreamHandlerSuiteBeforeTestHooks[testName] { s.BeforeTestHookResult = fn(s) } } // // Helper methods // func marshalAny(msg proto.Message) *anypb.Any { a, err := anypb.New(msg) if err != nil { panic(err) } return a } func (s *StreamHandlerSuite) expectError(fn func(), msg string) { fn() select { case err := <-s.errC: s.ErrorContains(err, msg) case <-time.After(DefaultTimeout): s.FailNowf("timed out waiting for error %q", msg) } } func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler { sh := s.mgr.NewStreamHandler( s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID}) s.errC = make(chan error, 1) ctx, ca := context.WithCancel(s.T().Context()) go func() { defer close(s.errC) s.errC <- sh.Run(ctx) }() s.cleanup = append(s.cleanup, func() { start := time.Now() for len(sh.ReadC()) > 0 && time.Since(start) < 100*time.Millisecond { runtime.Gosched() } if len(sh.ReadC()) > 0 { s.Fail(fmt.Sprintf("read channel contains %d unhandled client messages", len(sh.ReadC()))) } ca() var err error select { case err = <-s.errC: case <-time.After(DefaultTimeout): s.Fail("timed out waiting for stream handler to close") } sh.Close() if err != nil { s.Require().ErrorIs(err, context.Canceled) } if len(sh.WriteC()) != 0 { logs := []string{"write channel contains unhandled server messages:"} i := 0 for msg := range sh.WriteC() { logs = append(logs, fmt.Sprintf("[%d]: %s", i, msg.String())) i++ } s.Fail(strings.Join(logs, "\n")) } }) return sh } func (s *StreamHandlerSuite) msgDownstreamConnected(streamID uint64) *extensions_ssh.ClientMessage { return &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_Event{ Event: &extensions_ssh.StreamEvent{ Event: &extensions_ssh.StreamEvent_DownstreamConnected{ DownstreamConnected: &extensions_ssh.DownstreamConnectEvent{ StreamId: streamID, }, }, }, }, } } func (s *StreamHandlerSuite) msgDownstreamDisconnected(reason string) *extensions_ssh.ClientMessage { return &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_Event{ Event: &extensions_ssh.StreamEvent{ Event: &extensions_ssh.StreamEvent_DownstreamDisconnected{ DownstreamDisconnected: &extensions_ssh.DownstreamDisconnectedEvent{ Reason: reason, }, }, }, }, } } func (s *StreamHandlerSuite) msgUpstreamConnected(streamID uint64) *extensions_ssh.ClientMessage { return &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_Event{ Event: &extensions_ssh.StreamEvent{ Event: &extensions_ssh.StreamEvent_UpstreamConnected{ UpstreamConnected: &extensions_ssh.UpstreamConnectEvent{ StreamId: streamID, }, }, }, }, } } func (s *StreamHandlerSuite) expectAllowUpstream(sh *ssh.StreamHandler, hostname string) { select { case msg := <-sh.WriteC(): if authResp := msg.GetAuthResponse(); authResp != nil { if allow := authResp.GetAllow(); allow != nil { s.Require().NotNil(allow.GetUpstream(), "received an allow response, but not to an upstream target") s.Require().Equal(hostname, allow.GetUpstream().GetHostname()) } else { s.FailNowf("received an auth response, but it was not an allow response", authResp.String()) } } else { s.FailNow("received a message, but it was not an auth response", msg.String()) } case <-time.After(DefaultTimeout): s.FailNow("timed out waiting for upstream allow message") } } func (s *StreamHandlerSuite) expectDeny(sh *ssh.StreamHandler, partial bool, methods []string) { select { case msg := <-sh.WriteC(): if authResp := msg.GetAuthResponse(); authResp != nil { if deny := authResp.GetDeny(); deny != nil { s.Require().Equal(partial, deny.Partial) s.Require().Equal(methods, deny.Methods) } else { s.Require().Fail("received an auth response, but it was not a deny response", authResp.String()) } } else { s.FailNow("received a message, but it was not an auth response", msg.String()) } case <-time.After(DefaultTimeout): s.FailNow("timed out waiting for deny message") } } func (s *StreamHandlerSuite) expectAllowInternal(sh *ssh.StreamHandler) { select { case msg := <-sh.WriteC(): if authResp := msg.GetAuthResponse(); authResp != nil { if allow := authResp.GetAllow(); allow != nil { s.Require().NotNil(allow.GetInternal(), "received an allow response, but not to an internal target") } else { s.FailNow("received an auth response, but it was not an allow response", authResp.String()) } } else { s.FailNow("received a message, but it was not an auth response", msg.String()) } case <-time.After(DefaultTimeout): s.FailNow("timed out waiting for internal allow message") } } func (s *StreamHandlerSuite) expectPrompt(sh *ssh.StreamHandler) { select { case msg := <-sh.WriteC(): if authResp := msg.GetAuthResponse(); authResp != nil { if info := authResp.GetInfoRequest(); info != nil { s.Require().NotNil(info.GetRequest(), "received a nil info request") } else { s.FailNow("received an auth response, but it was not an info request", authResp.String()) } } else { s.FailNow("received a message, but it was not an auth response", msg.String()) } case <-time.After(DefaultTimeout): s.FailNow("timed out waiting for prompt message") } } func (s *StreamHandlerSuite) validPublicKeyMethodRequest() *anypb.Any { return marshalAny(&extensions_ssh.PublicKeyMethodRequest{ PublicKey: s.ed25519SshPublicKey.Marshal(), PublicKeyAlg: s.ed25519SshPublicKey.Type(), PublicKeyFingerprintSha256: []byte(gossh.FingerprintSHA256(s.ed25519SshPublicKey)), }) } // // Tests // func (s *StreamHandlerSuite) TestDuplicateDownstreamConnectedEvent() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- s.msgDownstreamConnected(1) }, "received duplicate downstream connected event") } func (s *StreamHandlerSuite) TestDownstreamDisconnectedEvent() { sh := s.startStreamHandler(1) sh.ReadC() <- s.msgDownstreamDisconnected("") // this just logs a message } func (s *StreamHandlerSuite) TestUpstreamConnectedEvent() { sh := s.startStreamHandler(1) sh.ReadC() <- s.msgUpstreamConnected(1) // this just logs a message } func (s *StreamHandlerSuite) TestInvalidEvent() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_Event{ Event: &extensions_ssh.StreamEvent{Event: nil}, }, } }, "received invalid event") } func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidProtocol() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "not-ssh", Service: "ssh-connection", }, }, } }, "invalid protocol: not-ssh") } func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidService() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-userauth", }, }, } }, "invalid service: ssh-userauth") } func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidMessage() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: nil, } }, "received invalid message") } func (s *StreamHandlerSuite) TestHandleAuthRequest_FirstRequestIsKeyboardInteractive() { sh := s.startStreamHandler(1) s.expectError(func() { // first request should be publickey sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "keyboard-interactive", }, }, } }, "unexpected auth method: keyboard-interactive") } func (s *StreamHandlerSuite) TestHandleAuthRequest_MissingUsername() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "", }, }, } }, "username missing") } func (s *StreamHandlerSuite) TestHandleAuthRequest_EmptyHostname() { sh := s.startStreamHandler(1) // empty hostname is allowed initially s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: s.ed25519SshPublicKey.Marshal(), Permissions: &extensions_ssh.Permissions{}, }}, nil) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectAllowInternal(sh) } func (s *StreamHandlerSuite) TestHandleAuthRequest_MismatchedAuthMethodAndRequestType() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "", MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), }, }, } }, "invalid public key method request type") } func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequest() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: s.ed25519SshPublicKey.Marshal(), Permissions: &extensions_ssh.Permissions{}, }}, nil) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectAllowUpstream(sh, "host1") } func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequestError() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Return(ssh.PublicKeyAuthMethodResponse{}, errors.New("test error")) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } }, "test error") } func (s *StreamHandlerSuite) TestHandleAuthRequest_PublicKeyRetry() { sh := s.startStreamHandler(1) i := -1 s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). MaxTimes(4). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { i++ switch i { case 0, 1, 2: return ssh.PublicKeyAuthMethodResponse{ RequireAdditionalMethods: []string{"publickey"}, }, nil case 3: return ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: s.ed25519SshPublicKey.Marshal(), Permissions: &extensions_ssh.Permissions{}, }}, nil default: panic("unreachable") } }) for i := range 4 { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } if i < 3 { s.expectDeny(sh, false, []string{"publickey"}) } else { s.expectAllowUpstream(sh, "host1") } } } func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentUsername() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(1). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { return ssh.PublicKeyAuthMethodResponse{ RequireAdditionalMethods: []string{"publickey"}, }, nil }) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectDeny(sh, false, []string{"publickey"}) s.Equal("test", *sh.Username()) s.Equal("host1", *sh.Hostname()) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test2", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } }, "inconsistent username") } func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentHostname() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(1). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { return ssh.PublicKeyAuthMethodResponse{ RequireAdditionalMethods: []string{"publickey"}, }, nil }) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectDeny(sh, false, []string{"publickey"}) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host2", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } }, "inconsistent hostname") } func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentEmptyHostname() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(1). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { return ssh.PublicKeyAuthMethodResponse{ Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: req.PublicKey, Permissions: &extensions_ssh.Permissions{}, }, RequireAdditionalMethods: []string{"keyboard-interactive"}, }, nil }) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectDeny(sh, true, []string{"keyboard-interactive"}) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "keyboard-interactive", Username: "test", Hostname: "host1", MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), }, }, } }, "inconsistent hostname") } func (s *StreamHandlerSuite) TestHandleAuthRequest_UnknownAuthMethod() { sh := s.startStreamHandler(1) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "password", Username: "test", Hostname: "host1", }, }, } }, "unexpected auth method: password") } func (s *StreamHandlerSuite) TestHandleAuthRequest_UnimplementedAuthMethod() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(1). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { return ssh.PublicKeyAuthMethodResponse{ RequireAdditionalMethods: []string{"password"}, }, nil }) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectDeny(sh, false, []string{"password"}) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "password", Username: "test", Hostname: "host1", }, }, } }, "bug: server requested an unsupported auth method \"password\"") } func (s *StreamHandlerSuite) TestHandleAuthRequest_WrongClientMessage() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(1). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { return ssh.PublicKeyAuthMethodResponse{ Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: req.PublicKey, Permissions: &extensions_ssh.Permissions{}, }, RequireAdditionalMethods: []string{"keyboard-interactive"}, }, nil }) newMsg := func() *extensions_ssh.ClientMessage_AuthRequest { return &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, } } sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: newMsg(), } s.expectDeny(sh, true, []string{"keyboard-interactive"}) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: newMsg(), } }, "unexpected auth method: publickey") } func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongMethodRequestType() { sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(1). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { return ssh.PublicKeyAuthMethodResponse{ Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: req.PublicKey, Permissions: &extensions_ssh.Permissions{}, }, RequireAdditionalMethods: []string{"keyboard-interactive"}, }, nil }) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectDeny(sh, true, []string{"keyboard-interactive"}) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "keyboard-interactive", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } }, "invalid keyboard-interactive method request type") } func init() { setupKeyboardInteractive := func(s *StreamHandlerSuite, input []any) any { querierErr, _ := input[0].(error) sh := s.startStreamHandler(100) i := -1 s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(2). DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { i++ switch i { case 0: return ssh.PublicKeyAuthMethodResponse{ RequireAdditionalMethods: []string{"publickey"}, }, nil case 1: return ssh.PublicKeyAuthMethodResponse{ Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: s.ed25519SshPublicKey.Marshal(), Permissions: &extensions_ssh.Permissions{}, }, RequireAdditionalMethods: []string{"keyboard-interactive"}, }, nil default: panic("unreachable") } }) s.mockAuth.EXPECT(). HandleKeyboardInteractiveMethodRequest(Any(), Any(), Any(), Any()). DoAndReturn(func( ctx context.Context, info ssh.StreamAuthInfo, _ *extensions_ssh.KeyboardInteractiveMethodRequest, querier ssh.KeyboardInteractiveQuerier, ) (ssh.KeyboardInteractiveAuthMethodResponse, error) { s.Equal("test", *info.Username) s.Equal("host1", *info.Hostname) s.Equal(uint64(100), info.StreamID) resp, err := querier.Prompt(ctx, &extensions_ssh.KeyboardInteractiveInfoPrompts{ Name: "test-name", Instruction: "test-instruction", Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ { Prompt: "test-prompt", Echo: true, }, }, }) s.Require().Equal(querierErr, err, "unexpected error from querier.Prompt") if querierErr == nil { s.Equal([]string{"test-prompt-response"}, resp.Responses) return ssh.KeyboardInteractiveAuthMethodResponse{ Allow: &extensions_ssh.KeyboardInteractiveAllowResponse{}, }, nil } return ssh.KeyboardInteractiveAuthMethodResponse{}, err }) for range 2 { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "host1", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } } s.expectDeny(sh, false, []string{"publickey"}) s.expectDeny(sh, true, []string{"keyboard-interactive"}) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "keyboard-interactive", Username: "test", Hostname: "host1", MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), }, }, } return sh } StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive"] = HookWithArgs(setupKeyboardInteractive, (error)(nil)) StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_NoPromptReply"] = HookWithArgs(setupKeyboardInteractive, context.Canceled) StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.Internal, "received invalid info response")) StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid prompt response")) StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")) } func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive() { sh := s.BeforeTestHookResult.(*ssh.StreamHandler) s.expectPrompt(sh) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_InfoResponse{ InfoResponse: &extensions_ssh.InfoResponse{ Method: "keyboard-interactive", Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{ Responses: []string{"test-prompt-response"}, }), }, }, } s.expectAllowUpstream(sh, "host1") } func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_NoPromptReply() { sh := s.BeforeTestHookResult.(*ssh.StreamHandler) s.expectPrompt(sh) } func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse() { sh := s.BeforeTestHookResult.(*ssh.StreamHandler) s.expectPrompt(sh) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_InfoResponse{ InfoResponse: &extensions_ssh.InfoResponse{ Method: "publickey", Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{ Responses: []string{"test-prompt-response"}, }), }, }, } }, "received invalid info response") } func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse() { sh := s.BeforeTestHookResult.(*ssh.StreamHandler) s.expectPrompt(sh) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_InfoResponse{ InfoResponse: &extensions_ssh.InfoResponse{ Method: "keyboard-interactive", Response: nil, }, }, } }, "received invalid prompt response") } func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType() { sh := s.BeforeTestHookResult.(*ssh.StreamHandler) s.expectPrompt(sh) s.expectError(func() { sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "keyboard-interactive", Username: "test", Hostname: "host1", MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}), }, }, } }, "received invalid message, expecting info response") } type mockGrpcServerStream struct { grpc.ServerStream ctx context.Context } func (s *mockGrpcServerStream) Context() context.Context { return s.ctx } type mockChannelStream struct { *grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage] closeServerToClientOnce sync.Once serverToClient chan *extensions_ssh.ChannelMessage closeClientToServerOnce sync.Once clientToServer chan *extensions_ssh.ChannelMessage } func newMockChannelStream(t *testing.T) *mockChannelStream { cs := &mockChannelStream{ GenericServerStream: &grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]{ ServerStream: &mockGrpcServerStream{ ctx: t.Context(), }, }, serverToClient: make(chan *extensions_ssh.ChannelMessage, 32), clientToServer: make(chan *extensions_ssh.ChannelMessage, 32), } t.Cleanup(func() { cs.CloseClientToServer() cs.CloseServerToClient() }) return cs } func (cs *mockChannelStream) Send(msg *extensions_ssh.ChannelMessage) error { cs.serverToClient <- msg return nil } func (cs *mockChannelStream) Recv() (*extensions_ssh.ChannelMessage, error) { msg, ok := <-cs.clientToServer if !ok { return nil, io.EOF } return msg, nil } func (cs *mockChannelStream) SendClientToServer(msg *extensions_ssh.ChannelMessage) { cs.clientToServer <- msg } func (cs *mockChannelStream) CloseClientToServer() { cs.closeClientToServerOnce.Do(func() { close(cs.clientToServer) }) } func (cs *mockChannelStream) CloseServerToClient() { cs.closeServerToClientOnce.Do(func() { close(cs.serverToClient) }) } func (cs *mockChannelStream) RecvServerToClient() (*extensions_ssh.ChannelMessage, error) { select { case msg, ok := <-cs.serverToClient: if !ok { return nil, io.EOF } return msg, nil case <-time.After(DefaultTimeout): return nil, errors.New("timed out waiting for server to send message") } } var _ extensions_ssh.StreamManagement_ServeChannelServer = (*mockChannelStream)(nil) func channelMsg(input any) *extensions_ssh.ChannelMessage { return &extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: wrapperspb.Bytes(gossh.Marshal(input)), }, } } func recvChannelMsg[T any](s *StreamHandlerSuite, stream *mockChannelStream) T { response, err := stream.RecvServerToClient() s.Require().NoError(err) var msg T s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg)) return msg } func sendChannelMsg(stream *mockChannelStream, msg any) { stream.SendClientToServer(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(msg), }, }, }) } func (s *StreamHandlerSuite) TestServeChannel_InitialRecvError() { sh := s.startStreamHandler(1) stream := newMockChannelStream(s.T()) stream.CloseClientToServer() s.Error(io.EOF, sh.ServeChannel(stream)) } func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotRawBytes() { sh := s.startStreamHandler(1) stream := newMockChannelStream(s.T()) stream.SendClientToServer(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_Metadata{}, }) s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream)) } func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotChannelOpen() { sh := s.startStreamHandler(1) stream := newMockChannelStream(s.T()) stream.SendClientToServer(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: wrapperspb.Bytes([]byte("not ChannelOpen")), }, }) s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream)) } func init() { hook := func(s *StreamHandlerSuite, args []any) any { errorMatcher := args[0].(Matcher) sh := s.startStreamHandler(1) s.mockAuth.EXPECT(). HandlePublicKeyMethodRequest(Any(), Any(), Any()). Times(1). DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) { s.Equal("test", *info.Username) s.Equal("", *info.Hostname) return ssh.PublicKeyAuthMethodResponse{ Allow: &extensions_ssh.PublicKeyAllowResponse{ PublicKey: req.PublicKey, Permissions: &extensions_ssh.Permissions{}, }, RequireAdditionalMethods: []string{}, }, nil }) s.False(sh.IsExpectingInternalChannel()) sh.ReadC() <- &extensions_ssh.ClientMessage{ Message: &extensions_ssh.ClientMessage_AuthRequest{ AuthRequest: &extensions_ssh.AuthenticationRequest{ Protocol: "ssh", Service: "ssh-connection", AuthMethod: "publickey", Username: "test", Hostname: "", MethodRequest: s.validPublicKeyMethodRequest(), }, }, } s.expectAllowInternal(sh) s.True(sh.IsExpectingInternalChannel()) s.Equal("test", *sh.Username()) s.Equal("", *sh.Hostname()) stream := newMockChannelStream(s.T()) errC := make(chan error, 1) go func() { errC <- sh.ServeChannel(stream) stream.CloseServerToClient() }() s.cleanup = append(s.cleanup, func() { stream.CloseClientToServer() select { case err := <-errC: s.Truef(errorMatcher.Matches(err), "expected: %v\nactual: %v", errorMatcher.String(), err) case <-time.After(DefaultTimeout): s.FailNow("timed out waiting for ServeChannel to exit") } }) return stream } StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = HookWithArgs(hook, Not(Nil())) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = HookWithArgs(hook, Not(Nil())) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = HookWithArgs(hook, Eq(status.Errorf(codes.PermissionDenied, "test error"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = HookWithArgs(hook, Nil()) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_WhoamiError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Logout"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_LogoutError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_NonInteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_InteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed"))) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ChannelCloseResponseTimeout"] = HookWithArgs(hook, Eq(status.Errorf(codes.DeadlineExceeded, "timed out waiting for channel close"))) } func (s *StreamHandlerSuite) TestServeChannel_Session() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize) s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow) s.Equal(uint32(2), resp.PeersID) s.Equal(uint32(1), resp.MyID) sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id recvChannelMsg[ssh.ChannelCloseMsg](s, stream) } func (s *StreamHandlerSuite) TestServeChannel_Session_DifferentWindowAndPacketSizes() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, // client id PeersWindow: ssh.ChannelWindowSize / 2, MaxPacketSize: ssh.ChannelMaxPacket / 2, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize) s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow) s.Equal(uint32(2), resp.PeersID) // client id s.Equal(uint32(1), resp.MyID) // server id sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id recvChannelMsg[ssh.ChannelCloseMsg](s, stream) } func (s *StreamHandlerSuite) channelDataLoop(peerID uint32, stream *mockChannelStream, exitCode ...uint32) *bytes.Buffer { s.T().Helper() var channelData bytes.Buffer for { response, err := stream.RecvServerToClient() if errors.Is(err, io.EOF) { break } s.Require().NoError(err) bytes := response.GetRawBytes().GetValue() switch bytes[0] { case ssh.MsgChannelData: var msg ssh.ChannelDataMsg s.Require().NoError(gossh.Unmarshal(bytes, &msg)) channelData.Write(msg.Rest) case ssh.MsgChannelRequest: var msg ssh.ChannelRequestMsg s.Require().NoError(gossh.Unmarshal(bytes, &msg)) s.Equal("exit-status", msg.Request) s.Require().NotEmpty(exitCode, "received an exit-status ChannelRequest but the test did not assert an exit code") expected := exitCode[0] actual := binary.BigEndian.Uint32(msg.RequestSpecificData) s.Equal(expected, actual) case ssh.MsgChannelClose: sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID}) } } return &channelData } func (s *StreamHandlerSuite) TestServeChannel_Session_ExecWithPtyHelp() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "pty-req", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "exec", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ Command: "--help", }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) maybeRoutesPortalCmd := "" if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { maybeRoutesPortalCmd = ` portal Interactive route portal ` } channelData := s.channelDataLoop(peerID, stream, 0) s.Equal(` Usage: pomerium [command] Available Commands: help Help about any command logout Log out `[1:]+maybeRoutesPortalCmd+ ` whoami Show details for the current session Flags: -h, --help help for pomerium Use "pomerium [command] --help" for more information about a command. `, channelData.String()) } func (s *StreamHandlerSuite) TestServeChannel_Session_ChannelCloseResponseTimeout() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "pty-req", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "exec", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ Command: "--help", }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) for { response, err := stream.RecvServerToClient() if errors.Is(err, io.EOF) { break } s.Require().NoError(err) bytes := response.GetRawBytes().GetValue() switch bytes[0] { case ssh.MsgChannelData: var msg ssh.ChannelDataMsg s.Require().NoError(gossh.Unmarshal(bytes, &msg)) case ssh.MsgChannelClose: // don't send a response } } } func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_NonInteractiveError() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "shell", WantReply: true, })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { channelData := s.channelDataLoop(peerID, stream, 1) s.Equal("Error: 'portal' is an interactive command and requires a TTY (try passing '-t' to ssh)\n", ansi.Strip(channelData.String())) } else { channelData := s.channelDataLoop(peerID, stream, 0) s.Equal(` Usage: pomerium [command] Available Commands: help Help about any command logout Log out whoami Show details for the current session Flags: -h, --help help for pomerium Use "pomerium [command] --help" for more information about a command. `[1:], channelData.String()) } } func (s *StreamHandlerSuite) TestServeChannel_Session_InteractiveError() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "pty-req", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "exec", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ Command: "whoami", }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) channelData := s.channelDataLoop(peerID, stream, 1) s.Equal("Error: 'whoami' is not an interactive command (try passing '-T' to ssh, or removing '-t')\r\n", ansi.Strip(channelData.String())) } func printFrame(in string) string { re := strings.NewReplacer(" ", "Ā·", "\t", "šŸ”’", "\n", "\n⤶", "\r", "⇤") return re.Replace(ansi.Strip(in)) } func postProcessFrame(in string) string { return strings.ReplaceAll(ansi.Strip(in), "\r", "") } type routesPortalTestHookOutput struct { stream *mockChannelStream peerID uint32 } func init() { hook := func(s *StreamHandlerSuite) any { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "pty-req", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{ TermEnv: "dumb", Width: 39, Height: 10, }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "shell", WantReply: true, })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) if !s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) { channelData := s.channelDataLoop(peerID, stream, 0) s.Equal(` Usage: pomerium [command] Available Commands: help Help about any command logout Log out whoami Show details for the current session Flags: -h, --help help for pomerium Use "pomerium [command] --help" for more information about a command. `[1:], channelData.String()) return nil } return &routesPortalTestHookOutput{ stream, peerID, } } StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"], hook) StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"], hook) } func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal() { res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput) if res == nil { return // routes portal disabled } stream, peerID := res.stream, res.peerID frames := []string{ ` || | Connect to which server? | | | | > 1. test@host1 | | 2. test@host2 | | | | | | ↑/k up • ↓/j down • q quit • ? more| | |`[1:], ` || || || | 1. test@host1 | | > 2. test@host2 | || || || ||`[1:], } for i, frame := range frames { frames[i] = strings.ReplaceAll(frame, "|", "") } var ok bool var channelData bytes.Buffer currentFrame := 0 start := time.Now() frameAdvance := func() { switch currentFrame { case 0: cursorDown := []byte(ansi.CursorDown(1)) currentFrame++ sendChannelMsg(stream, ssh.ChannelDataMsg{ PeersID: peerID, Length: uint32(len(cursorDown)), Rest: cursorDown, }) case 1: currentFrame++ ok = true sendChannelMsg(stream, ssh.ChannelDataMsg{ PeersID: peerID, Length: uint32(1), Rest: []byte("q"), }) } channelData.Reset() } LOOP: for time.Since(start) < DefaultTimeout { response, err := stream.RecvServerToClient() if err != nil { s.Fail(err.Error()) break } bytes := response.GetRawBytes().GetValue() switch bytes[0] { case ssh.MsgChannelData: if ok { continue } var msg ssh.ChannelDataMsg s.Require().NoError(gossh.Unmarshal(bytes, &msg)) channelData.Write(msg.Rest) if postProcessFrame(channelData.String()) == frames[currentFrame] { frameAdvance() if currentFrame >= len(frames) { ok = true } } case ssh.MsgChannelRequest: // the only channel request we expect to send would be "exit-status" var msg ssh.ChannelRequestMsg s.Require().NoError(gossh.Unmarshal(bytes, &msg)) s.Equal("exit-status", msg.Request) s.Equal(uint32(0), binary.BigEndian.Uint32(msg.RequestSpecificData)) case ssh.MsgChannelClose: sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID}) break LOOP default: s.FailNow("test bug") } } currentFrameStr := "" if !ok { currentFrameStr = printFrame(frames[currentFrame]) } s.Require().Truef(ok, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s", currentFrame, printFrame(postProcessFrame(channelData.String())), currentFrameStr) } func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_Select() { res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput) if res == nil { return // routes portal disabled } stream, peerID := res.stream, res.peerID frames := []string{ ` || | Connect to which server? | | | | > 1. test@host1 | | 2. test@host2 | | | | | | ↑/k up • ↓/j down • q quit • ? more| | |`[1:], ` || || || | 1. test@host1 | | > 2. test@host2 | || || || ||`[1:], ` || | Connect to which server? | | | | 1. test@host1 | | > 2. test@host2 | | | | | | ↑/k up • ↓/j down • q quit …| | |`[1:], } for i, frame := range frames { frames[i] = strings.ReplaceAll(frame, "|", "") } var portalOk bool var handoffOk bool var expectHandoff bool var channelData bytes.Buffer currentFrame := 0 start := time.Now() frameAdvance := func() { switch currentFrame { case 0: cursorDown := []byte(ansi.CursorDown(1)) currentFrame++ sendChannelMsg(stream, ssh.ChannelDataMsg{ PeersID: peerID, Length: uint32(len(cursorDown)), Rest: cursorDown, }) case 1: currentFrame++ sendChannelMsg(stream, ssh.ChannelRequestMsg{ PeersID: peerID, Request: "window-change", WantReply: false, RequestSpecificData: gossh.Marshal(ssh.ChannelWindowChangeRequestMsg{ WidthColumns: 36, HeightRows: 10, }), }) case 2: currentFrame++ s.mockAuth.EXPECT().EvaluateDelayed(Any(), Any()). DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo) error { s.Equal(info.Username, ptr("test")) s.Equal(info.Hostname, ptr("host2")) return nil }) expectHandoff = true sendChannelMsg(stream, ssh.ChannelDataMsg{ PeersID: peerID, Length: uint32(1), Rest: []byte("\r"), }) } channelData.Reset() } LOOP: for time.Since(start) < DefaultTimeout { response, err := stream.RecvServerToClient() if err != nil { s.Fail(err.Error()) break } if expectHandoff { if response.GetRawBytes() != nil { // we might get bytes containing a newline var msg ssh.ChannelDataMsg s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg)) s.Require().Empty(strings.TrimSpace(ansi.Strip(string(msg.Rest)))) continue } action := response.GetChannelControl().GetControlAction() s.Require().NotNil(action, "expected channel control action") var sshAction extensions_ssh.SSHChannelControlAction s.Require().NoError(action.UnmarshalTo(&sshAction)) handoff := sshAction.GetHandOff() s.Require().NotNil(action, "expected handoff action") s.Require().NotNil(handoff.GetUpstreamAuth().GetUpstream(), "expected upstream handoff action") s.Equal("test", handoff.GetUpstreamAuth().Username) s.Equal("host2", handoff.GetUpstreamAuth().GetUpstream().Hostname) testutil.AssertProtoEqual(s.T(), []*extensions_ssh.AllowedMethod{ { Method: "publickey", MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{ PublicKey: s.ed25519SshPublicKey.Marshal(), Permissions: &extensions_ssh.Permissions{}, }), }, }, handoff.GetUpstreamAuth().GetUpstream().AllowedMethods) handoffOk = true break LOOP } bytes := response.GetRawBytes().GetValue() s.Require().NotNil(bytes, response.String()) switch bytes[0] { case ssh.MsgChannelData: if portalOk { continue } s.Require().False(expectHandoff) var msg ssh.ChannelDataMsg s.Require().NoError(gossh.Unmarshal(bytes, &msg)) channelData.Write(msg.Rest) if postProcessFrame(channelData.String()) == frames[currentFrame] { frameAdvance() if currentFrame >= len(frames) { portalOk = true } } default: s.FailNow("test bug") } } currentFrameStr := "" if !portalOk { currentFrameStr = printFrame(frames[currentFrame]) } s.Truef(portalOk, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s", currentFrame, printFrame(postProcessFrame(channelData.String())), currentFrameStr) s.True(handoffOk, "timed out waiting for handoff") sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID}) } func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Whoami() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID s.mockAuth.EXPECT(). FormatSession(Any(), Any()). Return([]byte("example"), nil) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "exec", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ Command: "whoami", }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) channelData := s.channelDataLoop(peerID, stream, 0) s.Equal("example", channelData.String()) } func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_WhoamiError() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID s.mockAuth.EXPECT(). FormatSession(Any(), Any()). Return(nil, errors.New("test error")) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "exec", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ Command: "whoami", }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) channelData := s.channelDataLoop(peerID, stream, 1) s.Equal("Error: couldn't fetch session: test error\r\n", channelData.String()) } func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Logout() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID s.mockAuth.EXPECT(). DeleteSession(Any(), Any()). Return(nil) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "exec", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ Command: "logout", }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) channelData := s.channelDataLoop(peerID, stream, 0) s.Equal("Logged out successfully\r\n", channelData.String()) } func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_LogoutError() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "session", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream) peerID := resp.MyID s.mockAuth.EXPECT(). DeleteSession(Any(), Any()). Return(errors.New("test error")) stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{ PeersID: peerID, Request: "exec", WantReply: true, RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{ Command: "logout", }), })) recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream) channelData := s.channelDataLoop(peerID, stream, 1) s.Equal("Error: failed to delete session: test error\r\n", channelData.String()) } func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_NoSubMsg() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "direct-tcpip", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) // error checked in cleanup } func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_BadHostname() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "direct-tcpip", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{ DestAddr: "", // invalid DestPort: 22, SrcAddr: "127.0.0.1", SrcPort: 12345, }), })) // error checked in cleanup } func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() { s.mockAuth.EXPECT(). EvaluateDelayed(Any(), Any()). Times(1). Return(errors.New("test error")) stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "direct-tcpip", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{ DestAddr: "host1", DestPort: 22, SrcAddr: "127.0.0.1", SrcPort: 12345, }), })) // error checked in cleanup } func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() { s.mockAuth.EXPECT(). EvaluateDelayed(Any(), Any()). Times(1). Return(nil) stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "direct-tcpip", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{ DestAddr: "host1", // i.e. 'ssh -J pomerium test@host1' DestPort: 22, // this will be sent by the ssh client, but is ignored SrcAddr: "127.0.0.1", SrcPort: 12345, }), })) recv, err := stream.RecvServerToClient() s.Require().NoError(err) action := recv.GetChannelControl().GetControlAction() s.Require().NotNil(action, "received a message, but it was not a channel control action") handoff := extensions_ssh.SSHChannelControlAction{} s.Require().NoError(action.UnmarshalTo(&handoff)) testutil.AssertProtoEqual(s.T(), extensions_ssh.SSHChannelControlAction_HandOffUpstream{ DownstreamChannelInfo: &extensions_ssh.SSHDownstreamChannelInfo{ ChannelType: "direct-tcpip", DownstreamChannelId: 2, InternalUpstreamChannelId: 1, InitialWindowSize: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, }, DownstreamPtyInfo: nil, UpstreamAuth: &extensions_ssh.AllowResponse{ Username: "test", Target: &extensions_ssh.AllowResponse_Upstream{ Upstream: &extensions_ssh.UpstreamTarget{ Hostname: "host1", DirectTcpip: true, AllowedMethods: []*extensions_ssh.AllowedMethod{ { Method: "publickey", MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{ PublicKey: s.ed25519SshPublicKey.Marshal(), Permissions: &extensions_ssh.Permissions{}, }), }, }, }, }, }, }, handoff.GetHandOff()) } func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() { stream := s.BeforeTestHookResult.(*mockChannelStream) stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{ ChanType: "unknown", PeersID: 2, PeersWindow: ssh.ChannelWindowSize, MaxPacketSize: ssh.ChannelMaxPacket, })) // error checked in cleanup } func (s *StreamHandlerSuite) TestFormatSession() { s.mockAuth.EXPECT(). FormatSession(Any(), Any()). Return([]byte("example"), nil) sh := s.mgr.NewStreamHandler( s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) ctx, ca := context.WithCancel(context.Background()) ca() // this will exit immediately, but it will have a state, which is only // created upon calling Run() sh.Run(ctx) res, err := sh.FormatSession(s.T().Context()) s.NoError(err) s.Equal([]byte("example"), res) } func (s *StreamHandlerSuite) TestDeleteSession() { s.mockAuth.EXPECT(). DeleteSession(Any(), Any()). Return(nil) sh := s.mgr.NewStreamHandler( s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) ctx, ca := context.WithCancel(context.Background()) ca() // this will exit immediately, but it will have a state, which is only // created upon calling Run() sh.Run(ctx) err := sh.DeleteSession(s.T().Context()) s.NoError(err) } func (s *StreamHandlerSuite) TestRunCalledTwice() { sh := s.mgr.NewStreamHandler( s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) ctx, ca := context.WithCancel(context.Background()) ca() sh.Run(ctx) s.PanicsWithValue("Run called twice", func() { sh.Run(context.Background()) }) } func (s *StreamHandlerSuite) TestAllSSHRoutes() { sh := s.mgr.NewStreamHandler( s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) routes := slices.Collect(sh.AllSSHRoutes()) s.Len(routes, 2) s.Equal("ssh://host1", routes[0].From) s.Equal("ssh://dest1:22", routes[0].To[0].String()) s.Equal("ssh://host2", routes[1].From) s.Equal("ssh://dest2:22", routes[1].To[0].String()) next, stop := iter.Pull(sh.AllSSHRoutes()) v, ok := next() s.NotNil(v) s.True(ok) stop() v, ok = next() s.Nil(v) s.False(ok) } func TestStreamHandlerSuite(t *testing.T) { suite.Run(t, &StreamHandlerSuite{}) } func TestStreamHandlerSuiteWithRuntimeFlags(t *testing.T) { suite.Run(t, &StreamHandlerSuite{ StreamHandlerSuiteOptions: StreamHandlerSuiteOptions{ ConfigModifiers: []func(*config.Config){ func(c *config.Config) { c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true }, }, }, }) } func ptr[T any](t T) *T { return &t }