diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index 9769191d2..42bf8f49c 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -1,27 +1,40 @@ package authorize import ( + "bufio" + "bytes" + "crypto/sha256" "encoding/binary" "errors" "fmt" "io" "net/url" "slices" + "strconv" "strings" + "sync" "sync/atomic" "time" "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/klauspost/compress/zstd" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + extensions_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/oauth" gossh "golang.org/x/crypto/ssh" "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/protobuf/encoding/protodelim" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -33,6 +46,76 @@ type StreamState struct { MethodsAuthenticated []string } +func (a *Authorize) RecordingFinalized( + stream grpc.ClientStreamingServer[extensions_session_recording.RecordingData, emptypb.Empty], +) error { + msg, err := stream.Recv() + if err != nil { + return err + } + md := msg.GetMetadata() + if md == nil { + return fmt.Errorf("first message did not contain metadata") + } + log.Ctx(stream.Context()).Info().Str("info", protojson.Format(md)).Msg("new recording") + + var recording []byte +READ: + for { + msg, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + switch data := msg.Data.(type) { + case *extensions_session_recording.RecordingData_Chunk: + recording = append(recording, data.Chunk...) + case *extensions_session_recording.RecordingData_Checksum: + actual := sha256.Sum256(recording) + if actual != [32]byte(data.Checksum) { + return fmt.Errorf("checksum mismatch") + } + break READ + } + } + + r, err := zstd.NewReader(bytes.NewReader(recording)) + if err != nil { + return fmt.Errorf("failed to create zstd reader: %w", err) + } + + switch md.Format { + case extensions_session_recording.Format_AsciicastFormat: + log.Ctx(stream.Context()).Info().Int("compressed_size", len(recording)).Msg("asciicast recording received") + case extensions_session_recording.Format_RawFormat: + reader := bufio.NewReader(r) + var header extensions_session_recording.Header + if err := protodelim.UnmarshalFrom(reader, &header); err != nil { + return fmt.Errorf("failed to unmarshal header: %w", err) + } + + var packets []*extensions_session_recording.Packet + for { + var packet extensions_session_recording.Packet + err := protodelim.UnmarshalFrom(reader, &packet) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("failed to unmarshal packet: %w", err) + } + packets = append(packets, &packet) + } + + log.Ctx(stream.Context()).Info().Int("compressed_size", len(recording)).Int("packet_count", len(packets)).Msg("recording received") + } + return nil +} + +var activeStreamIds sync.Map + func (a *Authorize) ManageStream( server extensions_ssh.StreamManagement_ManageStreamServer, ) error { @@ -72,6 +155,7 @@ func (a *Authorize) ManageStream( var state StreamState deviceAuthSuccess := &atomic.Bool{} + deviceAuthDone := make(chan struct{}) errC := make(chan error, 1) a.activeStreamsMu.Lock() @@ -80,7 +164,6 @@ func (a *Authorize) ManageStream( for { select { case err := <-errC: - return err case req, ok := <-recvC: if !ok { @@ -92,6 +175,10 @@ func (a *Authorize) ManageStream( case *extensions_ssh.StreamEvent_DownstreamConnected: fmt.Println("downstream connected") _ = event + case *extensions_ssh.StreamEvent_UpstreamConnected: + fmt.Printf("upstream connected: %d\n", event.UpstreamConnected.GetStreamId()) + activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state) + defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId()) case nil: } case *extensions_ssh.ClientMessage_AuthRequest: @@ -119,32 +206,35 @@ func (a *Authorize) ManageStream( state.PublicKey = pubkeyReq.PublicKey if authReq.Username == "" && authReq.Hostname == "" { - pkData, _ := anypb.New(&extensions_ssh.PublicKeyAllowResponse{ - PublicKey: state.PublicKey, - Permissions: &extensions_ssh.Permissions{ - PermitPortForwarding: true, - PermitAgentForwarding: true, - PermitX11Forwarding: true, - PermitPty: true, - PermitUserRc: true, - ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)), - ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)), - }, - }) resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Allow{ Allow: &extensions_ssh.AllowResponse{ Username: state.Username, - Hostname: state.Hostname, - AllowedMethods: []*extensions_ssh.AllowedMethod{ - { - Method: "publickey", - MethodData: pkData, + Target: &extensions_ssh.AllowResponse_Internal{ + Internal: &extensions_ssh.InternalTarget{}, + }, + }, + }, + }, + }, + } + sendC <- &resp + continue + } else if authReq.Username == "_mirror" && authReq.Hostname != "" { + id, _ := strconv.ParseUint(authReq.Hostname, 10, 64) + resp := extensions_ssh.ServerMessage{ + Message: &extensions_ssh.ServerMessage_AuthResponse{ + AuthResponse: &extensions_ssh.AuthenticationResponse{ + Response: &extensions_ssh.AuthenticationResponse_Allow{ + Allow: &extensions_ssh.AllowResponse{ + Target: &extensions_ssh.AllowResponse_MirrorSession{ + MirrorSession: &extensions_ssh.MirrorSessionTarget{ + SourceId: id, + Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, }, }, - Target: extensions_ssh.Target_Internal, }, }, }, @@ -207,8 +297,8 @@ func (a *Authorize) ManageStream( infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ Name: "Sign in with " + idp.GetType(), Instruction: deviceAuthResp.VerificationURIComplete, - Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ - {}, + Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ + // {}, }, } @@ -243,6 +333,7 @@ func (a *Authorize) ManageStream( } fmt.Println(token) deviceAuthSuccess.Store(true) + close(deviceAuthDone) }() } case *extensions_ssh.ClientMessage_InfoResponse: @@ -254,31 +345,45 @@ func (a *Authorize) ManageStream( fmt.Println(respInfo.Responses) } } + select { + case <-deviceAuthDone: + case <-ctx.Done(): + } if deviceAuthSuccess.Load() { state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive") } else { - retryReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ - Name: "", - Instruction: "Login not successful yet, try again", - Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ - {}, - }, - } - infoReqAny, _ := anypb.New(&retryReq) - resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ - Response: &extensions_ssh.AuthenticationResponse_InfoRequest{ - InfoRequest: &extensions_ssh.InfoRequest{ - Method: "keyboard-interactive", - Request: infoReqAny, - }, + Response: &extensions_ssh.AuthenticationResponse_Deny{ + Deny: &extensions_ssh.DenyResponse{}, }, }, }, } sendC <- &resp + // retryReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ + // Name: "", + // Instruction: "Login not successful yet, try again", + // Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ + // // {}, + // }, + // } + // infoReqAny, _ := anypb.New(&retryReq) + + // resp := extensions_ssh.ServerMessage{ + // Message: &extensions_ssh.ServerMessage_AuthResponse{ + // AuthResponse: &extensions_ssh.AuthenticationResponse{ + // Response: &extensions_ssh.AuthenticationResponse_InfoRequest{ + // InfoRequest: &extensions_ssh.InfoRequest{ + // Method: "keyboard-interactive", + // Request: infoReqAny, + // }, + // }, + // }, + // }, + // } + // sendC <- &resp continue } if slices.Contains(state.MethodsAuthenticated, "publickey") { @@ -294,23 +399,35 @@ func (a *Authorize) ManageStream( ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)), }, }) + sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ + RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()), + Format: extensions_session_recording.Format_AsciicastFormat, + }) authResponse := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Allow{ Allow: &extensions_ssh.AllowResponse{ Username: state.Username, - Hostname: state.Hostname, - AllowedMethods: []*extensions_ssh.AllowedMethod{ - { - Method: "publickey", - MethodData: pkData, - }, - { - Method: "keyboard-interactive", + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + Hostname: state.Hostname, + AllowedMethods: []*extensions_ssh.AllowedMethod{ + { + Method: "publickey", + MethodData: pkData, + }, + { + Method: "keyboard-interactive", + }, + }, + Extensions: []*corev3.TypedExtensionConfig{ + { + TypedConfig: sessionRecordingExt, + }, + }, }, }, - Target: extensions_ssh.Target_Upstream, }, }, }, @@ -465,10 +582,17 @@ func (a *Authorize) ServeChannel( switch msg.Request { case "pty-req": + opts := a.currentOptions.Load() + var routes []string + for r := range opts.GetAllPolicies() { + if strings.HasPrefix(r.From, "ssh://") { + routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+opts.SSHHostname))) + } + } req := parsePtyReq(msg.RequestSpecificData) - items := []list.Item{ - item("ubuntu@vm"), - item("joe@local"), + items := []list.Item{} + for _, route := range routes { + items = append(items, item(route)) } downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ TermEnv: req.TermEnv, @@ -502,6 +626,10 @@ func (a *Authorize) ServeChannel( return } username, hostname, _ := strings.Cut(answer.(model).choice, "@") + sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ + RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), + Format: extensions_session_recording.Format_AsciicastFormat, + }) handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ Action: &extensions_ssh.SSHChannelControlAction_HandOff{ HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ @@ -509,7 +637,17 @@ func (a *Authorize) ServeChannel( DownstreamPtyInfo: downstreamPtyInfo, UpstreamAuth: &extensions_ssh.AllowResponse{ Username: username, - Hostname: hostname, + Target: &extensions_ssh.AllowResponse_Upstream{ + Upstream: &extensions_ssh.UpstreamTarget{ + AllowMirrorConnections: true, + Hostname: hostname, + Extensions: []*corev3.TypedExtensionConfig{ + { + TypedConfig: sessionRecordingExt, + }, + }, + }, + }, }, }, }, @@ -675,11 +813,5 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } func (m model) View() string { - if m.choice != "" { - return quitTextStyle.Render(fmt.Sprintf("%s? Sounds good to me.", m.choice)) - } - if m.quitting { - return quitTextStyle.Render("Not hungry? That’s cool.") - } return "\n" + m.list.View() } diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 3289a4391..48e1dbba3 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -53,9 +53,9 @@ func main() { } func run(ctx context.Context, configFile string) error { - ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { - return c.Str("config_file_source", configFile).Bool("bootstrap", true) - }) + // ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { + // return c.Str("config_file_source", configFile).Bool("bootstrap", true) + // }) var src config.Source diff --git a/config/envoyconfig/listeners_ssh.go b/config/envoyconfig/listeners_ssh.go index 2574692c9..d1d6ce92b 100644 --- a/config/envoyconfig/listeners_ssh.go +++ b/config/envoyconfig/listeners_ssh.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/url" + "os" "strings" "time" @@ -11,14 +12,17 @@ import ( xds_matcher_v3 "github.com/cncf/xds/go/xds/type/matcher/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + extensions_compressor_zstd_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/compression/zstd/compressor/v3" envoy_generic_proxy_action_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/generic_proxy/action/v3" envoy_generic_proxy_matcher_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/generic_proxy/matcher/v3" envoy_generic_router_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/generic_proxy/router/v3" envoy_generic_proxy_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/generic_proxy/v3" matcherv3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + extensions_ssh_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording" "github.com/pomerium/pomerium/config" "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/wrapperspb" ) func (b *Builder) buildSSHListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { @@ -39,6 +43,15 @@ func (b *Builder) buildSSHListener(ctx context.Context, cfg *config.Config) (*en } else { grpcClientTimeout = durationpb.New(30 * time.Second) } + os.MkdirAll("/tmp/recordings", 0o755) + authorizeService := &envoy_config_core_v3.GrpcService{ + Timeout: grpcClientTimeout, + TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{ + EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{ + ClusterName: "pomerium-authorize", + }, + }, + } li := &envoy_config_listener_v3.Listener{ Name: "ssh", Address: buildTCPAddress(cfg.Options.SSHAddr, 22), @@ -58,17 +71,30 @@ func (b *Builder) buildSSHListener(ctx context.Context, cfg *config.Config) (*en PublicKeyFile: cfg.Options.SSHUserCAKey.PublicKeyFile, PrivateKeyFile: cfg.Options.SSHUserCAKey.PrivateKeyFile, }, - GrpcService: &envoy_config_core_v3.GrpcService{ - Timeout: grpcClientTimeout, - TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{ - EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{ - ClusterName: "pomerium-authorize", - }, - }, - }, + GrpcService: authorizeService, }), }, Filters: []*envoy_config_core_v3.TypedExtensionConfig{ + { + Name: "envoy.filters.generic.ssh.session_recording", + TypedConfig: marshalAny(&extensions_ssh_session_recording.Config{ + StorageDir: "/tmp/recordings", + GrpcService: authorizeService, + CompressorLibrary: &envoy_config_core_v3.TypedExtensionConfig{ + Name: "envoy.compression.zstd.compressor", + TypedConfig: marshalAny(&extensions_compressor_zstd_v3.Zstd{ + CompressionLevel: wrapperspb.UInt32(19), + EnableChecksum: false, + Strategy: extensions_compressor_zstd_v3.Zstd_BTULTRA2, + ChunkSize: wrapperspb.UInt32(8192), + }), + }, + }), + }, + // { + // Name: "envoy.filters.generic.ssh.session_multiplexing", + // TypedConfig: marshalAny(&extensions_ssh_session_multiplexing.Config{}), + // }, { Name: "envoy.filters.generic.router", TypedConfig: marshalAny(&envoy_generic_router_v3.Router{ diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index 98b079f83..9b2bd0fc2 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -10,6 +10,7 @@ import ( envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + extensions_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording" "go.uber.org/automaxprocs/maxprocs" "golang.org/x/sync/errgroup" @@ -270,6 +271,7 @@ func setupAuthorize(ctx context.Context, src config.Source, controlPlane *contro } envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, svc) extensions_ssh.RegisterStreamManagementServer(controlPlane.GRPCServer, svc) + extensions_session_recording.RegisterRecordingServiceServer(controlPlane.GRPCServer, svc) log.Ctx(ctx).Info().Msg("enabled authorize service") src.OnConfigChange(ctx, svc.OnConfigChange)