package authorize import ( "bufio" "bytes" "context" "crypto/sha256" "encoding/binary" "errors" "fmt" "io" "net/url" "slices" "strconv" "strings" "sync" "sync/atomic" "text/template" "time" "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/klauspost/compress/zstd" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" extensions_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/manager" "github.com/pomerium/pomerium/pkg/identity/oauth" "github.com/pomerium/pomerium/pkg/storage" "github.com/spf13/cobra" gossh "golang.org/x/crypto/ssh" "golang.org/x/oauth2" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protodelim" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) type ActiveStreams struct { mu sync.Mutex streamsById map[uint64]*StreamState } type StreamState struct { Context context.Context StreamID uint64 ErrorC chan<- error Username string Hostname string PublicKey []byte MethodsAuthenticated []string Session *session.Session } func (a *ActiveStreams) Get(id uint64) *StreamState { a.mu.Lock() defer a.mu.Unlock() return a.streamsById[id] } func (a *ActiveStreams) Put(id uint64, state *StreamState) { a.mu.Lock() defer a.mu.Unlock() a.streamsById[id] = state } func (a *ActiveStreams) Delete(id uint64) { a.mu.Lock() defer a.mu.Unlock() delete(a.streamsById, id) } func (a *ActiveStreams) Range(f func(id uint64, state *StreamState)) { a.mu.Lock() defer a.mu.Unlock() for id, state := range a.streamsById { f(id, state) } } func (a *Authorize) RecordingFinalized( stream grpc.ClientStreamingServer[extensions_session_recording.RecordingData, emptypb.Empty], ) error { msg, err := stream.Recv() if err != nil { return err } md := msg.GetMetadata() if md == nil { return fmt.Errorf("first message did not contain metadata") } log.Ctx(stream.Context()).Info().Str("info", protojson.Format(md)).Msg("new recording") var recording []byte READ: for { msg, err := stream.Recv() if err != nil { if errors.Is(err, io.EOF) { break } return err } switch data := msg.Data.(type) { case *extensions_session_recording.RecordingData_Chunk: recording = append(recording, data.Chunk...) case *extensions_session_recording.RecordingData_Checksum: actual := sha256.Sum256(recording) if actual != [32]byte(data.Checksum) { return fmt.Errorf("checksum mismatch") } break READ } } r, err := zstd.NewReader(bytes.NewReader(recording)) if err != nil { return fmt.Errorf("failed to create zstd reader: %w", err) } switch md.Format { case extensions_session_recording.Format_AsciicastFormat: log.Ctx(stream.Context()).Info().Int("compressed_size", len(recording)).Msg("asciicast recording received") case extensions_session_recording.Format_RawFormat: reader := bufio.NewReader(r) var header extensions_session_recording.Header if err := protodelim.UnmarshalFrom(reader, &header); err != nil { return fmt.Errorf("failed to unmarshal header: %w", err) } var packets []*extensions_session_recording.Packet for { var packet extensions_session_recording.Packet err := protodelim.UnmarshalFrom(reader, &packet) if err != nil { if errors.Is(err, io.EOF) { break } return fmt.Errorf("failed to unmarshal packet: %w", err) } packets = append(packets, &packet) } log.Ctx(stream.Context()).Info().Int("compressed_size", len(recording)).Int("packet_count", len(packets)).Msg("recording received") } return nil } func (a *Authorize) ManageStream( server extensions_ssh.StreamManagement_ManageStreamServer, ) error { recvC := make(chan *extensions_ssh.ClientMessage, 32) sendC := make(chan *extensions_ssh.ServerMessage, 32) eg, ctx := errgroup.WithContext(server.Context()) eg.Go(func() error { defer close(recvC) for { req, err := server.Recv() if err != nil { if errors.Is(err, io.EOF) { return nil } return err } recvC <- req } }) // XXX querier := storage.NewCachingQuerier( storage.NewQuerier(a.state.Load().dataBrokerClient), storage.GlobalCache, ) ctx = storage.WithQuerier(ctx, querier) eg.Go(func() error { for { select { case <-ctx.Done(): return nil case msg := <-sendC: if err := server.Send(msg); err != nil { if errors.Is(err, io.EOF) { return nil } return err } } } }) errC := make(chan error, 1) state := &StreamState{ Context: ctx, ErrorC: errC, } deviceAuthDone := make(chan struct{}) for { select { case err := <-errC: return err case req, ok := <-recvC: if !ok { return nil } switch req := req.Message.(type) { case *extensions_ssh.ClientMessage_Event: switch event := req.Event.Event.(type) { case *extensions_ssh.StreamEvent_DownstreamConnected: id := event.DownstreamConnected.StreamId if id == 0 { return fmt.Errorf("invalid stream ID: %v", id) } state.StreamID = id a.activeStreams.Put(id, state) defer a.activeStreams.Delete(id) case *extensions_ssh.StreamEvent_UpstreamConnected: case nil: } case *extensions_ssh.ClientMessage_AuthRequest: authReq := req.AuthRequest if state.Username == "" { state.Username = authReq.Username } if state.Hostname == "" { state.Hostname = authReq.Hostname } switch authReq.AuthMethod { case "publickey": methodReq, _ := authReq.MethodRequest.UnmarshalNew() pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest) if !ok { return fmt.Errorf("client sent invalid auth request message") } // // validate public key here // session, err := a.GetPomeriumSession(ctx, pubkeyReq.PublicKey) if err != nil { return err // XXX: wrap this error? } state.MethodsAuthenticated = append(state.MethodsAuthenticated, "publickey") state.PublicKey = pubkeyReq.PublicKey if authReq.Username == "" { return fmt.Errorf("no username given") } if session != nil { state.Session = session // Perform authorize check for this route req, err := a.getEvaluatorRequestFromSSHAuthRequest(state) if err != nil { return err } res, err := a.evaluate(ctx, req, &sessions.State{ID: session.Id}) if err != nil { return err } sendC <- handleEvaluatorResponseForSSH(res, state) if res.Allow.Value && !res.Deny.Value { a.startContinuousAuthorization(ctx, errC, req, session) } } if session == nil && !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") { resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Deny{ Deny: &extensions_ssh.DenyResponse{ Partial: true, Methods: []string{"keyboard-interactive"}, }, }, }, }, } sendC <- &resp } case "keyboard-interactive": route := a.getSSHRouteForHostname(state.Hostname) // route can be nil, in which case the default idp will be used opts := a.currentConfig.Load().Options idp, err := opts.GetIdentityProviderForPolicy(route) if err != nil { return err } authenticator, err := identity.NewAuthenticator(ctx, a.tracerProvider, oauth.Options{ RedirectURL: &url.URL{}, ProviderName: idp.GetType(), ProviderURL: idp.GetUrl(), ClientID: idp.GetClientId(), ClientSecret: idp.GetClientSecret(), Scopes: idp.GetScopes(), AuthCodeOptions: idp.GetRequestParams(), DeviceAuthClientType: idp.GetDeviceAuthClientType(), }) if err != nil { return err } deviceAuthResp, err := authenticator.DeviceAuth(ctx) if err != nil { return err } infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ Name: "Sign in with " + idp.GetType(), Instruction: deviceAuthResp.VerificationURIComplete, Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{}, } infoReqAny, _ := anypb.New(&infoReq) resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_InfoRequest{ InfoRequest: &extensions_ssh.InfoRequest{ Method: "keyboard-interactive", Request: infoReqAny, }, }, }, }, } sendC <- &resp go func() { var claims identity.SessionClaims token, err := authenticator.DeviceAccessToken(ctx, deviceAuthResp, &claims) if err != nil { errC <- err return } s := sessions.NewState(idp.Id) claims.Claims.Claims(&s) // XXX s.ID, err = getSessionIDForSSH(state.PublicKey) if err != nil { errC <- err return } fmt.Println(token) state.Session, err = a.PersistSession(ctx, s, claims, token) if err != nil { fmt.Println("error from PersistSession:", err) errC <- fmt.Errorf("error persisting session: %w", err) return } close(deviceAuthDone) }() } case *extensions_ssh.ClientMessage_InfoResponse: resp := req.InfoResponse if resp.Method == "keyboard-interactive" { r, _ := resp.Response.UnmarshalNew() respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses) if ok { fmt.Println(respInfo.Responses) } } select { case <-deviceAuthDone: case <-ctx.Done(): } if state.Session != nil { state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive") } else { resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Deny{ Deny: &extensions_ssh.DenyResponse{}, }, }, }, } sendC <- &resp continue } if slices.Contains(state.MethodsAuthenticated, "publickey") { // Perform authorize check for this route req, err := a.getEvaluatorRequestFromSSHAuthRequest(state) if err != nil { return err } res, err := a.evaluate(ctx, req, &sessions.State{ID: state.Session.Id}) if err != nil { return err } sendC <- handleEvaluatorResponseForSSH(res, state) if res.Allow.Value && !res.Deny.Value { a.startContinuousAuthorization(ctx, errC, req, state.Session) } } else { resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Deny{ Deny: &extensions_ssh.DenyResponse{ Partial: true, Methods: []string{"publickey"}, }, }, }, }, } sendC <- &resp } case nil: } } } } func (a *Authorize) getSSHRouteForHostname(hostname string) *config.Policy { opts := a.currentConfig.Load().Options from := "ssh://" + strings.TrimSuffix(strings.Join([]string{hostname, opts.SSHHostname}, "."), ".") for r := range opts.GetAllPolicies() { if r.From == from { return r } } return nil } func (a *Authorize) GetPomeriumSession( ctx context.Context, publicKey []byte, ) (*session.Session, error) { sessionID, err := getSessionIDForSSH(publicKey) if err != nil { return nil, err } fmt.Println("session ID:", sessionID) // XXX session, err := session.Get(ctx, a.GetDataBrokerServiceClient(), sessionID) if err != nil { if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound { return nil, nil } return nil, err } return session, nil } func getSessionIDForSSH(publicKey []byte) (string, error) { // XXX: get the fingerprint from Envoy rather than computing it here k, err := gossh.ParsePublicKey(publicKey) if err != nil { return "", fmt.Errorf("couldn't parse ssh key: %w", err) } return "sshkey-" + gossh.FingerprintSHA256(k), nil } func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest( state *StreamState, ) (*evaluator.Request, error) { sessionID, err := getSessionIDForSSH(state.PublicKey) if err != nil { return nil, err } route := a.getSSHRouteForHostname(state.Hostname) if route == nil { return &evaluator.Request{ IsInternal: true, Session: evaluator.RequestSession{ ID: sessionID, }, }, nil // return nil, fmt.Errorf("no route found for hostname %q", state.Hostname) } req := &evaluator.Request{ IsInternal: false, HTTP: evaluator.RequestHTTP{ Hostname: route.From, // XXX: this is not quite right // IP: ? // TODO }, Session: evaluator.RequestSession{ ID: sessionID, }, Policy: route, } return req, nil } func handleEvaluatorResponseForSSH( result *evaluator.Result, state *StreamState, ) *extensions_ssh.ServerMessage { // fmt.Printf(" *** evaluator result: %+v\n", result) // TODO: ideally there would be a way to keep this in sync with the logic in check_response.go allow := result.Allow.Value && !result.Deny.Value if allow { pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey)) if state.Hostname == "" { return &extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Allow{ Allow: &extensions_ssh.AllowResponse{ Username: state.Username, Target: &extensions_ssh.AllowResponse_Internal{ Internal: &extensions_ssh.InternalTarget{ SetMetadata: &corev3.Metadata{ FilterMetadata: map[string]*structpb.Struct{ "pomerium": { Fields: map[string]*structpb.Value{ "stream-id": structpb.NewStringValue(strconv.FormatUint(state.StreamID, 10)), }, }, }, }, }, }, }, }, }, }, } } sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()), Format: extensions_session_recording.Format_AsciicastFormat, }) return &extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Allow{ Allow: &extensions_ssh.AllowResponse{ Username: state.Username, Target: &extensions_ssh.AllowResponse_Upstream{ Upstream: &extensions_ssh.UpstreamTarget{ Hostname: state.Hostname, AllowedMethods: []*extensions_ssh.AllowedMethod{ { Method: "publickey", MethodData: pkData, }, { Method: "keyboard-interactive", }, }, Extensions: []*corev3.TypedExtensionConfig{ { TypedConfig: sessionRecordingExt, }, }, }, }, }, }, }, }, } } // XXX: do we want to send an equivalent to the "show error details" output // in the case of a deny result? // XXX: this is not quite right -- needs to exactly match the last list of methods methods := []string{"publickey"} if slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") { methods = append(methods, "keyboard-interactive") } return &extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Deny{ Deny: &extensions_ssh.DenyResponse{ Methods: methods, }, }, }, }, } } func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse { return &extensions_ssh.PublicKeyAllowResponse{ PublicKey: publicKey, Permissions: &extensions_ssh.Permissions{ PermitPortForwarding: true, PermitAgentForwarding: true, PermitX11Forwarding: true, PermitPty: true, PermitUserRc: true, ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)), // XXX: tie this to Pomerium session lifetime? ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)), }, } } // PersistSession stores session and user data in the databroker. func (a *Authorize) PersistSession( ctx context.Context, sessionState *sessions.State, // XXX: consider not using this struct claims identity.SessionClaims, accessToken *oauth2.Token, ) (*session.Session, error) { now := time.Now() sessionLifetime := a.currentConfig.Load().Options.CookieExpire sessionExpiry := timestamppb.New(now.Add(sessionLifetime)) sess := &session.Session{ Id: sessionState.ID, UserId: sessionState.UserID(), IssuedAt: timestamppb.New(now), AccessedAt: timestamppb.New(now), ExpiresAt: sessionExpiry, OauthToken: manager.ToOAuthToken(accessToken), Audience: sessionState.Audience, } sess.SetRawIDToken(claims.RawIDToken) sess.AddClaims(claims.Flatten()) // XXX: do we need to create a user record too? // compare with Stateful.PersistSession() res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess) if err != nil { return nil, err } sessionState.DatabrokerServerVersion = res.GetServerVersion() sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion() return sess, nil } func (a *Authorize) startContinuousAuthorization( ctx context.Context, errC chan<- error, req *evaluator.Request, session *session.Session, ) { recheck := func() { // XXX: probably want to log the results of this evaluation only if it changes res, _ := a.evaluate(ctx, req, &sessions.State{ID: session.Id}) if !res.Allow.Value || res.Deny.Value { errC <- fmt.Errorf("no longer authorized") } } keyReq := &databroker.QueryRequest{ Type: grpcutil.GetTypeURL(session), Limit: 1, } keyReq.SetFilterByIDOrIndex(session.Id) key, err := (&proto.MarshalOptions{ Deterministic: true, }).Marshal(keyReq) if err != nil { panic(err) } ticker := time.NewTicker(10 * time.Second) go func() { for { select { case <-a.sessionsCacheWarmer.cache.Wait(key): errC <- fmt.Errorf("session expired") return case <-ticker.C: recheck() case <-ctx.Done(): ticker.Stop() return } } }() } func marshalAny(msg proto.Message) *anypb.Any { a, err := anypb.New(msg) if err != nil { panic(err) } return a } // sentinel error to indicate that the command triggered a handoff, and we // should not automatically disconnect var ErrHandoff = errors.New("handoff") func (a *Authorize) ServeChannel( server extensions_ssh.StreamManagement_ServeChannelServer, ) error { ctx := server.Context() inputR, inputW := io.Pipe() outputR, outputW := io.Pipe() var peerId uint32 var activeProgram atomic.Pointer[tea.Program] errC := make(chan error, 1) remoteWindow := &window{Cond: sync.NewCond(&sync.Mutex{})} sendC := make(chan any, 8) recvC := make(chan *extensions_ssh.ChannelMessage) go func() { for { select { case msg := <-sendC: switch msg := msg.(type) { case *extensions_ssh.ChannelControl: log.Ctx(ctx).Debug().Msg("sending channel control message") if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_ChannelControl{ ChannelControl: msg, }, }); err != nil { errC <- err return } case windowAdjustMsg, channelRequestMsg, channelRequestSuccessMsg, channelRequestFailureMsg, channelEOFMsg: // these messages don't consume window space data := gossh.Marshal(msg) if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: wrapperspb.Bytes(data), }, }); err != nil { errC <- err return } log.Ctx(ctx).Debug().Uint8("type", data[0]).Msg("message sent") default: data := gossh.Marshal(msg) need := uint32(len(data)) have := uint32(0) for have < need { n, err := remoteWindow.reserve(need - have) if err != nil { errC <- err return } have += n } if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: wrapperspb.Bytes(data), }, }); err != nil { errC <- err return } log.Ctx(ctx).Debug().Uint8("type", data[0]).Uint32("size", need).Msg("message sent") } case <-ctx.Done(): errC <- ctx.Err() return } } }() var state *StreamState go func() { localWindow := uint32(channelWindowSize) for { channelMsg, err := server.Recv() if err != nil { if errors.Is(err, io.EOF) { errC <- nil return } errC <- err return } if state == nil { mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata) if !ok { errC <- fmt.Errorf("first message was not metadata") return } idStr := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["stream-id"].GetStringValue() if idStr == "" { errC <- fmt.Errorf("no session ID found for stream %q", idStr) return } id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { errC <- fmt.Errorf("invalid stream ID %q: %w", idStr, err) return } if v := a.activeStreams.Get(id); v != nil { state = v } else { errC <- fmt.Errorf("no stream state found for ID %d", id) return } continue } if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok { msgLen := uint32(len(raw.RawBytes.GetValue())) if msgLen == 0 { errC <- status.Errorf(codes.InvalidArgument, "peer sent empty message") return } if msgLen > channelMaxPacket { errC <- status.Errorf(codes.ResourceExhausted, "message too large") return } log.Ctx(ctx).Debug().Uint8("type", raw.RawBytes.Value[0]).Uint32("size", msgLen).Msg("message received") // peek the first byte to check if we need to deduct from the window switch raw.RawBytes.Value[0] { case msgChannelWindowAdjust, msgChannelRequest, msgChannelSuccess, msgChannelFailure, msgChannelEOF: // these messages don't consume window space default: if localWindow < msgLen { errC <- status.Errorf(codes.ResourceExhausted, "peer sent more bytes than allowed by channel window") return } localWindow -= msgLen if localWindow < channelWindowSize/2 { log.Ctx(ctx).Debug().Msg("flow control: increasing local window size") localWindow += channelWindowSize sendC <- windowAdjustMsg{ PeersID: peerId, AdditionalBytes: channelWindowSize, } } } } select { case recvC <- channelMsg: case <-ctx.Done(): errC <- ctx.Err() return } } }() var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo var channelIdCounter uint32 for { select { case channelMsg := <-recvC: rawMsg := channelMsg.GetRawBytes().GetValue() switch rawMsg[0] { case msgChannelOpen: var msg channelOpenMsg gossh.Unmarshal(rawMsg, &msg) channelIdCounter++ if channelIdCounter > 1 { return fmt.Errorf("only one channel can be opened") } peerId = msg.PeersID downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{ ChannelType: msg.ChanType, DownstreamChannelId: peerId, InternalUpstreamChannelId: channelIdCounter, InitialWindowSize: msg.PeersWindow, MaxPacketSize: msg.MaxPacketSize, } remoteWindow.add(msg.PeersWindow) switch msg.ChanType { case "session": sendC <- channelOpenConfirmMsg{ PeersID: peerId, MyID: channelIdCounter, MyWindow: channelWindowSize, MaxPacketSize: channelMaxPacket, } case "direct-tcpip": var subMsg channelOpenDirectMsg if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil { return err } handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ Action: &extensions_ssh.SSHChannelControlAction_HandOff{ HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ DownstreamChannelInfo: downstreamChannelInfo, UpstreamAuth: &extensions_ssh.AllowResponse{ Target: &extensions_ssh.AllowResponse_Upstream{ Upstream: &extensions_ssh.UpstreamTarget{ Hostname: subMsg.DestAddr, DirectTcpip: true, }, }, }, }, }, }) if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_ChannelControl{ ChannelControl: &extensions_ssh.ChannelControl{ Protocol: "ssh", ControlAction: handOff, }, }, }); err != nil { return err } } case msgChannelRequest: var msg channelRequestMsg gossh.Unmarshal(rawMsg, &msg) switch msg.Request { case "shell", "exec": if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(channelRequestSuccessMsg{ PeersID: peerId, }), }, }, }); err != nil { return err } cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, state, inputR, outputW, sendC, &activeProgram) if msg.Request == "shell" { cmd.SetArgs([]string{"portal"}) } else { var execReq execChannelRequestMsg if err := gossh.Unmarshal(msg.RequestSpecificData, &execReq); err != nil { return err } cmd.SetArgs(strings.Fields(execReq.Command)) } go func() { defer activeProgram.Store(nil) defer outputW.Close() defer inputR.Close() err := cmd.Execute() if !errors.Is(err, ErrHandoff) { sendC <- &extensions_ssh.ChannelControl{ Protocol: "ssh", ControlAction: marshalAny(&extensions_ssh.SSHChannelControlAction_Disconnect{ ReasonCode: 11, }), } } }() go streamOutputToChannel(sendC, peerId, outputR) case "pty-req": req := parsePtyReq(msg.RequestSpecificData) downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ TermEnv: req.TermEnv, WidthColumns: req.Width, HeightRows: req.Height, WidthPx: req.WidthPx, HeightPx: req.HeightPx, Modes: req.Modes, } sendC <- channelRequestSuccessMsg{PeersID: peerId} case "window-change": var req channelWindowChangeRequestMsg if err := gossh.Unmarshal(msg.RequestSpecificData, &req); err != nil { return err } if p := activeProgram.Load(); p != nil { p.Send(tea.WindowSizeMsg{ Width: int(req.WidthColumns), Height: int(req.HeightRows), }) } } case msgChannelData: var msg channelDataMsg gossh.Unmarshal(rawMsg, &msg) if activeProgram.Load() != nil { inputW.Write(msg.Rest) } case msgChannelClose: var msg channelDataMsg gossh.Unmarshal(rawMsg, &msg) case msgChannelWindowAdjust: var msg windowAdjustMsg if err := gossh.Unmarshal(rawMsg, &msg); err != nil { return err } log.Ctx(ctx).Debug().Uint32("bytes", msg.AdditionalBytes).Msg("flow control: remote window size increased") remoteWindow.add(msg.AdditionalBytes) case msgChannelEOF: return nil default: panic("unhandled message: " + fmt.Sprint(rawMsg[1])) } case err := <-errC: log.Ctx(ctx).Err(err).Msg("channel error") return err } } } type ptyReq struct { TermEnv string Width, Height uint32 WidthPx, HeightPx uint32 Modes []byte } func parsePtyReq(reqData []byte) ptyReq { termEnvLen := binary.BigEndian.Uint32(reqData) reqData = reqData[4:] termEnv := string(reqData[:termEnvLen]) reqData = reqData[termEnvLen:] return ptyReq{ TermEnv: termEnv, Width: binary.BigEndian.Uint32(reqData), Height: binary.BigEndian.Uint32(reqData[4:]), WidthPx: binary.BigEndian.Uint32(reqData[8:]), HeightPx: binary.BigEndian.Uint32(reqData[12:]), Modes: reqData[16:], } } func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader) { var buf [4096]byte for { n, err := outputR.Read(buf[:]) if err != nil { return } sendC <- channelDataMsg{ PeersID: channelID, Length: uint32(n), Rest: slices.Clone(buf[:n]), } } } func (a *Authorize) NewSSHCLI( cfg *config.Config, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo, state *StreamState, stdin io.Reader, stdout io.Writer, sendC chan<- any, activeProgram *atomic.Pointer[tea.Program], ) *cobra.Command { cmd := &cobra.Command{ Use: "pomerium", PersistentPreRunE: func(cmd *cobra.Command, args []string) error { _, cmdIsInteractive := cmd.Annotations["interactive"] switch { case (ptyInfo == nil) && cmdIsInteractive: cmd.SilenceUsage = true return fmt.Errorf("\x1b[31m'%s' is an interactive command and requires a TTY (try passing '-t' to ssh)\x1b[0m", cmd.Use) case (ptyInfo != nil) && !cmdIsInteractive: cmd.SilenceUsage = true return fmt.Errorf("\x1b[31m'%s' is not an interactive command (try passing '-T' to ssh, or removing '-t')\x1b[0m\r", cmd.Use) } return nil }, } sessionID := state.Session.Id cmd.AddCommand(a.NewPortalCommand(cfg, ptyInfo, channelInfo, state, sendC, activeProgram)) cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID)) cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID)) cmd.CompletionOptions.DisableDefaultCmd = true cmd.SetIn(stdin) cmd.SetOut(stdout) cmd.SetErr(stdout) return cmd } func (a *Authorize) NewLogoutCommand( cfg *config.Config, sessionID string, ) *cobra.Command { cmd := &cobra.Command{ Use: "logout", Short: "Log out", RunE: func(cmd *cobra.Command, args []string) error { client := a.state.Load().dataBrokerClient err := session.Delete(cmd.Context(), client, sessionID) if err != nil { return fmt.Errorf("internal error: %w", err) } cmd.OutOrStdout().Write([]byte("Logged out successfully\r\n")) return nil }, } return cmd } var whoamiTmpl = template.Must(template.New("whoami").Parse(` User ID: {{.UserId}} Session ID: {{.Id}} Expires at: {{.ExpiresAt.AsTime}} Claims: {{- range $k, $v := .Claims }} {{ $k }}: {{ $v.AsSlice }} {{- end }} `)) func (a *Authorize) NewWhoamiCommand( cfg *config.Config, sessionID string, ) *cobra.Command { cmd := &cobra.Command{ Use: "whoami", RunE: func(cmd *cobra.Command, args []string) error { client := a.state.Load().dataBrokerClient s, err := session.Get(cmd.Context(), client, sessionID) if err != nil { return fmt.Errorf("couldn't fetch session: %w", err) } var b bytes.Buffer whoamiTmpl.Execute(&b, s) cmd.OutOrStdout().Write([]byte(b.String() + "\r\n")) return nil }, } return cmd } func (a *Authorize) NewPortalCommand( cfg *config.Config, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo, channelInfo *extensions_ssh.SSHDownstreamChannelInfo, state *StreamState, sendC chan<- any, activeProgram *atomic.Pointer[tea.Program], ) *cobra.Command { cmd := &cobra.Command{ Use: "portal", Short: "Interactive route portal", Annotations: map[string]string{ "interactive": "", }, RunE: func(cmd *cobra.Command, args []string) error { var routes []string for r := range cfg.Options.GetAllPolicies() { if strings.HasPrefix(r.From, "ssh://") { routes = append(routes, fmt.Sprintf("%s@%s", state.Username, strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname))) } } items := []list.Item{} for _, route := range routes { items = append(items, item(route)) } a.activeStreams.Range(func(id uint64, state *StreamState) { if id != state.StreamID { items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", id))) } }) l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2)) l.Title = "Connect to which server?" l.SetShowStatusBar(false) l.SetFilteringEnabled(false) l.Styles.Title = titleStyle l.Styles.PaginationStyle = paginationStyle l.Styles.HelpStyle = helpStyle program := tea.NewProgram(model{list: l}, tea.WithInput(cmd.InOrStdin()), tea.WithOutput(cmd.OutOrStdout()), tea.WithAltScreen(), tea.WithContext(cmd.Context()), tea.WithEnvironment([]string{"TERM=" + ptyInfo.TermEnv}), ) activeProgram.Store(program) go program.Send(tea.WindowSizeMsg{Width: int(ptyInfo.WidthColumns), Height: int(ptyInfo.HeightRows)}) answer, err := program.Run() if err != nil { return err } if answer.(model).choice == "" { return nil // quit/ctrl+c } var handOff *anypb.Any if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") { id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64) if err != nil { panic(err) } handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{ Action: &extensions_ssh.SSHChannelControlAction_HandOff{ HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ DownstreamChannelInfo: channelInfo, DownstreamPtyInfo: ptyInfo, UpstreamAuth: &extensions_ssh.AllowResponse{ Target: &extensions_ssh.AllowResponse_MirrorSession{ MirrorSession: &extensions_ssh.MirrorSessionTarget{ SourceId: id, Mode: extensions_ssh.MirrorSessionTarget_ReadWrite, }, }, }, }, }, }) } else { username, hostname, _ := strings.Cut(answer.(model).choice, "@") // Perform authorize check for this route state.Hostname = hostname if username != state.Username { return fmt.Errorf("internal error: username mismatch") } req, err := a.getEvaluatorRequestFromSSHAuthRequest(state) if err != nil { return err } res, err := a.evaluate(cmd.Context(), req, &sessions.State{ID: state.Session.Id}) if err != nil { return err } if res.Allow.Value && !res.Deny.Value { a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session) } else { return fmt.Errorf("not authorized") } sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()), Format: extensions_session_recording.Format_AsciicastFormat, }) handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{ Action: &extensions_ssh.SSHChannelControlAction_HandOff{ HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ DownstreamChannelInfo: channelInfo, DownstreamPtyInfo: ptyInfo, UpstreamAuth: &extensions_ssh.AllowResponse{ Username: username, Target: &extensions_ssh.AllowResponse_Upstream{ Upstream: &extensions_ssh.UpstreamTarget{ AllowMirrorConnections: true, Hostname: hostname, Extensions: []*corev3.TypedExtensionConfig{ { TypedConfig: sessionRecordingExt, }, }, }, }, }, }, }, }) } sendC <- &extensions_ssh.ChannelControl{ Protocol: "ssh", ControlAction: handOff, } return ErrHandoff }, } return cmd } var ( titleStyle = lipgloss.NewStyle().MarginLeft(2) itemStyle = lipgloss.NewStyle().PaddingLeft(4) selectedItemStyle = lipgloss.NewStyle().PaddingLeft(2).Foreground(lipgloss.Color("170")) paginationStyle = list.DefaultStyles().PaginationStyle.PaddingLeft(4) helpStyle = list.DefaultStyles().HelpStyle.PaddingLeft(4).PaddingBottom(1) quitTextStyle = lipgloss.NewStyle().Margin(1, 0, 2, 4) ) type item string func (i item) FilterValue() string { return "" } type itemDelegate struct{} func (d itemDelegate) Height() int { return 1 } func (d itemDelegate) Spacing() int { return 0 } func (d itemDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil } func (d itemDelegate) Render(w io.Writer, m list.Model, index int, listItem list.Item) { i, ok := listItem.(item) if !ok { return } str := fmt.Sprintf("%d. %s", index+1, i) fn := itemStyle.Render if index == m.Index() { fn = func(s ...string) string { return selectedItemStyle.Render("> " + strings.Join(s, " ")) } } fmt.Fprint(w, fn(str)) } type model struct { list list.Model choice string quitting bool } func (m model) Init() tea.Cmd { return nil } func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: m.list.SetWidth(msg.Width - 2) m.list.SetHeight(msg.Height - 2) return m, nil case tea.KeyMsg: switch keypress := msg.String(); keypress { case "q", "ctrl+c": m.quitting = true return m, tea.Quit case "enter": i, ok := m.list.SelectedItem().(item) if ok { m.choice = string(i) } return m, tea.Quit } } var cmd tea.Cmd m.list, cmd = m.list.Update(msg) return m, cmd } func (m model) View() string { return "\n" + m.list.View() } // code below copied from x/crypto/ssh/common.go const ( // channelMaxPacket contains the maximum number of bytes that will be // sent in a single packet. As per RFC 4253, section 6.1, 32k is also // the minimum. channelMaxPacket = 1 << 15 // We follow OpenSSH here. channelWindowSize = 64 * channelMaxPacket ) // window represents the buffer available to clients // wishing to write to a channel. type window struct { *sync.Cond win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 writeWaiters int closed bool } // add adds win to the amount of window available // for consumers. func (w *window) add(win uint32) bool { // a zero sized window adjust is a noop. if win == 0 { return true } w.L.Lock() if w.win+win < win { w.L.Unlock() return false } w.win += win // It is unusual that multiple goroutines would be attempting to reserve // window space, but not guaranteed. Use broadcast to notify all waiters // that additional window is available. w.Broadcast() w.L.Unlock() return true } // close sets the window to closed, so all reservations fail // immediately. func (w *window) close() { w.L.Lock() w.closed = true w.Broadcast() w.L.Unlock() } // reserve reserves win from the available window capacity. // If no capacity remains, reserve will block. reserve may // return less than requested. func (w *window) reserve(win uint32) (uint32, error) { var err error w.L.Lock() w.writeWaiters++ w.Broadcast() for w.win == 0 && !w.closed { w.Wait() } w.writeWaiters-- if w.win < win { win = w.win } w.win -= win if w.closed { err = io.EOF } w.L.Unlock() return win, err } // code below copied from x/crypto/ssh/messages.go // (with some additional messages not included there) // See RFC 4254, section 5.1. const msgChannelOpen = 90 type channelOpenMsg struct { ChanType string `sshtype:"90"` PeersID uint32 PeersWindow uint32 MaxPacketSize uint32 TypeSpecificData []byte `ssh:"rest"` } const ( msgChannelExtendedData = 95 msgChannelData = 94 ) // See RFC 4253, section 11.1. const msgDisconnect = 1 // disconnectMsg is the message that signals a disconnect. It is also // the error type returned from mux.Wait() type disconnectMsg struct { Reason uint32 `sshtype:"1"` Message string Language string } // Used for debug print outs of packets. type channelDataMsg struct { PeersID uint32 `sshtype:"94"` Length uint32 Rest []byte `ssh:"rest"` } // See RFC 4254, section 5.1. const msgChannelOpenConfirm = 91 type channelOpenConfirmMsg struct { PeersID uint32 `sshtype:"91"` MyID uint32 MyWindow uint32 MaxPacketSize uint32 TypeSpecificData []byte `ssh:"rest"` } const msgChannelRequest = 98 type channelRequestMsg struct { PeersID uint32 `sshtype:"98"` Request string WantReply bool RequestSpecificData []byte `ssh:"rest"` } type channelOpenDirectMsg struct { DestAddr string DestPort uint32 SrcAddr string SrcPort uint32 } type channelWindowChangeRequestMsg struct { WidthColumns uint32 HeightRows uint32 WidthPx uint32 HeightPx uint32 } type shellChannelRequestMsg struct{} type execChannelRequestMsg struct { Command string } // See RFC 4254, section 5.2 const msgChannelWindowAdjust = 93 type windowAdjustMsg struct { PeersID uint32 `sshtype:"93"` AdditionalBytes uint32 } // See RFC 4254, section 5.4. const msgChannelSuccess = 99 type channelRequestSuccessMsg struct { PeersID uint32 `sshtype:"99"` } // See RFC 4254, section 5.4. const msgChannelFailure = 100 type channelRequestFailureMsg struct { PeersID uint32 `sshtype:"100"` } // See RFC 4254, section 5.3 const msgChannelClose = 97 type channelCloseMsg struct { PeersID uint32 `sshtype:"97"` } // See RFC 4254, section 5.3 const msgChannelEOF = 96 type channelEOFMsg struct { PeersID uint32 `sshtype:"96"` }