mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
ssh: stream management api (#5670)
## Summary This implements the StreamManagement API defined at https://github.com/pomerium/envoy-custom/blob/main/api/extensions/filters/network/ssh/ssh.proto#L46-L60. Policy evaluation and authorization logic is stubbed out here, and implemented in https://github.com/pomerium/pomerium/pull/5665. ## Related issues <!-- For example... - #159 --> ## User Explanation <!-- How would you explain this change to the user? If this change doesn't create any user-facing changes, you can leave this blank. If filled out, add the `docs` label --> ## Checklist - [ ] reference any related issues - [ ] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review
This commit is contained in:
parent
c53aca0dd8
commit
b216b7a135
18 changed files with 4257 additions and 9 deletions
302
pkg/ssh/channel_impl_test.go
Normal file
302
pkg/ssh/channel_impl_test.go
Normal file
|
@ -0,0 +1,302 @@
|
|||
package ssh_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
||||
"github.com/pomerium/pomerium/pkg/ssh"
|
||||
)
|
||||
|
||||
func TestFlowControl_BlockAndWaitForAdjust(t *testing.T) {
|
||||
stream := newMockChannelStream(t)
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 1024,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
|
||||
sendDone := make(chan struct{})
|
||||
wait := make(chan struct{})
|
||||
go func() {
|
||||
defer close(sendDone)
|
||||
close(wait)
|
||||
ci.SendMessage(ssh.ChannelDataMsg{
|
||||
PeersID: 1,
|
||||
Length: 1024,
|
||||
Rest: make([]byte, 1024),
|
||||
})
|
||||
}()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
<-wait
|
||||
stream.SendClientToServer(channelMsg(ssh.WindowAdjustMsg{
|
||||
PeersID: 2,
|
||||
AdditionalBytes: 1024,
|
||||
}))
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelDataMsg{
|
||||
PeersID: 2,
|
||||
}))
|
||||
msg, err := ci.RecvMsg()
|
||||
<-sendDone
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ssh.ChannelDataMsg{
|
||||
PeersID: 2,
|
||||
Rest: []byte{},
|
||||
}, msg)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(1 * time.Second):
|
||||
assert.Fail(t, "timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlowControl_SendWindowAdjust(t *testing.T) {
|
||||
stream := newMockChannelStream(t)
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 1024,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
|
||||
largeDataMsg := ssh.ChannelDataMsg{
|
||||
PeersID: 1,
|
||||
Length: 16375,
|
||||
Rest: make([]byte, 16375),
|
||||
}
|
||||
encodedLen := len(gossh.Marshal(largeDataMsg))
|
||||
require.Equal(t, 16384, encodedLen) // to make the numbers easier
|
||||
|
||||
const MaxMsgsSentBeforeAdjust = (ssh.ChannelWindowSize / 2) / 16384
|
||||
for i := range MaxMsgsSentBeforeAdjust {
|
||||
stream.SendClientToServer(channelMsg(largeDataMsg))
|
||||
dataMsg, err := ci.RecvMsg()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, dataMsg)
|
||||
require.Equalf(t, 0, len(stream.serverToClient), "unexpected window adjust on message %d", i)
|
||||
}
|
||||
|
||||
require.Equalf(t, 0, len(stream.serverToClient), "unexpected window adjust on message %d", MaxMsgsSentBeforeAdjust)
|
||||
stream.SendClientToServer(channelMsg(largeDataMsg))
|
||||
dataMsg, err := ci.RecvMsg()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, dataMsg)
|
||||
require.Equal(t, 1, len(stream.serverToClient))
|
||||
|
||||
recv, err := stream.RecvServerToClient()
|
||||
assert.NoError(t, err)
|
||||
bytes := recv.GetRawBytes().GetValue()
|
||||
var adjust ssh.WindowAdjustMsg
|
||||
assert.NoError(t, gossh.Unmarshal(bytes, &adjust))
|
||||
assert.Equal(t, uint32(ssh.ChannelWindowSize), adjust.AdditionalBytes)
|
||||
assert.Equal(t, uint32(1), adjust.PeersID)
|
||||
}
|
||||
|
||||
func TestFlowControl_WindowAdjustOverflow(t *testing.T) {
|
||||
stream := newMockChannelStream(t)
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 1024,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
stream.SendClientToServer(channelMsg(ssh.WindowAdjustMsg{
|
||||
PeersID: 2,
|
||||
AdditionalBytes: math.MaxUint32,
|
||||
}))
|
||||
_, err := ci.RecvMsg()
|
||||
assert.ErrorIs(t, err, status.Errorf(codes.InvalidArgument, "invalid window adjustment"))
|
||||
}
|
||||
|
||||
func TestFlowControl_StreamClosed(t *testing.T) {
|
||||
ctx, ca := context.WithCancel(t.Context())
|
||||
stream := &mockChannelStream{
|
||||
GenericServerStream: &grpc.GenericServerStream[extensions_ssh.ChannelMessage, extensions_ssh.ChannelMessage]{
|
||||
ServerStream: &mockGrpcServerStream{
|
||||
ctx: ctx,
|
||||
},
|
||||
},
|
||||
serverToClient: make(chan *extensions_ssh.ChannelMessage, 32),
|
||||
clientToServer: make(chan *extensions_ssh.ChannelMessage, 32),
|
||||
}
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 0,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
ready := make(chan struct{})
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
close(ready)
|
||||
errC <- ci.SendMessage(ssh.ChannelDataMsg{
|
||||
PeersID: 1,
|
||||
Length: 1,
|
||||
Rest: []byte("a"),
|
||||
})
|
||||
}()
|
||||
<-ready
|
||||
runtime.Gosched()
|
||||
ca()
|
||||
select {
|
||||
case err := <-errC:
|
||||
assert.ErrorIs(t, err, status.Errorf(codes.Internal, "stream closed"))
|
||||
case <-time.After(DefaultTimeout):
|
||||
assert.Fail(t, "timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvMsg_EmptyMessage(t *testing.T) {
|
||||
stream := newMockChannelStream(t)
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 1024,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
|
||||
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes([]byte{}),
|
||||
},
|
||||
})
|
||||
_, err := ci.RecvMsg()
|
||||
assert.ErrorIs(t, status.Errorf(codes.InvalidArgument, "peer sent empty message"), err)
|
||||
}
|
||||
|
||||
func TestRecvMsg_MessageTooLarge(t *testing.T) {
|
||||
stream := newMockChannelStream(t)
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 1024,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
|
||||
tooLargeDataMsg := ssh.ChannelDataMsg{
|
||||
PeersID: 1,
|
||||
Length: ssh.ChannelMaxPacket,
|
||||
Rest: make([]byte, ssh.ChannelMaxPacket),
|
||||
}
|
||||
stream.SendClientToServer(channelMsg(tooLargeDataMsg))
|
||||
_, err := ci.RecvMsg()
|
||||
assert.ErrorIs(t, status.Errorf(codes.ResourceExhausted, "message too large"), err)
|
||||
}
|
||||
|
||||
func TestRecvMsg_AllowedMessages(t *testing.T) {
|
||||
stream := newMockChannelStream(t)
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 1024,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
|
||||
// RecvMsg will immediately read another message after WindowAdjust, so
|
||||
// we have to send something
|
||||
stream.SendClientToServer(channelMsg(ssh.WindowAdjustMsg{}))
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelDataMsg{}))
|
||||
_, err := ci.RecvMsg()
|
||||
assert.NoError(t, err)
|
||||
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelRequestMsg{}))
|
||||
_, err = ci.RecvMsg()
|
||||
assert.NoError(t, err)
|
||||
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelDataMsg{}))
|
||||
_, err = ci.RecvMsg()
|
||||
assert.NoError(t, err)
|
||||
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelCloseMsg{}))
|
||||
_, err = ci.RecvMsg()
|
||||
assert.NoError(t, err)
|
||||
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelEOFMsg{}))
|
||||
_, err = ci.RecvMsg()
|
||||
assert.NoError(t, err)
|
||||
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelOpenMsg{}))
|
||||
_, err = ci.RecvMsg()
|
||||
assert.ErrorIs(t, err, status.Errorf(codes.InvalidArgument, "only one channel can be opened"))
|
||||
|
||||
stream.SendClientToServer(channelMsg(ssh.ChannelRequestFailureMsg{}))
|
||||
_, err = ci.RecvMsg()
|
||||
assert.ErrorIs(t, err, status.Errorf(codes.Unimplemented, "received unexpected message with type 100"))
|
||||
|
||||
stream.SendClientToServer(&extensions_ssh.ChannelMessage{Message: &extensions_ssh.ChannelMessage_ChannelControl{}})
|
||||
_, err = ci.RecvMsg()
|
||||
assert.ErrorIs(t, err, status.Errorf(codes.Unimplemented, "unknown channel message received"))
|
||||
}
|
||||
|
||||
func TestRecvMsg_UnmarshalErrors(t *testing.T) {
|
||||
stream := newMockChannelStream(t)
|
||||
ci := ssh.NewChannelImpl(nil, stream, &extensions_ssh.SSHDownstreamChannelInfo{
|
||||
ChannelType: "session",
|
||||
DownstreamChannelId: 1,
|
||||
InternalUpstreamChannelId: 2,
|
||||
InitialWindowSize: 1024,
|
||||
MaxPacketSize: 4096,
|
||||
})
|
||||
|
||||
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelWindowAdjust}),
|
||||
},
|
||||
})
|
||||
_, err := ci.RecvMsg()
|
||||
assert.ErrorContains(t, err, "ssh: short read")
|
||||
|
||||
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelRequest}),
|
||||
},
|
||||
})
|
||||
_, err = ci.RecvMsg()
|
||||
assert.ErrorContains(t, err, "ssh: short read")
|
||||
|
||||
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelData}),
|
||||
},
|
||||
})
|
||||
_, err = ci.RecvMsg()
|
||||
assert.ErrorContains(t, err, "ssh: short read")
|
||||
|
||||
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelClose}),
|
||||
},
|
||||
})
|
||||
_, err = ci.RecvMsg()
|
||||
assert.ErrorContains(t, err, "ssh: short read")
|
||||
|
||||
stream.SendClientToServer(&extensions_ssh.ChannelMessage{
|
||||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||||
RawBytes: wrapperspb.Bytes([]byte{ssh.MsgChannelEOF}),
|
||||
},
|
||||
})
|
||||
_, err = ci.RecvMsg()
|
||||
assert.ErrorContains(t, err, "ssh: short read")
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue