diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index 42bf8f49c..39c43f83a 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -222,18 +222,49 @@ func (a *Authorize) ManageStream( } sendC <- &resp continue - } else if authReq.Username == "_mirror" && authReq.Hostname != "" { - id, _ := strconv.ParseUint(authReq.Hostname, 10, 64) + } else if authReq.Username == "_mirror" && authReq.Hostname == "" { resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Allow{ Allow: &extensions_ssh.AllowResponse{ - Target: &extensions_ssh.AllowResponse_MirrorSession{ - MirrorSession: &extensions_ssh.MirrorSessionTarget{ - SourceId: id, - Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, - }, + Username: state.Username, + Target: &extensions_ssh.AllowResponse_Internal{ + Internal: &extensions_ssh.InternalTarget{}, + }, + }, + }, + }, + }, + } + // id, _ := strconv.ParseUint(authReq.Hostname, 10, 64) + // resp := extensions_ssh.ServerMessage{ + // Message: &extensions_ssh.ServerMessage_AuthResponse{ + // AuthResponse: &extensions_ssh.AuthenticationResponse{ + // Response: &extensions_ssh.AuthenticationResponse_Allow{ + // Allow: &extensions_ssh.AllowResponse{ + // Target: &extensions_ssh.AllowResponse_MirrorSession{ + // MirrorSession: &extensions_ssh.MirrorSessionTarget{ + // SourceId: id, + // Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, + // }, + // }, + // }, + // }, + // }, + // }, + // } + sendC <- &resp + continue + } else if authReq.Username != "" && authReq.Hostname == "" { + resp := extensions_ssh.ServerMessage{ + Message: &extensions_ssh.ServerMessage_AuthResponse{ + AuthResponse: &extensions_ssh.AuthenticationResponse{ + Response: &extensions_ssh.AuthenticationResponse_Allow{ + Allow: &extensions_ssh.AllowResponse{ + Username: state.Username, + Target: &extensions_ssh.AllowResponse_Internal{ + Internal: &extensions_ssh.InternalTarget{}, }, }, }, @@ -501,6 +532,13 @@ type channelRequestMsg struct { RequestSpecificData []byte `ssh:"rest"` } +type channelOpenDirectMsg struct { + DestAddr string + DestPort uint32 + SrcAddr string + SrcPort uint32 +} + // See RFC 4254, section 5.4. const msgChannelSuccess = 99 @@ -566,14 +604,47 @@ func (a *Authorize) ServeChannel( InitialWindowSize: confirm.MyWindow, MaxPacketSize: confirm.MaxPacketSize, } - if err := server.Send(&extensions_ssh.ChannelMessage{ - Message: &extensions_ssh.ChannelMessage_RawBytes{ - RawBytes: &wrapperspb.BytesValue{ - Value: gossh.Marshal(confirm), + switch msg.ChanType { + case "session": + if err := server.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_RawBytes{ + RawBytes: &wrapperspb.BytesValue{ + Value: gossh.Marshal(confirm), + }, }, - }, - }); err != nil { - return err + }); err != nil { + return err + } + case "direct-tcpip": + var subMsg channelOpenDirectMsg + if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil { + return err + } + handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ + Action: &extensions_ssh.SSHChannelControlAction_HandOff{ + HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: downstreamChannelInfo, + UpstreamAuth: &extensions_ssh.AllowResponse{ + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + Hostname: subMsg.DestAddr, + DirectTcpip: true, + }, + }, + }, + }, + }, + }) + if err := server.Send(&extensions_ssh.ChannelMessage{ + Message: &extensions_ssh.ChannelMessage_ChannelControl{ + ChannelControl: &extensions_ssh.ChannelControl{ + Protocol: "ssh", + ControlAction: handOff, + }, + }, + }); err != nil { + return err + } } case msgChannelRequest: @@ -594,6 +665,10 @@ func (a *Authorize) ServeChannel( for _, route := range routes { items = append(items, item(route)) } + activeStreamIds.Range(func(key, value any) bool { + items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key))) + return true + }) downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ TermEnv: req.TermEnv, WidthColumns: req.Width, @@ -625,33 +700,57 @@ func (a *Authorize) ServeChannel( if err != nil { return } - username, hostname, _ := strings.Cut(answer.(model).choice, "@") - sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ - RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), - Format: extensions_session_recording.Format_AsciicastFormat, - }) - handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ - Action: &extensions_ssh.SSHChannelControlAction_HandOff{ - HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ - DownstreamChannelInfo: downstreamChannelInfo, - DownstreamPtyInfo: downstreamPtyInfo, - UpstreamAuth: &extensions_ssh.AllowResponse{ - Username: username, - Target: &extensions_ssh.AllowResponse_Upstream{ - Upstream: &extensions_ssh.UpstreamTarget{ - AllowMirrorConnections: true, - Hostname: hostname, - Extensions: []*corev3.TypedExtensionConfig{ - { - TypedConfig: sessionRecordingExt, + var handOff *anypb.Any + if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") { + id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64) + if err != nil { + panic(err) + } + handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{ + Action: &extensions_ssh.SSHChannelControlAction_HandOff{ + HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: downstreamChannelInfo, + DownstreamPtyInfo: downstreamPtyInfo, + UpstreamAuth: &extensions_ssh.AllowResponse{ + Target: &extensions_ssh.AllowResponse_MirrorSession{ + MirrorSession: &extensions_ssh.MirrorSessionTarget{ + SourceId: id, + Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, + }, + }, + }, + }, + }, + }) + } else { + username, hostname, _ := strings.Cut(answer.(model).choice, "@") + sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ + RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), + Format: extensions_session_recording.Format_AsciicastFormat, + }) + handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{ + Action: &extensions_ssh.SSHChannelControlAction_HandOff{ + HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ + DownstreamChannelInfo: downstreamChannelInfo, + DownstreamPtyInfo: downstreamPtyInfo, + UpstreamAuth: &extensions_ssh.AllowResponse{ + Username: username, + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + AllowMirrorConnections: true, + Hostname: hostname, + Extensions: []*corev3.TypedExtensionConfig{ + { + TypedConfig: sessionRecordingExt, + }, }, }, }, }, }, }, - }, - }) + }) + } if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_ChannelControl{ diff --git a/config/envoyconfig/listeners_ssh.go b/config/envoyconfig/listeners_ssh.go index d1d6ce92b..fbb779e09 100644 --- a/config/envoyconfig/listeners_ssh.go +++ b/config/envoyconfig/listeners_ssh.go @@ -89,6 +89,13 @@ func (b *Builder) buildSSHListener(ctx context.Context, cfg *config.Config) (*en ChunkSize: wrapperspb.UInt32(8192), }), }, + // FileManagerConfig: &async_filesv3.AsyncFileManagerConfig{ + // ManagerType: &async_filesv3.AsyncFileManagerConfig_ThreadPool_{ + // ThreadPool: &async_filesv3.AsyncFileManagerConfig_ThreadPool{ + // ThreadCount: 2, + // }, + // }, + // }, }), }, // { diff --git a/internal/telemetry/trace/debug_test.go b/internal/telemetry/trace/debug_test.go index 5134f247c..9e58c1d5b 100644 --- a/internal/telemetry/trace/debug_test.go +++ b/internal/telemetry/trace/debug_test.go @@ -122,7 +122,7 @@ func TestSpanObserver(t *testing.T) { waitOkToExit.Store(true) obs.Observe(Span(7).ID()) assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) - assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond) + assert.Eventually(t, waitExited.Load, 100*time.Millisecond, 10*time.Millisecond) }) t.Run("multiple waiters", func(t *testing.T) { @@ -147,7 +147,7 @@ func TestSpanObserver(t *testing.T) { assert.Eventually(t, func() bool { return waitersExited.Load() == 10 - }, 10*time.Millisecond, 1*time.Millisecond) + }, 100*time.Millisecond, 10*time.Millisecond) }) }