pomerium/pkg/ssh/stream_test.go
Kenneth Jenkins 9678e6a231
ssh: implement authorization policy evaluation (#5665)
Implement the pkg/ssh.AuthInterface. Add logic for converting from the
ssh stream state to an evaluator request, and for interpreting the
results of policy evaluation. Refactor some of the existing authorize
logic to make it easier to reuse.
2025-07-01 12:04:00 -07:00

2078 lines
65 KiB
Go

package ssh_test
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
"os"
"runtime"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/charmbracelet/x/ansi"
"github.com/stretchr/testify/suite"
. "go.uber.org/mock/gomock" //nolint
gossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/ssh"
mock_ssh "github.com/pomerium/pomerium/pkg/ssh/mock"
)
var DefaultTimeout = 10 * time.Second
func init() {
if isDebuggerAttached() {
DefaultTimeout = 1 * time.Hour
}
}
func isDebuggerAttached() bool {
if runtime.GOOS == "linux" {
data, err := os.ReadFile("/proc/self/status")
if err == nil {
for line := range bytes.Lines(data) {
if bytes.HasPrefix(line, []byte("TracerPid:\t")) {
return line[11] != '0'
}
}
}
}
return false
}
func HookWithArgs(f func(s *StreamHandlerSuite, args []any) any, args ...any) []func(s *StreamHandlerSuite) any {
return []func(s *StreamHandlerSuite) any{
func(s *StreamHandlerSuite) any {
return f(s, args)
},
}
}
var (
StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
)
type StreamHandlerSuiteOptions struct {
ConfigModifiers []func(*config.Config)
}
type StreamHandlerSuite struct {
suite.Suite
StreamHandlerSuiteOptions
ctrl *Controller
mgr *ssh.StreamManager
cfg *config.Config
cleanup []func()
errC chan error
mockAuth *mock_ssh.MockAuthInterface
ed25519PublicKey ed25519.PublicKey
ed25519PrivateKey ed25519.PrivateKey
ed25519SshPublicKey gossh.PublicKey
ed25519SshPrivateKey gossh.Signer
BeforeTestHookResult any
}
func (s *StreamHandlerSuite) SetupTest() {
s.ctrl = NewController(s.T())
s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl)
s.mgr = ssh.NewStreamManager()
s.cleanup = []func(){}
s.errC = make(chan error, 1)
var err error
s.ed25519PublicKey, s.ed25519PrivateKey, err = ed25519.GenerateKey(rand.Reader)
s.Require().NoError(err)
s.ed25519SshPublicKey, err = gossh.NewPublicKey(s.ed25519PublicKey)
s.Require().NoError(err)
s.ed25519SshPrivateKey, err = gossh.NewSignerFromKey(s.ed25519PrivateKey)
s.Require().NoError(err)
s.cfg = &config.Config{Options: config.NewDefaultOptions()}
s.cfg.Options.Policies = []config.Policy{
{From: "https://from.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to.notssh.example.com")},
{From: "ssh://host1", To: mustParseWeightedURLs(s.T(), "ssh://dest1:22")},
{From: "https://from1.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to1.notssh.example.com")},
{From: "ssh://host2", To: mustParseWeightedURLs(s.T(), "ssh://dest2:22")},
{From: "https://from2.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to2.notssh.example.com")},
}
for _, f := range s.ConfigModifiers {
f(s.cfg)
}
}
func (s *StreamHandlerSuite) TearDownTest() {
for _, f := range s.cleanup {
f()
}
s.ctrl.Finish()
}
func (s *StreamHandlerSuite) BeforeTest(_, testName string) {
s.BeforeTestHookResult = nil
for _, fn := range StreamHandlerSuiteBeforeTestHooks[testName] {
s.BeforeTestHookResult = fn(s)
}
}
//
// Helper methods
//
func marshalAny(msg proto.Message) *anypb.Any {
a, err := anypb.New(msg)
if err != nil {
panic(err)
}
return a
}
func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
fn()
select {
case err := <-s.errC:
s.ErrorContains(err, msg)
case <-time.After(DefaultTimeout):
s.FailNowf("timed out waiting for error %q", msg)
}
}
func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler {
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
s.errC = make(chan error, 1)
ctx, ca := context.WithCancel(s.T().Context())
go func() {
defer close(s.errC)
s.errC <- sh.Run(ctx)
}()
s.cleanup = append(s.cleanup, func() {
start := time.Now()
for len(sh.ReadC()) > 0 && time.Since(start) < 100*time.Millisecond {
runtime.Gosched()
}
if len(sh.ReadC()) > 0 {
s.Fail(fmt.Sprintf("read channel contains %d unhandled client messages", len(sh.ReadC())))
}
ca()
var err error
select {
case err = <-s.errC:
case <-time.After(DefaultTimeout):
s.Fail("timed out waiting for stream handler to close")
}
sh.Close()
if err != nil {
s.Require().ErrorIs(err, context.Canceled)
}
if len(sh.WriteC()) != 0 {
logs := []string{"write channel contains unhandled server messages:"}
i := 0
for msg := range sh.WriteC() {
logs = append(logs, fmt.Sprintf("[%d]: %s", i, msg.String()))
i++
}
s.Fail(strings.Join(logs, "\n"))
}
})
return sh
}
func (s *StreamHandlerSuite) msgDownstreamConnected(streamID uint64) *extensions_ssh.ClientMessage {
return &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_Event{
Event: &extensions_ssh.StreamEvent{
Event: &extensions_ssh.StreamEvent_DownstreamConnected{
DownstreamConnected: &extensions_ssh.DownstreamConnectEvent{
StreamId: streamID,
},
},
},
},
}
}
func (s *StreamHandlerSuite) msgDownstreamDisconnected(reason string) *extensions_ssh.ClientMessage {
return &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_Event{
Event: &extensions_ssh.StreamEvent{
Event: &extensions_ssh.StreamEvent_DownstreamDisconnected{
DownstreamDisconnected: &extensions_ssh.DownstreamDisconnectedEvent{
Reason: reason,
},
},
},
},
}
}
func (s *StreamHandlerSuite) msgUpstreamConnected(streamID uint64) *extensions_ssh.ClientMessage {
return &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_Event{
Event: &extensions_ssh.StreamEvent{
Event: &extensions_ssh.StreamEvent_UpstreamConnected{
UpstreamConnected: &extensions_ssh.UpstreamConnectEvent{
StreamId: streamID,
},
},
},
},
}
}
func (s *StreamHandlerSuite) expectAllowUpstream(sh *ssh.StreamHandler, hostname string) {
select {
case msg := <-sh.WriteC():
if authResp := msg.GetAuthResponse(); authResp != nil {
if allow := authResp.GetAllow(); allow != nil {
s.Require().NotNil(allow.GetUpstream(), "received an allow response, but not to an upstream target")
s.Require().Equal(hostname, allow.GetUpstream().GetHostname())
} else {
s.FailNowf("received an auth response, but it was not an allow response", authResp.String())
}
} else {
s.FailNow("received a message, but it was not an auth response", msg.String())
}
case <-time.After(DefaultTimeout):
s.FailNow("timed out waiting for upstream allow message")
}
}
func (s *StreamHandlerSuite) expectDeny(sh *ssh.StreamHandler, partial bool, methods []string) {
select {
case msg := <-sh.WriteC():
if authResp := msg.GetAuthResponse(); authResp != nil {
if deny := authResp.GetDeny(); deny != nil {
s.Require().Equal(partial, deny.Partial)
s.Require().Equal(methods, deny.Methods)
} else {
s.Require().Fail("received an auth response, but it was not a deny response", authResp.String())
}
} else {
s.FailNow("received a message, but it was not an auth response", msg.String())
}
case <-time.After(DefaultTimeout):
s.FailNow("timed out waiting for deny message")
}
}
func (s *StreamHandlerSuite) expectAllowInternal(sh *ssh.StreamHandler) {
select {
case msg := <-sh.WriteC():
if authResp := msg.GetAuthResponse(); authResp != nil {
if allow := authResp.GetAllow(); allow != nil {
s.Require().NotNil(allow.GetInternal(), "received an allow response, but not to an internal target")
} else {
s.FailNow("received an auth response, but it was not an allow response", authResp.String())
}
} else {
s.FailNow("received a message, but it was not an auth response", msg.String())
}
case <-time.After(DefaultTimeout):
s.FailNow("timed out waiting for internal allow message")
}
}
func (s *StreamHandlerSuite) expectPrompt(sh *ssh.StreamHandler) {
select {
case msg := <-sh.WriteC():
if authResp := msg.GetAuthResponse(); authResp != nil {
if info := authResp.GetInfoRequest(); info != nil {
s.Require().NotNil(info.GetRequest(), "received a nil info request")
} else {
s.FailNow("received an auth response, but it was not an info request", authResp.String())
}
} else {
s.FailNow("received a message, but it was not an auth response", msg.String())
}
case <-time.After(DefaultTimeout):
s.FailNow("timed out waiting for prompt message")
}
}
func (s *StreamHandlerSuite) validPublicKeyMethodRequest() *anypb.Any {
return marshalAny(&extensions_ssh.PublicKeyMethodRequest{
PublicKey: s.ed25519SshPublicKey.Marshal(),
PublicKeyAlg: s.ed25519SshPublicKey.Type(),
PublicKeyFingerprintSha256: []byte(gossh.FingerprintSHA256(s.ed25519SshPublicKey)),
})
}
//
// Tests
//
func (s *StreamHandlerSuite) TestDuplicateDownstreamConnectedEvent() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- s.msgDownstreamConnected(1)
}, "received duplicate downstream connected event")
}
func (s *StreamHandlerSuite) TestDownstreamDisconnectedEvent() {
sh := s.startStreamHandler(1)
sh.ReadC() <- s.msgDownstreamDisconnected("") // this just logs a message
}
func (s *StreamHandlerSuite) TestUpstreamConnectedEvent() {
sh := s.startStreamHandler(1)
sh.ReadC() <- s.msgUpstreamConnected(1) // this just logs a message
}
func (s *StreamHandlerSuite) TestInvalidEvent() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_Event{
Event: &extensions_ssh.StreamEvent{Event: nil},
},
}
}, "received invalid event")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidProtocol() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "not-ssh",
Service: "ssh-connection",
},
},
}
}, "invalid protocol: not-ssh")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidService() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-userauth",
},
},
}
}, "invalid service: ssh-userauth")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidMessage() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: nil,
}
}, "received invalid message")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_FirstRequestIsKeyboardInteractive() {
sh := s.startStreamHandler(1)
s.expectError(func() {
// first request should be publickey
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "keyboard-interactive",
},
},
}
}, "unexpected auth method: keyboard-interactive")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_MissingUsername() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "",
},
},
}
}, "username missing")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_EmptyHostname() {
sh := s.startStreamHandler(1)
// empty hostname is allowed initially
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: s.ed25519SshPublicKey.Marshal(),
Permissions: &extensions_ssh.Permissions{},
}}, nil)
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectAllowInternal(sh)
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_MismatchedAuthMethodAndRequestType() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "",
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
},
},
}
}, "invalid public key method request type")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequest() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: s.ed25519SshPublicKey.Marshal(),
Permissions: &extensions_ssh.Permissions{},
}}, nil)
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectAllowUpstream(sh, "host1")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequestError() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Return(ssh.PublicKeyAuthMethodResponse{}, errors.New("test error"))
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
}, "test error")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_PublicKeyRetry() {
sh := s.startStreamHandler(1)
i := -1
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
MaxTimes(4).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
i++
switch i {
case 0, 1, 2:
return ssh.PublicKeyAuthMethodResponse{
RequireAdditionalMethods: []string{"publickey"},
}, nil
case 3:
return ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: s.ed25519SshPublicKey.Marshal(),
Permissions: &extensions_ssh.Permissions{},
}}, nil
default:
panic("unreachable")
}
})
for i := range 4 {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
if i < 3 {
s.expectDeny(sh, false, []string{"publickey"})
} else {
s.expectAllowUpstream(sh, "host1")
}
}
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentUsername() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(1).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
return ssh.PublicKeyAuthMethodResponse{
RequireAdditionalMethods: []string{"publickey"},
}, nil
})
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectDeny(sh, false, []string{"publickey"})
s.Equal("test", *sh.Username())
s.Equal("host1", *sh.Hostname())
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test2",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
}, "inconsistent username")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentHostname() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(1).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
return ssh.PublicKeyAuthMethodResponse{
RequireAdditionalMethods: []string{"publickey"},
}, nil
})
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectDeny(sh, false, []string{"publickey"})
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host2",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
}, "inconsistent hostname")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentEmptyHostname() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(1).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
return ssh.PublicKeyAuthMethodResponse{
Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: req.PublicKey,
Permissions: &extensions_ssh.Permissions{},
},
RequireAdditionalMethods: []string{"keyboard-interactive"},
}, nil
})
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectDeny(sh, true, []string{"keyboard-interactive"})
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "keyboard-interactive",
Username: "test",
Hostname: "host1",
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
},
},
}
}, "inconsistent hostname")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_UnknownAuthMethod() {
sh := s.startStreamHandler(1)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "password",
Username: "test",
Hostname: "host1",
},
},
}
}, "unexpected auth method: password")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_UnimplementedAuthMethod() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(1).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
return ssh.PublicKeyAuthMethodResponse{
RequireAdditionalMethods: []string{"password"},
}, nil
})
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectDeny(sh, false, []string{"password"})
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "password",
Username: "test",
Hostname: "host1",
},
},
}
}, "bug: server requested an unsupported auth method \"password\"")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_WrongClientMessage() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(1).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
return ssh.PublicKeyAuthMethodResponse{
Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: req.PublicKey,
Permissions: &extensions_ssh.Permissions{},
},
RequireAdditionalMethods: []string{"keyboard-interactive"},
}, nil
})
newMsg := func() *extensions_ssh.ClientMessage_AuthRequest {
return &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
}
}
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: newMsg(),
}
s.expectDeny(sh, true, []string{"keyboard-interactive"})
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: newMsg(),
}
}, "unexpected auth method: publickey")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongMethodRequestType() {
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(1).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
return ssh.PublicKeyAuthMethodResponse{
Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: req.PublicKey,
Permissions: &extensions_ssh.Permissions{},
},
RequireAdditionalMethods: []string{"keyboard-interactive"},
}, nil
})
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectDeny(sh, true, []string{"keyboard-interactive"})
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "keyboard-interactive",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
}, "invalid keyboard-interactive method request type")
}
func init() {
setupKeyboardInteractive := func(s *StreamHandlerSuite, input []any) any {
querierErr, _ := input[0].(error)
sh := s.startStreamHandler(100)
i := -1
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(2).
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
i++
switch i {
case 0:
return ssh.PublicKeyAuthMethodResponse{
RequireAdditionalMethods: []string{"publickey"},
}, nil
case 1:
return ssh.PublicKeyAuthMethodResponse{
Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: s.ed25519SshPublicKey.Marshal(),
Permissions: &extensions_ssh.Permissions{},
},
RequireAdditionalMethods: []string{"keyboard-interactive"},
}, nil
default:
panic("unreachable")
}
})
s.mockAuth.EXPECT().
HandleKeyboardInteractiveMethodRequest(Any(), Any(), Any(), Any()).
DoAndReturn(func(
ctx context.Context,
info ssh.StreamAuthInfo,
_ *extensions_ssh.KeyboardInteractiveMethodRequest,
querier ssh.KeyboardInteractiveQuerier,
) (ssh.KeyboardInteractiveAuthMethodResponse, error) {
s.Equal("test", *info.Username)
s.Equal("host1", *info.Hostname)
s.Equal(uint64(100), info.StreamID)
resp, err := querier.Prompt(ctx, &extensions_ssh.KeyboardInteractiveInfoPrompts{
Name: "test-name",
Instruction: "test-instruction",
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
{
Prompt: "test-prompt",
Echo: true,
},
},
})
s.Require().Equal(querierErr, err, "unexpected error from querier.Prompt")
if querierErr == nil {
s.Equal([]string{"test-prompt-response"}, resp.Responses)
return ssh.KeyboardInteractiveAuthMethodResponse{
Allow: &extensions_ssh.KeyboardInteractiveAllowResponse{},
}, nil
}
return ssh.KeyboardInteractiveAuthMethodResponse{}, err
})
for range 2 {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "host1",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
}
s.expectDeny(sh, false, []string{"publickey"})
s.expectDeny(sh, true, []string{"keyboard-interactive"})
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "keyboard-interactive",
Username: "test",
Hostname: "host1",
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
},
},
}
return sh
}
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive"] = HookWithArgs(setupKeyboardInteractive, (error)(nil))
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_NoPromptReply"] = HookWithArgs(setupKeyboardInteractive, context.Canceled)
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.Internal, "received invalid info response"))
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid prompt response"))
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response"))
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive() {
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
s.expectPrompt(sh)
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_InfoResponse{
InfoResponse: &extensions_ssh.InfoResponse{
Method: "keyboard-interactive",
Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{
Responses: []string{"test-prompt-response"},
}),
},
},
}
s.expectAllowUpstream(sh, "host1")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_NoPromptReply() {
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
s.expectPrompt(sh)
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse() {
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
s.expectPrompt(sh)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_InfoResponse{
InfoResponse: &extensions_ssh.InfoResponse{
Method: "publickey",
Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{
Responses: []string{"test-prompt-response"},
}),
},
},
}
}, "received invalid info response")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse() {
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
s.expectPrompt(sh)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_InfoResponse{
InfoResponse: &extensions_ssh.InfoResponse{
Method: "keyboard-interactive",
Response: nil,
},
},
}
}, "received invalid prompt response")
}
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType() {
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
s.expectPrompt(sh)
s.expectError(func() {
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "keyboard-interactive",
Username: "test",
Hostname: "host1",
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
},
},
}
}, "received invalid message, expecting info response")
}
type mockGrpcServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (s *mockGrpcServerStream) Context() context.Context {
return s.ctx
}
type mockChannelStream struct {
*grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]
closeServerToClientOnce sync.Once
serverToClient chan *extensions_ssh.ChannelMessage
closeClientToServerOnce sync.Once
clientToServer chan *extensions_ssh.ChannelMessage
}
func newMockChannelStream(t *testing.T) *mockChannelStream {
cs := &mockChannelStream{
GenericServerStream: &grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]{
ServerStream: &mockGrpcServerStream{
ctx: t.Context(),
},
},
serverToClient: make(chan *extensions_ssh.ChannelMessage, 32),
clientToServer: make(chan *extensions_ssh.ChannelMessage, 32),
}
t.Cleanup(func() {
cs.CloseClientToServer()
cs.CloseServerToClient()
})
return cs
}
func (cs *mockChannelStream) Send(msg *extensions_ssh.ChannelMessage) error {
cs.serverToClient <- msg
return nil
}
func (cs *mockChannelStream) Recv() (*extensions_ssh.ChannelMessage, error) {
msg, ok := <-cs.clientToServer
if !ok {
return nil, io.EOF
}
return msg, nil
}
func (cs *mockChannelStream) SendClientToServer(msg *extensions_ssh.ChannelMessage) {
cs.clientToServer <- msg
}
func (cs *mockChannelStream) CloseClientToServer() {
cs.closeClientToServerOnce.Do(func() {
close(cs.clientToServer)
})
}
func (cs *mockChannelStream) CloseServerToClient() {
cs.closeServerToClientOnce.Do(func() {
close(cs.serverToClient)
})
}
func (cs *mockChannelStream) RecvServerToClient() (*extensions_ssh.ChannelMessage, error) {
select {
case msg, ok := <-cs.serverToClient:
if !ok {
return nil, io.EOF
}
return msg, nil
case <-time.After(DefaultTimeout):
return nil, errors.New("timed out waiting for server to send message")
}
}
var _ extensions_ssh.StreamManagement_ServeChannelServer = (*mockChannelStream)(nil)
func channelMsg(input any) *extensions_ssh.ChannelMessage {
return &extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: wrapperspb.Bytes(gossh.Marshal(input)),
},
}
}
func recvChannelMsg[T any](s *StreamHandlerSuite, stream *mockChannelStream) T {
response, err := stream.RecvServerToClient()
s.Require().NoError(err)
var msg T
s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg))
return msg
}
func sendChannelMsg(stream *mockChannelStream, msg any) {
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: &wrapperspb.BytesValue{
Value: gossh.Marshal(msg),
},
},
})
}
func (s *StreamHandlerSuite) TestServeChannel_InitialRecvError() {
sh := s.startStreamHandler(1)
stream := newMockChannelStream(s.T())
stream.CloseClientToServer()
s.Error(io.EOF, sh.ServeChannel(stream))
}
func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotRawBytes() {
sh := s.startStreamHandler(1)
stream := newMockChannelStream(s.T())
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_Metadata{},
})
s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream))
}
func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotChannelOpen() {
sh := s.startStreamHandler(1)
stream := newMockChannelStream(s.T())
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: wrapperspb.Bytes([]byte("not ChannelOpen")),
},
})
s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream))
}
func init() {
hook := func(s *StreamHandlerSuite, args []any) any {
errorMatcher := args[0].(Matcher)
sh := s.startStreamHandler(1)
s.mockAuth.EXPECT().
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
Times(1).
DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
s.Equal("test", *info.Username)
s.Equal("", *info.Hostname)
return ssh.PublicKeyAuthMethodResponse{
Allow: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: req.PublicKey,
Permissions: &extensions_ssh.Permissions{},
},
RequireAdditionalMethods: []string{},
}, nil
})
s.False(sh.IsExpectingInternalChannel())
sh.ReadC() <- &extensions_ssh.ClientMessage{
Message: &extensions_ssh.ClientMessage_AuthRequest{
AuthRequest: &extensions_ssh.AuthenticationRequest{
Protocol: "ssh",
Service: "ssh-connection",
AuthMethod: "publickey",
Username: "test",
Hostname: "",
MethodRequest: s.validPublicKeyMethodRequest(),
},
},
}
s.expectAllowInternal(sh)
s.True(sh.IsExpectingInternalChannel())
s.Equal("test", *sh.Username())
s.Equal("", *sh.Hostname())
stream := newMockChannelStream(s.T())
errC := make(chan error, 1)
go func() {
errC <- sh.ServeChannel(stream)
stream.CloseServerToClient()
}()
s.cleanup = append(s.cleanup, func() {
stream.CloseClientToServer()
select {
case err := <-errC:
s.Truef(errorMatcher.Matches(err), "expected: %v\nactual: %v", errorMatcher.String(), err)
case <-time.After(DefaultTimeout):
s.FailNow("timed out waiting for ServeChannel to exit")
}
})
return stream
}
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = HookWithArgs(hook, Not(Nil()))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = HookWithArgs(hook, Not(Nil()))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = HookWithArgs(hook, Eq(status.Errorf(codes.PermissionDenied, "test error")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = HookWithArgs(hook, Nil())
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_WhoamiError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Logout"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_LogoutError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_NonInteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_InteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ChannelCloseResponseTimeout"] = HookWithArgs(hook, Eq(status.Errorf(codes.DeadlineExceeded, "timed out waiting for channel close")))
}
func (s *StreamHandlerSuite) TestServeChannel_Session() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize)
s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow)
s.Equal(uint32(2), resp.PeersID)
s.Equal(uint32(1), resp.MyID)
sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id
recvChannelMsg[ssh.ChannelCloseMsg](s, stream)
}
func (s *StreamHandlerSuite) TestServeChannel_Session_DifferentWindowAndPacketSizes() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2, // client id
PeersWindow: ssh.ChannelWindowSize / 2,
MaxPacketSize: ssh.ChannelMaxPacket / 2,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize)
s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow)
s.Equal(uint32(2), resp.PeersID) // client id
s.Equal(uint32(1), resp.MyID) // server id
sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id
recvChannelMsg[ssh.ChannelCloseMsg](s, stream)
}
func (s *StreamHandlerSuite) channelDataLoop(peerID uint32, stream *mockChannelStream, exitCode ...uint32) *bytes.Buffer {
s.T().Helper()
var channelData bytes.Buffer
for {
response, err := stream.RecvServerToClient()
if errors.Is(err, io.EOF) {
break
}
s.Require().NoError(err)
bytes := response.GetRawBytes().GetValue()
switch bytes[0] {
case ssh.MsgChannelData:
var msg ssh.ChannelDataMsg
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
channelData.Write(msg.Rest)
case ssh.MsgChannelRequest:
var msg ssh.ChannelRequestMsg
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
s.Equal("exit-status", msg.Request)
s.Require().NotEmpty(exitCode, "received an exit-status ChannelRequest but the test did not assert an exit code")
expected := exitCode[0]
actual := binary.BigEndian.Uint32(msg.RequestSpecificData)
s.Equal(expected, actual)
case ssh.MsgChannelClose:
sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID})
}
}
return &channelData
}
func (s *StreamHandlerSuite) TestServeChannel_Session_ExecWithPtyHelp() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "pty-req",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "exec",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
Command: "--help",
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
maybeRoutesPortalCmd := ""
if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) {
maybeRoutesPortalCmd = ` portal Interactive route portal
`
}
channelData := s.channelDataLoop(peerID, stream, 0)
s.Equal(`
Usage:
pomerium [command]
Available Commands:
help Help about any command
logout Log out
`[1:]+maybeRoutesPortalCmd+
` whoami Show details for the current session
Flags:
-h, --help help for pomerium
Use "pomerium [command] --help" for more information about a command.
`, channelData.String())
}
func (s *StreamHandlerSuite) TestServeChannel_Session_ChannelCloseResponseTimeout() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "pty-req",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "exec",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
Command: "--help",
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
for {
response, err := stream.RecvServerToClient()
if errors.Is(err, io.EOF) {
break
}
s.Require().NoError(err)
bytes := response.GetRawBytes().GetValue()
switch bytes[0] {
case ssh.MsgChannelData:
var msg ssh.ChannelDataMsg
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
case ssh.MsgChannelClose:
// don't send a response
}
}
}
func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_NonInteractiveError() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "shell",
WantReply: true,
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) {
channelData := s.channelDataLoop(peerID, stream, 1)
s.Equal("Error: 'portal' is an interactive command and requires a TTY (try passing '-t' to ssh)\n",
ansi.Strip(channelData.String()))
} else {
channelData := s.channelDataLoop(peerID, stream, 0)
s.Equal(`
Usage:
pomerium [command]
Available Commands:
help Help about any command
logout Log out
whoami Show details for the current session
Flags:
-h, --help help for pomerium
Use "pomerium [command] --help" for more information about a command.
`[1:], channelData.String())
}
}
func (s *StreamHandlerSuite) TestServeChannel_Session_InteractiveError() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "pty-req",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "exec",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
Command: "whoami",
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
channelData := s.channelDataLoop(peerID, stream, 1)
s.Equal("Error: 'whoami' is not an interactive command (try passing '-T' to ssh, or removing '-t')\r\n",
ansi.Strip(channelData.String()))
}
func printFrame(in string) string {
re := strings.NewReplacer(" ", "·", "\t", "🡒", "\n", "\n⤶", "\r", "⇤")
return re.Replace(ansi.Strip(in))
}
func postProcessFrame(in string) string {
return strings.ReplaceAll(ansi.Strip(in), "\r", "")
}
type routesPortalTestHookOutput struct {
stream *mockChannelStream
peerID uint32
}
func init() {
hook := func(s *StreamHandlerSuite) any {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "pty-req",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{
TermEnv: "dumb",
Width: 39,
Height: 10,
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "shell",
WantReply: true,
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
if !s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) {
channelData := s.channelDataLoop(peerID, stream, 0)
s.Equal(`
Usage:
pomerium [command]
Available Commands:
help Help about any command
logout Log out
whoami Show details for the current session
Flags:
-h, --help help for pomerium
Use "pomerium [command] --help" for more information about a command.
`[1:], channelData.String())
return nil
}
return &routesPortalTestHookOutput{
stream,
peerID,
}
}
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"], hook)
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"], hook)
}
func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal() {
res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput)
if res == nil {
return // routes portal disabled
}
stream, peerID := res.stream, res.peerID
frames := []string{
`
||
| Connect to which server? |
| |
| > 1. test@host1 |
| 2. test@host2 |
| |
| |
| ↑/k up • ↓/j down • q quit • ? more|
| |`[1:],
`
||
||
||
| 1. test@host1 |
| > 2. test@host2 |
||
||
||
||`[1:],
}
for i, frame := range frames {
frames[i] = strings.ReplaceAll(frame, "|", "")
}
var ok bool
var channelData bytes.Buffer
currentFrame := 0
start := time.Now()
frameAdvance := func() {
switch currentFrame {
case 0:
cursorDown := []byte(ansi.CursorDown(1))
currentFrame++
sendChannelMsg(stream, ssh.ChannelDataMsg{
PeersID: peerID,
Length: uint32(len(cursorDown)),
Rest: cursorDown,
})
case 1:
currentFrame++
ok = true
sendChannelMsg(stream, ssh.ChannelDataMsg{
PeersID: peerID,
Length: uint32(1),
Rest: []byte("q"),
})
}
channelData.Reset()
}
LOOP:
for time.Since(start) < DefaultTimeout {
response, err := stream.RecvServerToClient()
if err != nil {
s.Fail(err.Error())
break
}
bytes := response.GetRawBytes().GetValue()
switch bytes[0] {
case ssh.MsgChannelData:
if ok {
continue
}
var msg ssh.ChannelDataMsg
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
channelData.Write(msg.Rest)
if postProcessFrame(channelData.String()) == frames[currentFrame] {
frameAdvance()
if currentFrame >= len(frames) {
ok = true
}
}
case ssh.MsgChannelRequest:
// the only channel request we expect to send would be "exit-status"
var msg ssh.ChannelRequestMsg
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
s.Equal("exit-status", msg.Request)
s.Equal(uint32(0), binary.BigEndian.Uint32(msg.RequestSpecificData))
case ssh.MsgChannelClose:
sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID})
break LOOP
default:
s.FailNow("test bug")
}
}
currentFrameStr := ""
if !ok {
currentFrameStr = printFrame(frames[currentFrame])
}
s.Require().Truef(ok, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s",
currentFrame,
printFrame(postProcessFrame(channelData.String())),
currentFrameStr)
}
func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_Select() {
res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput)
if res == nil {
return // routes portal disabled
}
stream, peerID := res.stream, res.peerID
frames := []string{
`
||
| Connect to which server? |
| |
| > 1. test@host1 |
| 2. test@host2 |
| |
| |
| ↑/k up • ↓/j down • q quit • ? more|
| |`[1:],
`
||
||
||
| 1. test@host1 |
| > 2. test@host2 |
||
||
||
||`[1:],
`
||
| Connect to which server? |
| |
| 1. test@host1 |
| > 2. test@host2 |
| |
| |
| ↑/k up • ↓/j down • q quit …|
| |`[1:],
}
for i, frame := range frames {
frames[i] = strings.ReplaceAll(frame, "|", "")
}
var portalOk bool
var handoffOk bool
var expectHandoff bool
var channelData bytes.Buffer
currentFrame := 0
start := time.Now()
frameAdvance := func() {
switch currentFrame {
case 0:
cursorDown := []byte(ansi.CursorDown(1))
currentFrame++
sendChannelMsg(stream, ssh.ChannelDataMsg{
PeersID: peerID,
Length: uint32(len(cursorDown)),
Rest: cursorDown,
})
case 1:
currentFrame++
sendChannelMsg(stream, ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "window-change",
WantReply: false,
RequestSpecificData: gossh.Marshal(ssh.ChannelWindowChangeRequestMsg{
WidthColumns: 36,
HeightRows: 10,
}),
})
case 2:
currentFrame++
s.mockAuth.EXPECT().EvaluateDelayed(Any(), Any()).
DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo) error {
s.Equal(info.Username, ptr("test"))
s.Equal(info.Hostname, ptr("host2"))
return nil
})
expectHandoff = true
sendChannelMsg(stream, ssh.ChannelDataMsg{
PeersID: peerID,
Length: uint32(1),
Rest: []byte("\r"),
})
}
channelData.Reset()
}
LOOP:
for time.Since(start) < DefaultTimeout {
response, err := stream.RecvServerToClient()
if err != nil {
s.Fail(err.Error())
break
}
if expectHandoff {
if response.GetRawBytes() != nil {
// we might get bytes containing a newline
var msg ssh.ChannelDataMsg
s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg))
s.Require().Empty(strings.TrimSpace(ansi.Strip(string(msg.Rest))))
continue
}
action := response.GetChannelControl().GetControlAction()
s.Require().NotNil(action, "expected channel control action")
var sshAction extensions_ssh.SSHChannelControlAction
s.Require().NoError(action.UnmarshalTo(&sshAction))
handoff := sshAction.GetHandOff()
s.Require().NotNil(action, "expected handoff action")
s.Require().NotNil(handoff.GetUpstreamAuth().GetUpstream(), "expected upstream handoff action")
s.Equal("test", handoff.GetUpstreamAuth().Username)
s.Equal("host2", handoff.GetUpstreamAuth().GetUpstream().Hostname)
testutil.AssertProtoEqual(s.T(), []*extensions_ssh.AllowedMethod{
{
Method: "publickey",
MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{
PublicKey: s.ed25519SshPublicKey.Marshal(),
Permissions: &extensions_ssh.Permissions{},
}),
},
}, handoff.GetUpstreamAuth().GetUpstream().AllowedMethods)
handoffOk = true
break LOOP
}
bytes := response.GetRawBytes().GetValue()
s.Require().NotNil(bytes, response.String())
switch bytes[0] {
case ssh.MsgChannelData:
if portalOk {
continue
}
s.Require().False(expectHandoff)
var msg ssh.ChannelDataMsg
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
channelData.Write(msg.Rest)
if postProcessFrame(channelData.String()) == frames[currentFrame] {
frameAdvance()
if currentFrame >= len(frames) {
portalOk = true
}
}
default:
s.FailNow("test bug")
}
}
currentFrameStr := ""
if !portalOk {
currentFrameStr = printFrame(frames[currentFrame])
}
s.Truef(portalOk, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s",
currentFrame,
printFrame(postProcessFrame(channelData.String())),
currentFrameStr)
s.True(handoffOk, "timed out waiting for handoff")
sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID})
}
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Whoami() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
s.mockAuth.EXPECT().
FormatSession(Any(), Any()).
Return([]byte("example"), nil)
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "exec",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
Command: "whoami",
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
channelData := s.channelDataLoop(peerID, stream, 0)
s.Equal("example", channelData.String())
}
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_WhoamiError() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
s.mockAuth.EXPECT().
FormatSession(Any(), Any()).
Return(nil, errors.New("test error"))
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "exec",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
Command: "whoami",
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
channelData := s.channelDataLoop(peerID, stream, 1)
s.Equal("Error: couldn't fetch session: test error\r\n", channelData.String())
}
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Logout() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
s.mockAuth.EXPECT().
DeleteSession(Any(), Any()).
Return(nil)
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "exec",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
Command: "logout",
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
channelData := s.channelDataLoop(peerID, stream, 0)
s.Equal("Logged out successfully\r\n", channelData.String())
}
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_LogoutError() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "session",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
peerID := resp.MyID
s.mockAuth.EXPECT().
DeleteSession(Any(), Any()).
Return(errors.New("test error"))
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
PeersID: peerID,
Request: "exec",
WantReply: true,
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
Command: "logout",
}),
}))
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
channelData := s.channelDataLoop(peerID, stream, 1)
s.Equal("Error: failed to delete session: test error\r\n", channelData.String())
}
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_NoSubMsg() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "direct-tcpip",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
// error checked in cleanup
}
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_BadHostname() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "direct-tcpip",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{
DestAddr: "", // invalid
DestPort: 22,
SrcAddr: "127.0.0.1",
SrcPort: 12345,
}),
}))
// error checked in cleanup
}
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
s.mockAuth.EXPECT().
EvaluateDelayed(Any(), Any()).
Times(1).
Return(errors.New("test error"))
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "direct-tcpip",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{
DestAddr: "host1",
DestPort: 22,
SrcAddr: "127.0.0.1",
SrcPort: 12345,
}),
}))
// error checked in cleanup
}
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
s.mockAuth.EXPECT().
EvaluateDelayed(Any(), Any()).
Times(1).
Return(nil)
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "direct-tcpip",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{
DestAddr: "host1", // i.e. 'ssh -J pomerium test@host1'
DestPort: 22, // this will be sent by the ssh client, but is ignored
SrcAddr: "127.0.0.1",
SrcPort: 12345,
}),
}))
recv, err := stream.RecvServerToClient()
s.Require().NoError(err)
action := recv.GetChannelControl().GetControlAction()
s.Require().NotNil(action, "received a message, but it was not a channel control action")
handoff := extensions_ssh.SSHChannelControlAction{}
s.Require().NoError(action.UnmarshalTo(&handoff))
testutil.AssertProtoEqual(s.T(), extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamChannelInfo: &extensions_ssh.SSHDownstreamChannelInfo{
ChannelType: "direct-tcpip",
DownstreamChannelId: 2,
InternalUpstreamChannelId: 1,
InitialWindowSize: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
},
DownstreamPtyInfo: nil,
UpstreamAuth: &extensions_ssh.AllowResponse{
Username: "test",
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: "host1",
DirectTcpip: true,
AllowedMethods: []*extensions_ssh.AllowedMethod{
{
Method: "publickey",
MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{
PublicKey: s.ed25519SshPublicKey.Marshal(),
Permissions: &extensions_ssh.Permissions{},
}),
},
},
},
},
},
}, handoff.GetHandOff())
}
func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() {
stream := s.BeforeTestHookResult.(*mockChannelStream)
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
ChanType: "unknown",
PeersID: 2,
PeersWindow: ssh.ChannelWindowSize,
MaxPacketSize: ssh.ChannelMaxPacket,
}))
// error checked in cleanup
}
func (s *StreamHandlerSuite) TestFormatSession() {
s.mockAuth.EXPECT().
FormatSession(Any(), Any()).
Return([]byte("example"), nil)
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
// this will exit immediately, but it will have a state, which is only
// created upon calling Run()
sh.Run(ctx)
res, err := sh.FormatSession(s.T().Context())
s.NoError(err)
s.Equal([]byte("example"), res)
}
func (s *StreamHandlerSuite) TestDeleteSession() {
s.mockAuth.EXPECT().
DeleteSession(Any(), Any()).
Return(nil)
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
// this will exit immediately, but it will have a state, which is only
// created upon calling Run()
sh.Run(ctx)
err := sh.DeleteSession(s.T().Context())
s.NoError(err)
}
func (s *StreamHandlerSuite) TestRunCalledTwice() {
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
sh.Run(ctx)
s.PanicsWithValue("Run called twice", func() {
sh.Run(context.Background())
})
}
func (s *StreamHandlerSuite) TestAllSSHRoutes() {
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
routes := slices.Collect(sh.AllSSHRoutes())
s.Len(routes, 2)
s.Equal("ssh://host1", routes[0].From)
s.Equal("ssh://dest1:22", routes[0].To[0].String())
s.Equal("ssh://host2", routes[1].From)
s.Equal("ssh://dest2:22", routes[1].To[0].String())
next, stop := iter.Pull(sh.AllSSHRoutes())
v, ok := next()
s.NotNil(v)
s.True(ok)
stop()
v, ok = next()
s.Nil(v)
s.False(ok)
}
func TestStreamHandlerSuite(t *testing.T) {
suite.Run(t, &StreamHandlerSuite{})
}
func TestStreamHandlerSuiteWithRuntimeFlags(t *testing.T) {
suite.Run(t, &StreamHandlerSuite{
StreamHandlerSuiteOptions: StreamHandlerSuiteOptions{
ConfigModifiers: []func(*config.Config){
func(c *config.Config) {
c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true
},
},
},
})
}
func ptr[T any](t T) *T {
return &t
}