mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
Implement the pkg/ssh.AuthInterface. Add logic for converting from the ssh stream state to an evaluator request, and for interpreting the results of policy evaluation. Refactor some of the existing authorize logic to make it easier to reuse.
2078 lines
65 KiB
Go
2078 lines
65 KiB
Go
package ssh_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"iter"
|
|
"os"
|
|
"runtime"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/charmbracelet/x/ansi"
|
|
"github.com/stretchr/testify/suite"
|
|
. "go.uber.org/mock/gomock" //nolint
|
|
gossh "golang.org/x/crypto/ssh"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
|
|
|
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/testutil"
|
|
"github.com/pomerium/pomerium/pkg/ssh"
|
|
mock_ssh "github.com/pomerium/pomerium/pkg/ssh/mock"
|
|
)
|
|
|
|
var DefaultTimeout = 10 * time.Second
|
|
|
|
func init() {
|
|
if isDebuggerAttached() {
|
|
DefaultTimeout = 1 * time.Hour
|
|
}
|
|
}
|
|
|
|
func isDebuggerAttached() bool {
|
|
if runtime.GOOS == "linux" {
|
|
data, err := os.ReadFile("/proc/self/status")
|
|
if err == nil {
|
|
for line := range bytes.Lines(data) {
|
|
if bytes.HasPrefix(line, []byte("TracerPid:\t")) {
|
|
return line[11] != '0'
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func HookWithArgs(f func(s *StreamHandlerSuite, args []any) any, args ...any) []func(s *StreamHandlerSuite) any {
|
|
return []func(s *StreamHandlerSuite) any{
|
|
func(s *StreamHandlerSuite) any {
|
|
return f(s, args)
|
|
},
|
|
}
|
|
}
|
|
|
|
var (
|
|
StreamHandlerSuiteBeforeTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
|
StreamHandlerSuiteAfterTestHooks = map[string][]func(s *StreamHandlerSuite) any{}
|
|
)
|
|
|
|
type StreamHandlerSuiteOptions struct {
|
|
ConfigModifiers []func(*config.Config)
|
|
}
|
|
|
|
type StreamHandlerSuite struct {
|
|
suite.Suite
|
|
StreamHandlerSuiteOptions
|
|
|
|
ctrl *Controller
|
|
|
|
mgr *ssh.StreamManager
|
|
cfg *config.Config
|
|
|
|
cleanup []func()
|
|
errC chan error
|
|
|
|
mockAuth *mock_ssh.MockAuthInterface
|
|
|
|
ed25519PublicKey ed25519.PublicKey
|
|
ed25519PrivateKey ed25519.PrivateKey
|
|
ed25519SshPublicKey gossh.PublicKey
|
|
ed25519SshPrivateKey gossh.Signer
|
|
|
|
BeforeTestHookResult any
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) SetupTest() {
|
|
s.ctrl = NewController(s.T())
|
|
s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl)
|
|
s.mgr = ssh.NewStreamManager()
|
|
s.cleanup = []func(){}
|
|
s.errC = make(chan error, 1)
|
|
|
|
var err error
|
|
s.ed25519PublicKey, s.ed25519PrivateKey, err = ed25519.GenerateKey(rand.Reader)
|
|
s.Require().NoError(err)
|
|
s.ed25519SshPublicKey, err = gossh.NewPublicKey(s.ed25519PublicKey)
|
|
s.Require().NoError(err)
|
|
s.ed25519SshPrivateKey, err = gossh.NewSignerFromKey(s.ed25519PrivateKey)
|
|
s.Require().NoError(err)
|
|
|
|
s.cfg = &config.Config{Options: config.NewDefaultOptions()}
|
|
s.cfg.Options.Policies = []config.Policy{
|
|
{From: "https://from.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to.notssh.example.com")},
|
|
{From: "ssh://host1", To: mustParseWeightedURLs(s.T(), "ssh://dest1:22")},
|
|
{From: "https://from1.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to1.notssh.example.com")},
|
|
{From: "ssh://host2", To: mustParseWeightedURLs(s.T(), "ssh://dest2:22")},
|
|
{From: "https://from2.notssh.example.com", To: mustParseWeightedURLs(s.T(), "https://to2.notssh.example.com")},
|
|
}
|
|
for _, f := range s.ConfigModifiers {
|
|
f(s.cfg)
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TearDownTest() {
|
|
for _, f := range s.cleanup {
|
|
f()
|
|
}
|
|
s.ctrl.Finish()
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) BeforeTest(_, testName string) {
|
|
s.BeforeTestHookResult = nil
|
|
for _, fn := range StreamHandlerSuiteBeforeTestHooks[testName] {
|
|
s.BeforeTestHookResult = fn(s)
|
|
}
|
|
}
|
|
|
|
//
|
|
// Helper methods
|
|
//
|
|
|
|
func marshalAny(msg proto.Message) *anypb.Any {
|
|
a, err := anypb.New(msg)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return a
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
|
|
fn()
|
|
select {
|
|
case err := <-s.errC:
|
|
s.ErrorContains(err, msg)
|
|
case <-time.After(DefaultTimeout):
|
|
s.FailNowf("timed out waiting for error %q", msg)
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler {
|
|
sh := s.mgr.NewStreamHandler(
|
|
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
|
|
s.errC = make(chan error, 1)
|
|
ctx, ca := context.WithCancel(s.T().Context())
|
|
go func() {
|
|
defer close(s.errC)
|
|
s.errC <- sh.Run(ctx)
|
|
}()
|
|
s.cleanup = append(s.cleanup, func() {
|
|
start := time.Now()
|
|
for len(sh.ReadC()) > 0 && time.Since(start) < 100*time.Millisecond {
|
|
runtime.Gosched()
|
|
}
|
|
if len(sh.ReadC()) > 0 {
|
|
s.Fail(fmt.Sprintf("read channel contains %d unhandled client messages", len(sh.ReadC())))
|
|
}
|
|
ca()
|
|
var err error
|
|
select {
|
|
case err = <-s.errC:
|
|
case <-time.After(DefaultTimeout):
|
|
s.Fail("timed out waiting for stream handler to close")
|
|
}
|
|
|
|
sh.Close()
|
|
if err != nil {
|
|
s.Require().ErrorIs(err, context.Canceled)
|
|
}
|
|
if len(sh.WriteC()) != 0 {
|
|
logs := []string{"write channel contains unhandled server messages:"}
|
|
i := 0
|
|
for msg := range sh.WriteC() {
|
|
logs = append(logs, fmt.Sprintf("[%d]: %s", i, msg.String()))
|
|
i++
|
|
}
|
|
s.Fail(strings.Join(logs, "\n"))
|
|
}
|
|
})
|
|
return sh
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) msgDownstreamConnected(streamID uint64) *extensions_ssh.ClientMessage {
|
|
return &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_Event{
|
|
Event: &extensions_ssh.StreamEvent{
|
|
Event: &extensions_ssh.StreamEvent_DownstreamConnected{
|
|
DownstreamConnected: &extensions_ssh.DownstreamConnectEvent{
|
|
StreamId: streamID,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) msgDownstreamDisconnected(reason string) *extensions_ssh.ClientMessage {
|
|
return &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_Event{
|
|
Event: &extensions_ssh.StreamEvent{
|
|
Event: &extensions_ssh.StreamEvent_DownstreamDisconnected{
|
|
DownstreamDisconnected: &extensions_ssh.DownstreamDisconnectedEvent{
|
|
Reason: reason,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) msgUpstreamConnected(streamID uint64) *extensions_ssh.ClientMessage {
|
|
return &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_Event{
|
|
Event: &extensions_ssh.StreamEvent{
|
|
Event: &extensions_ssh.StreamEvent_UpstreamConnected{
|
|
UpstreamConnected: &extensions_ssh.UpstreamConnectEvent{
|
|
StreamId: streamID,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) expectAllowUpstream(sh *ssh.StreamHandler, hostname string) {
|
|
select {
|
|
case msg := <-sh.WriteC():
|
|
if authResp := msg.GetAuthResponse(); authResp != nil {
|
|
if allow := authResp.GetAllow(); allow != nil {
|
|
s.Require().NotNil(allow.GetUpstream(), "received an allow response, but not to an upstream target")
|
|
s.Require().Equal(hostname, allow.GetUpstream().GetHostname())
|
|
} else {
|
|
s.FailNowf("received an auth response, but it was not an allow response", authResp.String())
|
|
}
|
|
} else {
|
|
s.FailNow("received a message, but it was not an auth response", msg.String())
|
|
}
|
|
case <-time.After(DefaultTimeout):
|
|
s.FailNow("timed out waiting for upstream allow message")
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) expectDeny(sh *ssh.StreamHandler, partial bool, methods []string) {
|
|
select {
|
|
case msg := <-sh.WriteC():
|
|
if authResp := msg.GetAuthResponse(); authResp != nil {
|
|
if deny := authResp.GetDeny(); deny != nil {
|
|
s.Require().Equal(partial, deny.Partial)
|
|
s.Require().Equal(methods, deny.Methods)
|
|
} else {
|
|
s.Require().Fail("received an auth response, but it was not a deny response", authResp.String())
|
|
}
|
|
} else {
|
|
s.FailNow("received a message, but it was not an auth response", msg.String())
|
|
}
|
|
case <-time.After(DefaultTimeout):
|
|
s.FailNow("timed out waiting for deny message")
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) expectAllowInternal(sh *ssh.StreamHandler) {
|
|
select {
|
|
case msg := <-sh.WriteC():
|
|
if authResp := msg.GetAuthResponse(); authResp != nil {
|
|
if allow := authResp.GetAllow(); allow != nil {
|
|
s.Require().NotNil(allow.GetInternal(), "received an allow response, but not to an internal target")
|
|
} else {
|
|
s.FailNow("received an auth response, but it was not an allow response", authResp.String())
|
|
}
|
|
} else {
|
|
s.FailNow("received a message, but it was not an auth response", msg.String())
|
|
}
|
|
case <-time.After(DefaultTimeout):
|
|
s.FailNow("timed out waiting for internal allow message")
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) expectPrompt(sh *ssh.StreamHandler) {
|
|
select {
|
|
case msg := <-sh.WriteC():
|
|
if authResp := msg.GetAuthResponse(); authResp != nil {
|
|
if info := authResp.GetInfoRequest(); info != nil {
|
|
s.Require().NotNil(info.GetRequest(), "received a nil info request")
|
|
} else {
|
|
s.FailNow("received an auth response, but it was not an info request", authResp.String())
|
|
}
|
|
} else {
|
|
s.FailNow("received a message, but it was not an auth response", msg.String())
|
|
}
|
|
case <-time.After(DefaultTimeout):
|
|
s.FailNow("timed out waiting for prompt message")
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) validPublicKeyMethodRequest() *anypb.Any {
|
|
return marshalAny(&extensions_ssh.PublicKeyMethodRequest{
|
|
PublicKey: s.ed25519SshPublicKey.Marshal(),
|
|
PublicKeyAlg: s.ed25519SshPublicKey.Type(),
|
|
PublicKeyFingerprintSha256: []byte(gossh.FingerprintSHA256(s.ed25519SshPublicKey)),
|
|
})
|
|
}
|
|
|
|
//
|
|
// Tests
|
|
//
|
|
|
|
func (s *StreamHandlerSuite) TestDuplicateDownstreamConnectedEvent() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- s.msgDownstreamConnected(1)
|
|
}, "received duplicate downstream connected event")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestDownstreamDisconnectedEvent() {
|
|
sh := s.startStreamHandler(1)
|
|
sh.ReadC() <- s.msgDownstreamDisconnected("") // this just logs a message
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestUpstreamConnectedEvent() {
|
|
sh := s.startStreamHandler(1)
|
|
sh.ReadC() <- s.msgUpstreamConnected(1) // this just logs a message
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestInvalidEvent() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_Event{
|
|
Event: &extensions_ssh.StreamEvent{Event: nil},
|
|
},
|
|
}
|
|
}, "received invalid event")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidProtocol() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "not-ssh",
|
|
Service: "ssh-connection",
|
|
},
|
|
},
|
|
}
|
|
}, "invalid protocol: not-ssh")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidService() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-userauth",
|
|
},
|
|
},
|
|
}
|
|
}, "invalid service: ssh-userauth")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_InvalidMessage() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: nil,
|
|
}
|
|
}, "received invalid message")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_FirstRequestIsKeyboardInteractive() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
// first request should be publickey
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "keyboard-interactive",
|
|
},
|
|
},
|
|
}
|
|
}, "unexpected auth method: keyboard-interactive")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_MissingUsername() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "",
|
|
},
|
|
},
|
|
}
|
|
}, "username missing")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_EmptyHostname() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
// empty hostname is allowed initially
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: s.ed25519SshPublicKey.Marshal(),
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
}}, nil)
|
|
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
|
|
s.expectAllowInternal(sh)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_MismatchedAuthMethodAndRequestType() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "",
|
|
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
|
|
},
|
|
},
|
|
}
|
|
}, "invalid public key method request type")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequest() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Return(ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: s.ed25519SshPublicKey.Marshal(),
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
}}, nil)
|
|
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
|
|
s.expectAllowUpstream(sh, "host1")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_ValidPublicKeyMethodRequestError() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Return(ssh.PublicKeyAuthMethodResponse{}, errors.New("test error"))
|
|
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
}, "test error")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_PublicKeyRetry() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
i := -1
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
MaxTimes(4).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
i++
|
|
switch i {
|
|
case 0, 1, 2:
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
RequireAdditionalMethods: []string{"publickey"},
|
|
}, nil
|
|
case 3:
|
|
return ssh.PublicKeyAuthMethodResponse{Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: s.ed25519SshPublicKey.Marshal(),
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
}}, nil
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
})
|
|
|
|
for i := range 4 {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
if i < 3 {
|
|
s.expectDeny(sh, false, []string{"publickey"})
|
|
} else {
|
|
s.expectAllowUpstream(sh, "host1")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentUsername() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(1).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
RequireAdditionalMethods: []string{"publickey"},
|
|
}, nil
|
|
})
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
s.expectDeny(sh, false, []string{"publickey"})
|
|
s.Equal("test", *sh.Username())
|
|
s.Equal("host1", *sh.Hostname())
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test2",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
}, "inconsistent username")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentHostname() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(1).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
RequireAdditionalMethods: []string{"publickey"},
|
|
}, nil
|
|
})
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
s.expectDeny(sh, false, []string{"publickey"})
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host2",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
}, "inconsistent hostname")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_InconsistentEmptyHostname() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(1).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: req.PublicKey,
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
},
|
|
RequireAdditionalMethods: []string{"keyboard-interactive"},
|
|
}, nil
|
|
})
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
s.expectDeny(sh, true, []string{"keyboard-interactive"})
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "keyboard-interactive",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
|
|
},
|
|
},
|
|
}
|
|
}, "inconsistent hostname")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_UnknownAuthMethod() {
|
|
sh := s.startStreamHandler(1)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "password",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
},
|
|
},
|
|
}
|
|
}, "unexpected auth method: password")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_UnimplementedAuthMethod() {
|
|
sh := s.startStreamHandler(1)
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(1).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
RequireAdditionalMethods: []string{"password"},
|
|
}, nil
|
|
})
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
s.expectDeny(sh, false, []string{"password"})
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "password",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
},
|
|
},
|
|
}
|
|
}, "bug: server requested an unsupported auth method \"password\"")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_WrongClientMessage() {
|
|
sh := s.startStreamHandler(1)
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(1).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: req.PublicKey,
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
},
|
|
RequireAdditionalMethods: []string{"keyboard-interactive"},
|
|
}, nil
|
|
})
|
|
newMsg := func() *extensions_ssh.ClientMessage_AuthRequest {
|
|
return &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
}
|
|
}
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: newMsg(),
|
|
}
|
|
s.expectDeny(sh, true, []string{"keyboard-interactive"})
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: newMsg(),
|
|
}
|
|
}, "unexpected auth method: publickey")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongMethodRequestType() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(1).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: req.PublicKey,
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
},
|
|
RequireAdditionalMethods: []string{"keyboard-interactive"},
|
|
}, nil
|
|
})
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
s.expectDeny(sh, true, []string{"keyboard-interactive"})
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "keyboard-interactive",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
}, "invalid keyboard-interactive method request type")
|
|
}
|
|
|
|
func init() {
|
|
setupKeyboardInteractive := func(s *StreamHandlerSuite, input []any) any {
|
|
querierErr, _ := input[0].(error)
|
|
sh := s.startStreamHandler(100)
|
|
|
|
i := -1
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(2).
|
|
DoAndReturn(func(_ context.Context, _ ssh.StreamAuthInfo, _ *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
i++
|
|
switch i {
|
|
case 0:
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
RequireAdditionalMethods: []string{"publickey"},
|
|
}, nil
|
|
case 1:
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: s.ed25519SshPublicKey.Marshal(),
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
},
|
|
RequireAdditionalMethods: []string{"keyboard-interactive"},
|
|
}, nil
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
})
|
|
s.mockAuth.EXPECT().
|
|
HandleKeyboardInteractiveMethodRequest(Any(), Any(), Any(), Any()).
|
|
DoAndReturn(func(
|
|
ctx context.Context,
|
|
info ssh.StreamAuthInfo,
|
|
_ *extensions_ssh.KeyboardInteractiveMethodRequest,
|
|
querier ssh.KeyboardInteractiveQuerier,
|
|
) (ssh.KeyboardInteractiveAuthMethodResponse, error) {
|
|
s.Equal("test", *info.Username)
|
|
s.Equal("host1", *info.Hostname)
|
|
s.Equal(uint64(100), info.StreamID)
|
|
resp, err := querier.Prompt(ctx, &extensions_ssh.KeyboardInteractiveInfoPrompts{
|
|
Name: "test-name",
|
|
Instruction: "test-instruction",
|
|
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
|
|
{
|
|
Prompt: "test-prompt",
|
|
Echo: true,
|
|
},
|
|
},
|
|
})
|
|
s.Require().Equal(querierErr, err, "unexpected error from querier.Prompt")
|
|
if querierErr == nil {
|
|
s.Equal([]string{"test-prompt-response"}, resp.Responses)
|
|
return ssh.KeyboardInteractiveAuthMethodResponse{
|
|
Allow: &extensions_ssh.KeyboardInteractiveAllowResponse{},
|
|
}, nil
|
|
}
|
|
return ssh.KeyboardInteractiveAuthMethodResponse{}, err
|
|
})
|
|
for range 2 {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
s.expectDeny(sh, false, []string{"publickey"})
|
|
s.expectDeny(sh, true, []string{"keyboard-interactive"})
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "keyboard-interactive",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
|
|
},
|
|
},
|
|
}
|
|
|
|
return sh
|
|
}
|
|
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive"] = HookWithArgs(setupKeyboardInteractive, (error)(nil))
|
|
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_NoPromptReply"] = HookWithArgs(setupKeyboardInteractive, context.Canceled)
|
|
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.Internal, "received invalid info response"))
|
|
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid prompt response"))
|
|
StreamHandlerSuiteBeforeTestHooks["TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType"] = HookWithArgs(setupKeyboardInteractive, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response"))
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive() {
|
|
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
|
|
|
|
s.expectPrompt(sh)
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_InfoResponse{
|
|
InfoResponse: &extensions_ssh.InfoResponse{
|
|
Method: "keyboard-interactive",
|
|
Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{
|
|
Responses: []string{"test-prompt-response"},
|
|
}),
|
|
},
|
|
},
|
|
}
|
|
s.expectAllowUpstream(sh, "host1")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_NoPromptReply() {
|
|
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
|
|
s.expectPrompt(sh)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidInfoResponse() {
|
|
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
|
|
s.expectPrompt(sh)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_InfoResponse{
|
|
InfoResponse: &extensions_ssh.InfoResponse{
|
|
Method: "publickey",
|
|
Response: marshalAny(&extensions_ssh.KeyboardInteractiveInfoPromptResponses{
|
|
Responses: []string{"test-prompt-response"},
|
|
}),
|
|
},
|
|
},
|
|
}
|
|
}, "received invalid info response")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_InvalidPromptResponse() {
|
|
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
|
|
s.expectPrompt(sh)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_InfoResponse{
|
|
InfoResponse: &extensions_ssh.InfoResponse{
|
|
Method: "keyboard-interactive",
|
|
Response: nil,
|
|
},
|
|
},
|
|
}
|
|
}, "received invalid prompt response")
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestHandleAuthRequest_KeyboardInteractive_WrongResponseMessageType() {
|
|
sh := s.BeforeTestHookResult.(*ssh.StreamHandler)
|
|
s.expectPrompt(sh)
|
|
s.expectError(func() {
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "keyboard-interactive",
|
|
Username: "test",
|
|
Hostname: "host1",
|
|
MethodRequest: marshalAny(&extensions_ssh.KeyboardInteractiveMethodRequest{}),
|
|
},
|
|
},
|
|
}
|
|
}, "received invalid message, expecting info response")
|
|
}
|
|
|
|
type mockGrpcServerStream struct {
|
|
grpc.ServerStream
|
|
ctx context.Context
|
|
}
|
|
|
|
func (s *mockGrpcServerStream) Context() context.Context {
|
|
return s.ctx
|
|
}
|
|
|
|
type mockChannelStream struct {
|
|
*grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]
|
|
|
|
closeServerToClientOnce sync.Once
|
|
serverToClient chan *extensions_ssh.ChannelMessage
|
|
closeClientToServerOnce sync.Once
|
|
clientToServer chan *extensions_ssh.ChannelMessage
|
|
}
|
|
|
|
func newMockChannelStream(t *testing.T) *mockChannelStream {
|
|
cs := &mockChannelStream{
|
|
GenericServerStream: &grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]{
|
|
ServerStream: &mockGrpcServerStream{
|
|
ctx: t.Context(),
|
|
},
|
|
},
|
|
serverToClient: make(chan *extensions_ssh.ChannelMessage, 32),
|
|
clientToServer: make(chan *extensions_ssh.ChannelMessage, 32),
|
|
}
|
|
t.Cleanup(func() {
|
|
cs.CloseClientToServer()
|
|
cs.CloseServerToClient()
|
|
})
|
|
return cs
|
|
}
|
|
|
|
func (cs *mockChannelStream) Send(msg *extensions_ssh.ChannelMessage) error {
|
|
cs.serverToClient <- msg
|
|
return nil
|
|
}
|
|
|
|
func (cs *mockChannelStream) Recv() (*extensions_ssh.ChannelMessage, error) {
|
|
msg, ok := <-cs.clientToServer
|
|
if !ok {
|
|
return nil, io.EOF
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
func (cs *mockChannelStream) SendClientToServer(msg *extensions_ssh.ChannelMessage) {
|
|
cs.clientToServer <- msg
|
|
}
|
|
|
|
func (cs *mockChannelStream) CloseClientToServer() {
|
|
cs.closeClientToServerOnce.Do(func() {
|
|
close(cs.clientToServer)
|
|
})
|
|
}
|
|
|
|
func (cs *mockChannelStream) CloseServerToClient() {
|
|
cs.closeServerToClientOnce.Do(func() {
|
|
close(cs.serverToClient)
|
|
})
|
|
}
|
|
|
|
func (cs *mockChannelStream) RecvServerToClient() (*extensions_ssh.ChannelMessage, error) {
|
|
select {
|
|
case msg, ok := <-cs.serverToClient:
|
|
if !ok {
|
|
return nil, io.EOF
|
|
}
|
|
return msg, nil
|
|
case <-time.After(DefaultTimeout):
|
|
return nil, errors.New("timed out waiting for server to send message")
|
|
}
|
|
}
|
|
|
|
var _ extensions_ssh.StreamManagement_ServeChannelServer = (*mockChannelStream)(nil)
|
|
|
|
func channelMsg(input any) *extensions_ssh.ChannelMessage {
|
|
return &extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
|
RawBytes: wrapperspb.Bytes(gossh.Marshal(input)),
|
|
},
|
|
}
|
|
}
|
|
|
|
func recvChannelMsg[T any](s *StreamHandlerSuite, stream *mockChannelStream) T {
|
|
response, err := stream.RecvServerToClient()
|
|
s.Require().NoError(err)
|
|
var msg T
|
|
s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg))
|
|
return msg
|
|
}
|
|
|
|
func sendChannelMsg(stream *mockChannelStream, msg any) {
|
|
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
|
RawBytes: &wrapperspb.BytesValue{
|
|
Value: gossh.Marshal(msg),
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_InitialRecvError() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
stream := newMockChannelStream(s.T())
|
|
stream.CloseClientToServer()
|
|
s.Error(io.EOF, sh.ServeChannel(stream))
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotRawBytes() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
stream := newMockChannelStream(s.T())
|
|
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_Metadata{},
|
|
})
|
|
s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream))
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_InitialRecvIsNotChannelOpen() {
|
|
sh := s.startStreamHandler(1)
|
|
|
|
stream := newMockChannelStream(s.T())
|
|
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
|
RawBytes: wrapperspb.Bytes([]byte("not ChannelOpen")),
|
|
},
|
|
})
|
|
s.ErrorIs(status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen"), sh.ServeChannel(stream))
|
|
}
|
|
|
|
func init() {
|
|
hook := func(s *StreamHandlerSuite, args []any) any {
|
|
errorMatcher := args[0].(Matcher)
|
|
sh := s.startStreamHandler(1)
|
|
|
|
s.mockAuth.EXPECT().
|
|
HandlePublicKeyMethodRequest(Any(), Any(), Any()).
|
|
Times(1).
|
|
DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (ssh.PublicKeyAuthMethodResponse, error) {
|
|
s.Equal("test", *info.Username)
|
|
s.Equal("", *info.Hostname)
|
|
return ssh.PublicKeyAuthMethodResponse{
|
|
Allow: &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: req.PublicKey,
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
},
|
|
RequireAdditionalMethods: []string{},
|
|
}, nil
|
|
})
|
|
s.False(sh.IsExpectingInternalChannel())
|
|
sh.ReadC() <- &extensions_ssh.ClientMessage{
|
|
Message: &extensions_ssh.ClientMessage_AuthRequest{
|
|
AuthRequest: &extensions_ssh.AuthenticationRequest{
|
|
Protocol: "ssh",
|
|
Service: "ssh-connection",
|
|
AuthMethod: "publickey",
|
|
Username: "test",
|
|
Hostname: "",
|
|
MethodRequest: s.validPublicKeyMethodRequest(),
|
|
},
|
|
},
|
|
}
|
|
s.expectAllowInternal(sh)
|
|
s.True(sh.IsExpectingInternalChannel())
|
|
s.Equal("test", *sh.Username())
|
|
s.Equal("", *sh.Hostname())
|
|
|
|
stream := newMockChannelStream(s.T())
|
|
errC := make(chan error, 1)
|
|
go func() {
|
|
errC <- sh.ServeChannel(stream)
|
|
stream.CloseServerToClient()
|
|
}()
|
|
s.cleanup = append(s.cleanup, func() {
|
|
stream.CloseClientToServer()
|
|
select {
|
|
case err := <-errC:
|
|
s.Truef(errorMatcher.Matches(err), "expected: %v\nactual: %v", errorMatcher.String(), err)
|
|
case <-time.After(DefaultTimeout):
|
|
s.FailNow("timed out waiting for ServeChannel to exit")
|
|
}
|
|
})
|
|
return stream
|
|
}
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_DifferentWindowAndPacketSizes"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_NoSubMsg"] = HookWithArgs(hook, Not(Nil()))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_BadHostname"] = HookWithArgs(hook, Not(Nil()))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip_AuthFailed"] = HookWithArgs(hook, Eq(status.Errorf(codes.PermissionDenied, "test error")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_DirectTcpip"] = HookWithArgs(hook, Nil())
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_InvalidChannelType"] = HookWithArgs(hook, Eq(status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: unknown")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ExecWithPtyHelp"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Whoami"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_WhoamiError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_Logout"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_Exec_LogoutError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_NonInteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_InteractiveError"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = HookWithArgs(hook, Eq(status.Errorf(codes.Canceled, "channel closed")))
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_ChannelCloseResponseTimeout"] = HookWithArgs(hook, Eq(status.Errorf(codes.DeadlineExceeded, "timed out waiting for channel close")))
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize)
|
|
s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow)
|
|
s.Equal(uint32(2), resp.PeersID)
|
|
s.Equal(uint32(1), resp.MyID)
|
|
sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id
|
|
recvChannelMsg[ssh.ChannelCloseMsg](s, stream)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_DifferentWindowAndPacketSizes() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2, // client id
|
|
PeersWindow: ssh.ChannelWindowSize / 2,
|
|
MaxPacketSize: ssh.ChannelMaxPacket / 2,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
s.Equal(uint32(ssh.ChannelMaxPacket), resp.MaxPacketSize)
|
|
s.Equal(uint32(ssh.ChannelWindowSize), resp.MyWindow)
|
|
s.Equal(uint32(2), resp.PeersID) // client id
|
|
s.Equal(uint32(1), resp.MyID) // server id
|
|
sendChannelMsg(stream, ssh.ChannelCloseMsg{resp.MyID}) // server id
|
|
recvChannelMsg[ssh.ChannelCloseMsg](s, stream)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) channelDataLoop(peerID uint32, stream *mockChannelStream, exitCode ...uint32) *bytes.Buffer {
|
|
s.T().Helper()
|
|
var channelData bytes.Buffer
|
|
for {
|
|
response, err := stream.RecvServerToClient()
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
s.Require().NoError(err)
|
|
bytes := response.GetRawBytes().GetValue()
|
|
switch bytes[0] {
|
|
case ssh.MsgChannelData:
|
|
var msg ssh.ChannelDataMsg
|
|
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
|
|
channelData.Write(msg.Rest)
|
|
case ssh.MsgChannelRequest:
|
|
var msg ssh.ChannelRequestMsg
|
|
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
|
|
s.Equal("exit-status", msg.Request)
|
|
s.Require().NotEmpty(exitCode, "received an exit-status ChannelRequest but the test did not assert an exit code")
|
|
expected := exitCode[0]
|
|
actual := binary.BigEndian.Uint32(msg.RequestSpecificData)
|
|
s.Equal(expected, actual)
|
|
case ssh.MsgChannelClose:
|
|
sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID})
|
|
}
|
|
}
|
|
return &channelData
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_ExecWithPtyHelp() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "pty-req",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "exec",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
|
|
Command: "--help",
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
|
|
maybeRoutesPortalCmd := ""
|
|
if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) {
|
|
maybeRoutesPortalCmd = ` portal Interactive route portal
|
|
`
|
|
}
|
|
channelData := s.channelDataLoop(peerID, stream, 0)
|
|
s.Equal(`
|
|
Usage:
|
|
pomerium [command]
|
|
|
|
Available Commands:
|
|
help Help about any command
|
|
logout Log out
|
|
`[1:]+maybeRoutesPortalCmd+
|
|
` whoami Show details for the current session
|
|
|
|
Flags:
|
|
-h, --help help for pomerium
|
|
|
|
Use "pomerium [command] --help" for more information about a command.
|
|
`, channelData.String())
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_ChannelCloseResponseTimeout() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "pty-req",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "exec",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
|
|
Command: "--help",
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
for {
|
|
response, err := stream.RecvServerToClient()
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
s.Require().NoError(err)
|
|
bytes := response.GetRawBytes().GetValue()
|
|
switch bytes[0] {
|
|
case ssh.MsgChannelData:
|
|
var msg ssh.ChannelDataMsg
|
|
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
|
|
case ssh.MsgChannelClose:
|
|
// don't send a response
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_NonInteractiveError() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "shell",
|
|
WantReply: true,
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
|
|
if s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) {
|
|
channelData := s.channelDataLoop(peerID, stream, 1)
|
|
s.Equal("Error: 'portal' is an interactive command and requires a TTY (try passing '-t' to ssh)\n",
|
|
ansi.Strip(channelData.String()))
|
|
} else {
|
|
channelData := s.channelDataLoop(peerID, stream, 0)
|
|
s.Equal(`
|
|
Usage:
|
|
pomerium [command]
|
|
|
|
Available Commands:
|
|
help Help about any command
|
|
logout Log out
|
|
whoami Show details for the current session
|
|
|
|
Flags:
|
|
-h, --help help for pomerium
|
|
|
|
Use "pomerium [command] --help" for more information about a command.
|
|
`[1:], channelData.String())
|
|
}
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_InteractiveError() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "pty-req",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "exec",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
|
|
Command: "whoami",
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
channelData := s.channelDataLoop(peerID, stream, 1)
|
|
s.Equal("Error: 'whoami' is not an interactive command (try passing '-T' to ssh, or removing '-t')\r\n",
|
|
ansi.Strip(channelData.String()))
|
|
}
|
|
|
|
func printFrame(in string) string {
|
|
re := strings.NewReplacer(" ", "·", "\t", "🡒", "\n", "\n⤶", "\r", "⇤")
|
|
return re.Replace(ansi.Strip(in))
|
|
}
|
|
|
|
func postProcessFrame(in string) string {
|
|
return strings.ReplaceAll(ansi.Strip(in), "\r", "")
|
|
}
|
|
|
|
type routesPortalTestHookOutput struct {
|
|
stream *mockChannelStream
|
|
peerID uint32
|
|
}
|
|
|
|
func init() {
|
|
hook := func(s *StreamHandlerSuite) any {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "pty-req",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.PtyReqChannelRequestMsg{
|
|
TermEnv: "dumb",
|
|
Width: 39,
|
|
Height: 10,
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "shell",
|
|
WantReply: true,
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
|
|
if !s.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHRoutesPortal) {
|
|
channelData := s.channelDataLoop(peerID, stream, 0)
|
|
s.Equal(`
|
|
Usage:
|
|
pomerium [command]
|
|
|
|
Available Commands:
|
|
help Help about any command
|
|
logout Log out
|
|
whoami Show details for the current session
|
|
|
|
Flags:
|
|
-h, --help help for pomerium
|
|
|
|
Use "pomerium [command] --help" for more information about a command.
|
|
`[1:], channelData.String())
|
|
return nil
|
|
}
|
|
return &routesPortalTestHookOutput{
|
|
stream,
|
|
peerID,
|
|
}
|
|
}
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal"], hook)
|
|
StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"] = append(StreamHandlerSuiteBeforeTestHooks["TestServeChannel_Session_RoutesPortal_Select"], hook)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal() {
|
|
res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput)
|
|
if res == nil {
|
|
return // routes portal disabled
|
|
}
|
|
stream, peerID := res.stream, res.peerID
|
|
|
|
frames := []string{
|
|
`
|
|
||
|
|
| Connect to which server? |
|
|
| |
|
|
| > 1. test@host1 |
|
|
| 2. test@host2 |
|
|
| |
|
|
| |
|
|
| ↑/k up • ↓/j down • q quit • ? more|
|
|
| |`[1:],
|
|
`
|
|
||
|
|
||
|
|
||
|
|
| 1. test@host1 |
|
|
| > 2. test@host2 |
|
|
||
|
|
||
|
|
||
|
|
||`[1:],
|
|
}
|
|
for i, frame := range frames {
|
|
frames[i] = strings.ReplaceAll(frame, "|", "")
|
|
}
|
|
var ok bool
|
|
var channelData bytes.Buffer
|
|
currentFrame := 0
|
|
start := time.Now()
|
|
frameAdvance := func() {
|
|
switch currentFrame {
|
|
case 0:
|
|
cursorDown := []byte(ansi.CursorDown(1))
|
|
currentFrame++
|
|
sendChannelMsg(stream, ssh.ChannelDataMsg{
|
|
PeersID: peerID,
|
|
Length: uint32(len(cursorDown)),
|
|
Rest: cursorDown,
|
|
})
|
|
case 1:
|
|
currentFrame++
|
|
ok = true
|
|
sendChannelMsg(stream, ssh.ChannelDataMsg{
|
|
PeersID: peerID,
|
|
Length: uint32(1),
|
|
Rest: []byte("q"),
|
|
})
|
|
}
|
|
channelData.Reset()
|
|
}
|
|
LOOP:
|
|
for time.Since(start) < DefaultTimeout {
|
|
response, err := stream.RecvServerToClient()
|
|
if err != nil {
|
|
s.Fail(err.Error())
|
|
break
|
|
}
|
|
|
|
bytes := response.GetRawBytes().GetValue()
|
|
switch bytes[0] {
|
|
case ssh.MsgChannelData:
|
|
if ok {
|
|
continue
|
|
}
|
|
var msg ssh.ChannelDataMsg
|
|
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
|
|
channelData.Write(msg.Rest)
|
|
if postProcessFrame(channelData.String()) == frames[currentFrame] {
|
|
frameAdvance()
|
|
if currentFrame >= len(frames) {
|
|
ok = true
|
|
}
|
|
}
|
|
case ssh.MsgChannelRequest:
|
|
// the only channel request we expect to send would be "exit-status"
|
|
var msg ssh.ChannelRequestMsg
|
|
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
|
|
s.Equal("exit-status", msg.Request)
|
|
s.Equal(uint32(0), binary.BigEndian.Uint32(msg.RequestSpecificData))
|
|
case ssh.MsgChannelClose:
|
|
sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID})
|
|
break LOOP
|
|
default:
|
|
s.FailNow("test bug")
|
|
}
|
|
}
|
|
currentFrameStr := ""
|
|
if !ok {
|
|
currentFrameStr = printFrame(frames[currentFrame])
|
|
}
|
|
s.Require().Truef(ok, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s",
|
|
currentFrame,
|
|
printFrame(postProcessFrame(channelData.String())),
|
|
currentFrameStr)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_RoutesPortal_Select() {
|
|
res, _ := s.BeforeTestHookResult.(*routesPortalTestHookOutput)
|
|
if res == nil {
|
|
return // routes portal disabled
|
|
}
|
|
stream, peerID := res.stream, res.peerID
|
|
|
|
frames := []string{
|
|
`
|
|
||
|
|
| Connect to which server? |
|
|
| |
|
|
| > 1. test@host1 |
|
|
| 2. test@host2 |
|
|
| |
|
|
| |
|
|
| ↑/k up • ↓/j down • q quit • ? more|
|
|
| |`[1:],
|
|
`
|
|
||
|
|
||
|
|
||
|
|
| 1. test@host1 |
|
|
| > 2. test@host2 |
|
|
||
|
|
||
|
|
||
|
|
||`[1:],
|
|
`
|
|
||
|
|
| Connect to which server? |
|
|
| |
|
|
| 1. test@host1 |
|
|
| > 2. test@host2 |
|
|
| |
|
|
| |
|
|
| ↑/k up • ↓/j down • q quit …|
|
|
| |`[1:],
|
|
}
|
|
for i, frame := range frames {
|
|
frames[i] = strings.ReplaceAll(frame, "|", "")
|
|
}
|
|
var portalOk bool
|
|
var handoffOk bool
|
|
var expectHandoff bool
|
|
|
|
var channelData bytes.Buffer
|
|
currentFrame := 0
|
|
start := time.Now()
|
|
frameAdvance := func() {
|
|
switch currentFrame {
|
|
case 0:
|
|
cursorDown := []byte(ansi.CursorDown(1))
|
|
currentFrame++
|
|
sendChannelMsg(stream, ssh.ChannelDataMsg{
|
|
PeersID: peerID,
|
|
Length: uint32(len(cursorDown)),
|
|
Rest: cursorDown,
|
|
})
|
|
case 1:
|
|
currentFrame++
|
|
|
|
sendChannelMsg(stream, ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "window-change",
|
|
WantReply: false,
|
|
RequestSpecificData: gossh.Marshal(ssh.ChannelWindowChangeRequestMsg{
|
|
WidthColumns: 36,
|
|
HeightRows: 10,
|
|
}),
|
|
})
|
|
case 2:
|
|
currentFrame++
|
|
s.mockAuth.EXPECT().EvaluateDelayed(Any(), Any()).
|
|
DoAndReturn(func(_ context.Context, info ssh.StreamAuthInfo) error {
|
|
s.Equal(info.Username, ptr("test"))
|
|
s.Equal(info.Hostname, ptr("host2"))
|
|
return nil
|
|
})
|
|
expectHandoff = true
|
|
sendChannelMsg(stream, ssh.ChannelDataMsg{
|
|
PeersID: peerID,
|
|
Length: uint32(1),
|
|
Rest: []byte("\r"),
|
|
})
|
|
}
|
|
channelData.Reset()
|
|
}
|
|
LOOP:
|
|
for time.Since(start) < DefaultTimeout {
|
|
response, err := stream.RecvServerToClient()
|
|
if err != nil {
|
|
s.Fail(err.Error())
|
|
break
|
|
}
|
|
|
|
if expectHandoff {
|
|
if response.GetRawBytes() != nil {
|
|
// we might get bytes containing a newline
|
|
var msg ssh.ChannelDataMsg
|
|
s.Require().NoError(gossh.Unmarshal(response.GetRawBytes().GetValue(), &msg))
|
|
s.Require().Empty(strings.TrimSpace(ansi.Strip(string(msg.Rest))))
|
|
continue
|
|
}
|
|
action := response.GetChannelControl().GetControlAction()
|
|
s.Require().NotNil(action, "expected channel control action")
|
|
var sshAction extensions_ssh.SSHChannelControlAction
|
|
s.Require().NoError(action.UnmarshalTo(&sshAction))
|
|
handoff := sshAction.GetHandOff()
|
|
s.Require().NotNil(action, "expected handoff action")
|
|
s.Require().NotNil(handoff.GetUpstreamAuth().GetUpstream(), "expected upstream handoff action")
|
|
s.Equal("test", handoff.GetUpstreamAuth().Username)
|
|
s.Equal("host2", handoff.GetUpstreamAuth().GetUpstream().Hostname)
|
|
testutil.AssertProtoEqual(s.T(), []*extensions_ssh.AllowedMethod{
|
|
{
|
|
Method: "publickey",
|
|
MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: s.ed25519SshPublicKey.Marshal(),
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
}),
|
|
},
|
|
}, handoff.GetUpstreamAuth().GetUpstream().AllowedMethods)
|
|
handoffOk = true
|
|
break LOOP
|
|
}
|
|
|
|
bytes := response.GetRawBytes().GetValue()
|
|
s.Require().NotNil(bytes, response.String())
|
|
switch bytes[0] {
|
|
case ssh.MsgChannelData:
|
|
if portalOk {
|
|
continue
|
|
}
|
|
s.Require().False(expectHandoff)
|
|
|
|
var msg ssh.ChannelDataMsg
|
|
s.Require().NoError(gossh.Unmarshal(bytes, &msg))
|
|
channelData.Write(msg.Rest)
|
|
if postProcessFrame(channelData.String()) == frames[currentFrame] {
|
|
frameAdvance()
|
|
if currentFrame >= len(frames) {
|
|
portalOk = true
|
|
}
|
|
}
|
|
default:
|
|
s.FailNow("test bug")
|
|
}
|
|
}
|
|
currentFrameStr := ""
|
|
if !portalOk {
|
|
currentFrameStr = printFrame(frames[currentFrame])
|
|
}
|
|
s.Truef(portalOk, "timed out waiting for frame %d\nbuffer:\n%s\nexpecting:\n%s",
|
|
currentFrame,
|
|
printFrame(postProcessFrame(channelData.String())),
|
|
currentFrameStr)
|
|
s.True(handoffOk, "timed out waiting for handoff")
|
|
sendChannelMsg(stream, ssh.ChannelCloseMsg{PeersID: peerID})
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Whoami() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
|
|
s.mockAuth.EXPECT().
|
|
FormatSession(Any(), Any()).
|
|
Return([]byte("example"), nil)
|
|
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "exec",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
|
|
Command: "whoami",
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
|
|
channelData := s.channelDataLoop(peerID, stream, 0)
|
|
s.Equal("example", channelData.String())
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_WhoamiError() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
|
|
s.mockAuth.EXPECT().
|
|
FormatSession(Any(), Any()).
|
|
Return(nil, errors.New("test error"))
|
|
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "exec",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
|
|
Command: "whoami",
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
|
|
channelData := s.channelDataLoop(peerID, stream, 1)
|
|
s.Equal("Error: couldn't fetch session: test error\r\n", channelData.String())
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_Logout() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
|
|
s.mockAuth.EXPECT().
|
|
DeleteSession(Any(), Any()).
|
|
Return(nil)
|
|
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "exec",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
|
|
Command: "logout",
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
|
|
channelData := s.channelDataLoop(peerID, stream, 0)
|
|
s.Equal("Logged out successfully\r\n", channelData.String())
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_Session_Exec_LogoutError() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "session",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
resp := recvChannelMsg[ssh.ChannelOpenConfirmMsg](s, stream)
|
|
peerID := resp.MyID
|
|
|
|
s.mockAuth.EXPECT().
|
|
DeleteSession(Any(), Any()).
|
|
Return(errors.New("test error"))
|
|
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{
|
|
PeersID: peerID,
|
|
Request: "exec",
|
|
WantReply: true,
|
|
RequestSpecificData: gossh.Marshal(ssh.ExecChannelRequestMsg{
|
|
Command: "logout",
|
|
}),
|
|
}))
|
|
recvChannelMsg[ssh.ChannelRequestSuccessMsg](s, stream)
|
|
|
|
channelData := s.channelDataLoop(peerID, stream, 1)
|
|
s.Equal("Error: failed to delete session: test error\r\n", channelData.String())
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_NoSubMsg() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "direct-tcpip",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
// error checked in cleanup
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_BadHostname() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "direct-tcpip",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{
|
|
DestAddr: "", // invalid
|
|
DestPort: 22,
|
|
SrcAddr: "127.0.0.1",
|
|
SrcPort: 12345,
|
|
}),
|
|
}))
|
|
// error checked in cleanup
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip_AuthFailed() {
|
|
s.mockAuth.EXPECT().
|
|
EvaluateDelayed(Any(), Any()).
|
|
Times(1).
|
|
Return(errors.New("test error"))
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "direct-tcpip",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{
|
|
DestAddr: "host1",
|
|
DestPort: 22,
|
|
SrcAddr: "127.0.0.1",
|
|
SrcPort: 12345,
|
|
}),
|
|
}))
|
|
// error checked in cleanup
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_DirectTcpip() {
|
|
s.mockAuth.EXPECT().
|
|
EvaluateDelayed(Any(), Any()).
|
|
Times(1).
|
|
Return(nil)
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "direct-tcpip",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
TypeSpecificData: gossh.Marshal(ssh.ChannelOpenDirectMsg{
|
|
DestAddr: "host1", // i.e. 'ssh -J pomerium test@host1'
|
|
DestPort: 22, // this will be sent by the ssh client, but is ignored
|
|
SrcAddr: "127.0.0.1",
|
|
SrcPort: 12345,
|
|
}),
|
|
}))
|
|
recv, err := stream.RecvServerToClient()
|
|
s.Require().NoError(err)
|
|
action := recv.GetChannelControl().GetControlAction()
|
|
s.Require().NotNil(action, "received a message, but it was not a channel control action")
|
|
handoff := extensions_ssh.SSHChannelControlAction{}
|
|
s.Require().NoError(action.UnmarshalTo(&handoff))
|
|
testutil.AssertProtoEqual(s.T(), extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
|
DownstreamChannelInfo: &extensions_ssh.SSHDownstreamChannelInfo{
|
|
ChannelType: "direct-tcpip",
|
|
DownstreamChannelId: 2,
|
|
InternalUpstreamChannelId: 1,
|
|
InitialWindowSize: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
},
|
|
DownstreamPtyInfo: nil,
|
|
UpstreamAuth: &extensions_ssh.AllowResponse{
|
|
Username: "test",
|
|
Target: &extensions_ssh.AllowResponse_Upstream{
|
|
Upstream: &extensions_ssh.UpstreamTarget{
|
|
Hostname: "host1",
|
|
DirectTcpip: true,
|
|
AllowedMethods: []*extensions_ssh.AllowedMethod{
|
|
{
|
|
Method: "publickey",
|
|
MethodData: marshalAny(&extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: s.ed25519SshPublicKey.Marshal(),
|
|
Permissions: &extensions_ssh.Permissions{},
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, handoff.GetHandOff())
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestServeChannel_InvalidChannelType() {
|
|
stream := s.BeforeTestHookResult.(*mockChannelStream)
|
|
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{
|
|
ChanType: "unknown",
|
|
PeersID: 2,
|
|
PeersWindow: ssh.ChannelWindowSize,
|
|
MaxPacketSize: ssh.ChannelMaxPacket,
|
|
}))
|
|
// error checked in cleanup
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestFormatSession() {
|
|
s.mockAuth.EXPECT().
|
|
FormatSession(Any(), Any()).
|
|
Return([]byte("example"), nil)
|
|
sh := s.mgr.NewStreamHandler(
|
|
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
|
ctx, ca := context.WithCancel(context.Background())
|
|
ca()
|
|
// this will exit immediately, but it will have a state, which is only
|
|
// created upon calling Run()
|
|
sh.Run(ctx)
|
|
|
|
res, err := sh.FormatSession(s.T().Context())
|
|
s.NoError(err)
|
|
s.Equal([]byte("example"), res)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestDeleteSession() {
|
|
s.mockAuth.EXPECT().
|
|
DeleteSession(Any(), Any()).
|
|
Return(nil)
|
|
sh := s.mgr.NewStreamHandler(
|
|
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
|
ctx, ca := context.WithCancel(context.Background())
|
|
ca()
|
|
// this will exit immediately, but it will have a state, which is only
|
|
// created upon calling Run()
|
|
sh.Run(ctx)
|
|
|
|
err := sh.DeleteSession(s.T().Context())
|
|
s.NoError(err)
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestRunCalledTwice() {
|
|
sh := s.mgr.NewStreamHandler(
|
|
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
|
ctx, ca := context.WithCancel(context.Background())
|
|
ca()
|
|
sh.Run(ctx)
|
|
s.PanicsWithValue("Run called twice", func() {
|
|
sh.Run(context.Background())
|
|
})
|
|
}
|
|
|
|
func (s *StreamHandlerSuite) TestAllSSHRoutes() {
|
|
sh := s.mgr.NewStreamHandler(
|
|
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
|
routes := slices.Collect(sh.AllSSHRoutes())
|
|
s.Len(routes, 2)
|
|
s.Equal("ssh://host1", routes[0].From)
|
|
s.Equal("ssh://dest1:22", routes[0].To[0].String())
|
|
s.Equal("ssh://host2", routes[1].From)
|
|
s.Equal("ssh://dest2:22", routes[1].To[0].String())
|
|
|
|
next, stop := iter.Pull(sh.AllSSHRoutes())
|
|
v, ok := next()
|
|
s.NotNil(v)
|
|
s.True(ok)
|
|
stop()
|
|
v, ok = next()
|
|
s.Nil(v)
|
|
s.False(ok)
|
|
}
|
|
|
|
func TestStreamHandlerSuite(t *testing.T) {
|
|
suite.Run(t, &StreamHandlerSuite{})
|
|
}
|
|
|
|
func TestStreamHandlerSuiteWithRuntimeFlags(t *testing.T) {
|
|
suite.Run(t, &StreamHandlerSuite{
|
|
StreamHandlerSuiteOptions: StreamHandlerSuiteOptions{
|
|
ConfigModifiers: []func(*config.Config){
|
|
func(c *config.Config) {
|
|
c.Options.RuntimeFlags[config.RuntimeFlagSSHRoutesPortal] = true
|
|
},
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
func ptr[T any](t T) *T {
|
|
return &t
|
|
}
|