This commit is contained in:
Joe Kralicky 2025-03-19 18:20:13 +00:00
parent d89a7d97d7
commit ff26890bf4
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
3 changed files with 143 additions and 37 deletions

View file

@ -222,18 +222,49 @@ func (a *Authorize) ManageStream(
} }
sendC <- &resp sendC <- &resp
continue continue
} else if authReq.Username == "_mirror" && authReq.Hostname != "" { } else if authReq.Username == "_mirror" && authReq.Hostname == "" {
id, _ := strconv.ParseUint(authReq.Hostname, 10, 64)
resp := extensions_ssh.ServerMessage{ resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{ Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{ Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{ Allow: &extensions_ssh.AllowResponse{
Target: &extensions_ssh.AllowResponse_MirrorSession{ Username: state.Username,
MirrorSession: &extensions_ssh.MirrorSessionTarget{ Target: &extensions_ssh.AllowResponse_Internal{
SourceId: id, Internal: &extensions_ssh.InternalTarget{},
Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, },
}, },
},
},
},
}
// id, _ := strconv.ParseUint(authReq.Hostname, 10, 64)
// resp := extensions_ssh.ServerMessage{
// Message: &extensions_ssh.ServerMessage_AuthResponse{
// AuthResponse: &extensions_ssh.AuthenticationResponse{
// Response: &extensions_ssh.AuthenticationResponse_Allow{
// Allow: &extensions_ssh.AllowResponse{
// Target: &extensions_ssh.AllowResponse_MirrorSession{
// MirrorSession: &extensions_ssh.MirrorSessionTarget{
// SourceId: id,
// Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
// },
// },
// },
// },
// },
// },
// }
sendC <- &resp
continue
} else if authReq.Username != "" && authReq.Hostname == "" {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{},
}, },
}, },
}, },
@ -501,6 +532,13 @@ type channelRequestMsg struct {
RequestSpecificData []byte `ssh:"rest"` RequestSpecificData []byte `ssh:"rest"`
} }
type channelOpenDirectMsg struct {
DestAddr string
DestPort uint32
SrcAddr string
SrcPort uint32
}
// See RFC 4254, section 5.4. // See RFC 4254, section 5.4.
const msgChannelSuccess = 99 const msgChannelSuccess = 99
@ -566,14 +604,47 @@ func (a *Authorize) ServeChannel(
InitialWindowSize: confirm.MyWindow, InitialWindowSize: confirm.MyWindow,
MaxPacketSize: confirm.MaxPacketSize, MaxPacketSize: confirm.MaxPacketSize,
} }
if err := server.Send(&extensions_ssh.ChannelMessage{ switch msg.ChanType {
Message: &extensions_ssh.ChannelMessage_RawBytes{ case "session":
RawBytes: &wrapperspb.BytesValue{ if err := server.Send(&extensions_ssh.ChannelMessage{
Value: gossh.Marshal(confirm), Message: &extensions_ssh.ChannelMessage_RawBytes{
RawBytes: &wrapperspb.BytesValue{
Value: gossh.Marshal(confirm),
},
}, },
}, }); err != nil {
}); err != nil { return err
return err }
case "direct-tcpip":
var subMsg channelOpenDirectMsg
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
return err
}
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamChannelInfo: downstreamChannelInfo,
UpstreamAuth: &extensions_ssh.AllowResponse{
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: subMsg.DestAddr,
DirectTcpip: true,
},
},
},
},
},
})
if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_ChannelControl{
ChannelControl: &extensions_ssh.ChannelControl{
Protocol: "ssh",
ControlAction: handOff,
},
},
}); err != nil {
return err
}
} }
case msgChannelRequest: case msgChannelRequest:
@ -594,6 +665,10 @@ func (a *Authorize) ServeChannel(
for _, route := range routes { for _, route := range routes {
items = append(items, item(route)) items = append(items, item(route))
} }
activeStreamIds.Range(func(key, value any) bool {
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
return true
})
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
TermEnv: req.TermEnv, TermEnv: req.TermEnv,
WidthColumns: req.Width, WidthColumns: req.Width,
@ -625,33 +700,57 @@ func (a *Authorize) ServeChannel(
if err != nil { if err != nil {
return return
} }
username, hostname, _ := strings.Cut(answer.(model).choice, "@") var handOff *anypb.Any
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
Format: extensions_session_recording.Format_AsciicastFormat, if err != nil {
}) panic(err)
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ }
Action: &extensions_ssh.SSHChannelControlAction_HandOff{ handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ Action: &extensions_ssh.SSHChannelControlAction_HandOff{
DownstreamChannelInfo: downstreamChannelInfo, HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamPtyInfo: downstreamPtyInfo, DownstreamChannelInfo: downstreamChannelInfo,
UpstreamAuth: &extensions_ssh.AllowResponse{ DownstreamPtyInfo: downstreamPtyInfo,
Username: username, UpstreamAuth: &extensions_ssh.AllowResponse{
Target: &extensions_ssh.AllowResponse_Upstream{ Target: &extensions_ssh.AllowResponse_MirrorSession{
Upstream: &extensions_ssh.UpstreamTarget{ MirrorSession: &extensions_ssh.MirrorSessionTarget{
AllowMirrorConnections: true, SourceId: id,
Hostname: hostname, Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
Extensions: []*corev3.TypedExtensionConfig{ },
{ },
TypedConfig: sessionRecordingExt, },
},
},
})
} else {
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,
})
handOff, _ = anypb.New(&extensions_ssh.SSHChannelControlAction{
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamChannelInfo: downstreamChannelInfo,
DownstreamPtyInfo: downstreamPtyInfo,
UpstreamAuth: &extensions_ssh.AllowResponse{
Username: username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
AllowMirrorConnections: true,
Hostname: hostname,
Extensions: []*corev3.TypedExtensionConfig{
{
TypedConfig: sessionRecordingExt,
},
}, },
}, },
}, },
}, },
}, },
}, },
}, })
}) }
if err := server.Send(&extensions_ssh.ChannelMessage{ if err := server.Send(&extensions_ssh.ChannelMessage{
Message: &extensions_ssh.ChannelMessage_ChannelControl{ Message: &extensions_ssh.ChannelMessage_ChannelControl{

View file

@ -89,6 +89,13 @@ func (b *Builder) buildSSHListener(ctx context.Context, cfg *config.Config) (*en
ChunkSize: wrapperspb.UInt32(8192), ChunkSize: wrapperspb.UInt32(8192),
}), }),
}, },
// FileManagerConfig: &async_filesv3.AsyncFileManagerConfig{
// ManagerType: &async_filesv3.AsyncFileManagerConfig_ThreadPool_{
// ThreadPool: &async_filesv3.AsyncFileManagerConfig_ThreadPool{
// ThreadCount: 2,
// },
// },
// },
}), }),
}, },
// { // {

View file

@ -122,7 +122,7 @@ func TestSpanObserver(t *testing.T) {
waitOkToExit.Store(true) waitOkToExit.Store(true)
obs.Observe(Span(7).ID()) obs.Observe(Span(7).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond) assert.Eventually(t, waitExited.Load, 100*time.Millisecond, 10*time.Millisecond)
}) })
t.Run("multiple waiters", func(t *testing.T) { t.Run("multiple waiters", func(t *testing.T) {
@ -147,7 +147,7 @@ func TestSpanObserver(t *testing.T) {
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return waitersExited.Load() == 10 return waitersExited.Load() == 10
}, 10*time.Millisecond, 1*time.Millisecond) }, 100*time.Millisecond, 10*time.Millisecond)
}) })
} }