package authorize import ( "bytes" "context" "encoding/binary" "errors" "fmt" "io" "net/url" "slices" "strings" "sync/atomic" "text/template" "time" "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/manager" "github.com/pomerium/pomerium/pkg/identity/oauth" "github.com/pomerium/pomerium/pkg/storage" gossh "golang.org/x/crypto/ssh" "golang.org/x/oauth2" "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) type StreamState struct { Username string Hostname string PublicKey []byte MethodsAuthenticated []string } func (a *Authorize) ManageStream( server extensions_ssh.StreamManagement_ManageStreamServer, ) error { recvC := make(chan *extensions_ssh.ClientMessage, 32) sendC := make(chan *extensions_ssh.ServerMessage, 32) eg, ctx := errgroup.WithContext(server.Context()) eg.Go(func() error { defer close(recvC) for { req, err := server.Recv() if err != nil { if errors.Is(err, io.EOF) { return nil } return err } recvC <- req } }) // XXX querier := storage.NewCachingQuerier( storage.NewQuerier(a.state.Load().dataBrokerClient), a.globalCache, ) ctx = storage.WithQuerier(ctx, querier) eg.Go(func() error { for { select { case <-ctx.Done(): return nil case msg := <-sendC: if err := server.Send(msg); err != nil { if errors.Is(err, io.EOF) { return nil } return err } } } }) var state StreamState //deviceAuthSuccess := &atomic.Bool{} sessionState := &atomic.Pointer[sessions.State]{} errC := make(chan error, 1) a.activeStreamsMu.Lock() a.activeStreams = append(a.activeStreams, errC) a.activeStreamsMu.Unlock() for { select { case err := <-errC: return err case req, ok := <-recvC: if !ok { return nil } switch req := req.Message.(type) { case *extensions_ssh.ClientMessage_Event: switch event := req.Event.Event.(type) { case *extensions_ssh.StreamEvent_DownstreamConnected: fmt.Println("downstream connected") _ = event case nil: } case *extensions_ssh.ClientMessage_AuthRequest: authReq := req.AuthRequest fmt.Println("auth request") if state.Username == "" { state.Username = authReq.Username } if state.Hostname == "" { state.Hostname = authReq.Hostname } switch authReq.AuthMethod { case "publickey": methodReq, _ := authReq.MethodRequest.UnmarshalNew() pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest) if !ok { return fmt.Errorf("client sent invalid auth request message") } // // validate public key here // session, err := a.GetPomeriumSession(ctx, pubkeyReq.PublicKey) if err != nil { return err // XXX: wrap this error? } state.MethodsAuthenticated = append(state.MethodsAuthenticated, "publickey") state.PublicKey = pubkeyReq.PublicKey if authReq.Username == "" && authReq.Hostname == "" { pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey)) resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Allow{ Allow: &extensions_ssh.AllowResponse{ Username: state.Username, Hostname: state.Hostname, AllowedMethods: []*extensions_ssh.AllowedMethod{ { Method: "publickey", MethodData: pkData, }, }, Target: extensions_ssh.Target_Internal, }, }, }, }, } sendC <- &resp continue } if session != nil { // Perform authorize check for this route req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state) if err != nil { return err } res, err := a.evaluate(ctx, req, &sessions.State{ID: session.Id}) if err != nil { return err } sendC <- handleEvaluatorResponseForSSH(res, &state) } if session == nil && !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") { resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Deny{ Deny: &extensions_ssh.DenyResponse{ Partial: true, Methods: []string{"keyboard-interactive"}, }, }, }, }, } sendC <- &resp } case "keyboard-interactive": route := a.getSSHRouteForHostname(state.Hostname) if route == nil { return fmt.Errorf("invalid route") } opts := a.currentOptions.Load() idp, err := opts.GetIdentityProviderForPolicy(route) if err != nil { return err } authenticator, err := identity.NewAuthenticator(ctx, a.tracerProvider, oauth.Options{ RedirectURL: &url.URL{}, ProviderName: idp.GetType(), ProviderURL: idp.GetUrl(), ClientID: idp.GetClientId(), ClientSecret: idp.GetClientSecret(), Scopes: idp.GetScopes(), AuthCodeOptions: idp.GetRequestParams(), DeviceAuthClientType: idp.GetDeviceAuthClientType(), }) if err != nil { return err } deviceAuthResp, err := authenticator.DeviceAuth(ctx) if err != nil { return err } infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ Name: "Sign in with " + idp.GetType(), Instruction: deviceAuthResp.VerificationURIComplete, Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ //{}, // XXX: proof of concept (no prompt) }, } infoReqAny, _ := anypb.New(&infoReq) resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_InfoRequest{ InfoRequest: &extensions_ssh.InfoRequest{ Method: "keyboard-interactive", Request: infoReqAny, }, }, }, }, } sendC <- &resp go func() { var claims identity.SessionClaims token, err := authenticator.DeviceAccessToken(ctx, deviceAuthResp, &claims) if err != nil { errC <- err return } s := sessions.NewState(idp.Id) claims.Claims.Claims(&s) // XXX s.ID, err = getSessionIDForSSH(state.PublicKey) if err != nil { errC <- err return } fmt.Println(token) err = a.PersistSession(ctx, s, claims, token) if err != nil { fmt.Println("error from PersistSession:", err) errC <- fmt.Errorf("error persisting session: %w", err) return } sessionState.Store(s) }() } case *extensions_ssh.ClientMessage_InfoResponse: resp := req.InfoResponse if resp.Method == "keyboard-interactive" { r, _ := resp.Response.UnmarshalNew() respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses) if ok { fmt.Println(respInfo.Responses) } } // XXX: proof of concept -- busy wait for login to complete for sessionState.Load() == nil { time.Sleep(time.Second) } if sessionState.Load() != nil { state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive") } else { retryReq := extensions_ssh.KeyboardInteractiveInfoPrompts{ Name: "", Instruction: "Login not successful yet, try again", Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{ {}, }, } infoReqAny, _ := anypb.New(&retryReq) resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_InfoRequest{ InfoRequest: &extensions_ssh.InfoRequest{ Method: "keyboard-interactive", Request: infoReqAny, }, }, }, }, } sendC <- &resp continue } if slices.Contains(state.MethodsAuthenticated, "publickey") { // Perform authorize check for this route req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state) if err != nil { return err } res, err := a.evaluate(ctx, req, sessionState.Load()) if err != nil { return err } sendC <- handleEvaluatorResponseForSSH(res, &state) } else { resp := extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Deny{ Deny: &extensions_ssh.DenyResponse{ Partial: true, Methods: []string{"publickey"}, }, }, }, }, } sendC <- &resp } case nil: } } } return eg.Wait() } func (a *Authorize) getSSHRouteForHostname(hostname string) *config.Policy { opts := a.currentOptions.Load() from := "ssh://" + strings.TrimSuffix(strings.Join([]string{hostname, opts.SSHHostname}, "."), ".") for r := range opts.GetAllPolicies() { if r.From == from { return r } } return nil } func (a *Authorize) GetPomeriumSession( ctx context.Context, publicKey []byte, ) (*session.Session, error) { sessionID, err := getSessionIDForSSH(publicKey) if err != nil { return nil, err } fmt.Println("session ID:", sessionID) // XXX session, err := session.Get(ctx, a.GetDataBrokerServiceClient(), sessionID) if err != nil { if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound { return nil, nil } return nil, err } return session, nil } func getSessionIDForSSH(publicKey []byte) (string, error) { // XXX: get the fingerprint from Envoy rather than computing it here k, err := gossh.ParsePublicKey(publicKey) if err != nil { return "", fmt.Errorf("couldn't parse ssh key: %w", err) } return "sshkey-" + gossh.FingerprintSHA256(k), nil } func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest( state *StreamState, ) (*evaluator.Request, error) { sessionID, err := getSessionIDForSSH(state.PublicKey) if err != nil { return nil, err } route := a.getSSHRouteForHostname(state.Hostname) if route == nil { return nil, fmt.Errorf("no route found for hostname %q", state.Hostname) } req := &evaluator.Request{ IsInternal: false, HTTP: evaluator.RequestHTTP{ Hostname: route.From, // XXX: this is not quite right //IP: ? // TODO }, Session: evaluator.RequestSession{ ID: sessionID, }, Policy: route, } return req, nil } func handleEvaluatorResponseForSSH( result *evaluator.Result, state *StreamState, ) *extensions_ssh.ServerMessage { //fmt.Printf(" *** evaluator result: %+v\n", result) // TODO: ideally there would be a way to keep this in sync with the logic in check_response.go allow := result.Allow.Value && !result.Deny.Value // XXX sessionID, _ := getSessionIDForSSH(state.PublicKey) if allow { pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey)) return &extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Allow{ Allow: &extensions_ssh.AllowResponse{ Username: state.Username, Hostname: state.Hostname, AllowedMethods: []*extensions_ssh.AllowedMethod{ { Method: "publickey", MethodData: pkData, }, { Method: "keyboard-interactive", }, }, //Target: extensions_ssh.Target_Upstream, Target: extensions_ssh.Target_Internal, SetMetadata: &envoy_config_core_v3.Metadata{ FilterMetadata: map[string]*structpb.Struct{ "pomerium.ssh": { Fields: map[string]*structpb.Value{ "pomerium-session-id": structpb.NewStringValue(sessionID), }, }, }, }, }, }, }, }, } } // XXX: do we want to send an equivalent to the "show error details" output // in the case of a deny result? // XXX: this is not quite right -- needs to exactly match the last list of methods methods := []string{"publickey"} if slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") { methods = append(methods, "keyboard-interactive") } return &extensions_ssh.ServerMessage{ Message: &extensions_ssh.ServerMessage_AuthResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{ Response: &extensions_ssh.AuthenticationResponse_Deny{ Deny: &extensions_ssh.DenyResponse{ Methods: methods, }, }, }, }, } } func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse { return &extensions_ssh.PublicKeyAllowResponse{ PublicKey: publicKey, Permissions: &extensions_ssh.Permissions{ PermitPortForwarding: true, PermitAgentForwarding: true, PermitX11Forwarding: true, PermitPty: true, PermitUserRc: true, ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)), // XXX: tie this to Pomerium session lifetime? ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)), }, } } // PersistSession stores session and user data in the databroker. func (a *Authorize) PersistSession( ctx context.Context, sessionState *sessions.State, // XXX: consider not using this struct claims identity.SessionClaims, accessToken *oauth2.Token, ) error { now := time.Now() sessionLifetime := a.currentOptions.Load().CookieExpire sessionExpiry := timestamppb.New(now.Add(sessionLifetime)) sess := &session.Session{ Id: sessionState.ID, UserId: sessionState.UserID(), IssuedAt: timestamppb.New(now), AccessedAt: timestamppb.New(now), ExpiresAt: sessionExpiry, OauthToken: manager.ToOAuthToken(accessToken), Audience: sessionState.Audience, } sess.SetRawIDToken(claims.RawIDToken) sess.AddClaims(claims.Flatten()) // XXX: do we need to create a user record too? // compare with Stateful.PersistSession() res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess) if err != nil { return err } sessionState.DatabrokerServerVersion = res.GetServerVersion() sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion() return nil } // See RFC 4254, section 5.1. const msgChannelOpen = 90 type channelOpenMsg struct { ChanType string `sshtype:"90"` PeersID uint32 PeersWindow uint32 MaxPacketSize uint32 TypeSpecificData []byte `ssh:"rest"` } const ( msgChannelExtendedData = 95 msgChannelData = 94 ) // Used for debug print outs of packets. type channelDataMsg struct { PeersID uint32 `sshtype:"94"` Length uint32 Rest []byte `ssh:"rest"` } // See RFC 4254, section 5.1. const msgChannelOpenConfirm = 91 type channelOpenConfirmMsg struct { PeersID uint32 `sshtype:"91"` MyID uint32 MyWindow uint32 MaxPacketSize uint32 TypeSpecificData []byte `ssh:"rest"` } const msgChannelRequest = 98 type channelRequestMsg struct { PeersID uint32 `sshtype:"98"` Request string WantReply bool RequestSpecificData []byte `ssh:"rest"` } // See RFC 4254, section 5.4. const msgChannelSuccess = 99 type channelRequestSuccessMsg struct { PeersID uint32 `sshtype:"99"` } // See RFC 4254, section 5.4. const msgChannelFailure = 100 type channelRequestFailureMsg struct { PeersID uint32 `sshtype:"100"` } // See RFC 4254, section 5.3 const msgChannelClose = 97 type channelCloseMsg struct { PeersID uint32 `sshtype:"97"` } // See RFC 4254, section 5.3 const msgChannelEOF = 96 type channelEOFMsg struct { PeersID uint32 `sshtype:"96"` } func (a *Authorize) ServeChannel( server extensions_ssh.StreamManagement_ServeChannelServer, ) error { //inputR, inputW := io.Pipe() //outputR, outputW := io.Pipe() var peerId uint32 var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo handedOff := false handoff := func() error { handedOff = true handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ Action: &extensions_ssh.SSHChannelControlAction_HandOff{ HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ DownstreamChannelInfo: downstreamChannelInfo, DownstreamPtyInfo: downstreamPtyInfo, UpstreamAuth: &extensions_ssh.AllowResponse{ // XXX Username: "demo", Hostname: "ssh", }, }, }, }) fmt.Println(" *** sending handoff request *** ") return server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_ChannelControl{ ChannelControl: &extensions_ssh.ChannelControl{ Protocol: "ssh", ControlAction: handOff, }, }, }) } for { channelMsg, err := server.Recv() if err != nil { if errors.Is(err, io.EOF) { return nil } return err } if handedOff { continue } rawMsg := channelMsg.GetRawBytes().GetValue() fmt.Printf(" *** channelMsg: %x\n", rawMsg) switch rawMsg[0] { case msgChannelOpen: var msg channelOpenMsg gossh.Unmarshal(rawMsg, &msg) var confirm channelOpenConfirmMsg peerId = msg.PeersID confirm.PeersID = peerId confirm.MyID = 1 confirm.MyWindow = msg.PeersWindow confirm.MaxPacketSize = msg.MaxPacketSize downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{ ChannelType: msg.ChanType, DownstreamChannelId: confirm.PeersID, InternalUpstreamChannelId: confirm.MyID, InitialWindowSize: confirm.MyWindow, MaxPacketSize: confirm.MaxPacketSize, } if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(confirm), }, }, }); err != nil { return err } case msgChannelRequest: var msg channelRequestMsg gossh.Unmarshal(rawMsg, &msg) fmt.Println(" *** SSH_MSG_CHANNEL_REQUEST: ", msg.Request) switch msg.Request { case "env": // ignore for now case "pty-req": req := parsePtyReq(msg.RequestSpecificData) downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ TermEnv: req.TermEnv, WidthColumns: req.Width, HeightRows: req.Height, WidthPx: req.WidthPx, HeightPx: req.HeightPx, Modes: req.Modes, } if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(channelRequestSuccessMsg{ PeersID: peerId, }), }, }, }); err != nil { return err } if err := handoff(); err != nil { return err } case "subsystem": subsystem := parseString(msg.RequestSpecificData) command, isInternal := strings.CutPrefix(subsystem, "pomerium") if isInternal { command = strings.TrimSpace(command) if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(channelRequestSuccessMsg{ PeersID: peerId, }), }, }, }); err != nil { return err } return a.serveInternalCommand(server, peerId, command) } if err := handoff(); err != nil { return err } default: // We're not interested in hijacking any other kinds of session. if err := handoff(); err != nil { return err } } case msgChannelData: var msg channelDataMsg gossh.Unmarshal(rawMsg, &msg) // ignore any data from the client (for now) case msgChannelClose: var msg channelDataMsg gossh.Unmarshal(rawMsg, &msg) default: panic("unhandled message: " + fmt.Sprint(rawMsg[1])) } } } /*func (a *Authorize) ServeChannel( server extensions_ssh.StreamManagement_ServeChannelServer, ) error { var program *tea.Program inputR, inputW := io.Pipe() outputR, outputW := io.Pipe() var peerId uint32 var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo for { channelMsg, err := server.Recv() if err != nil { if errors.Is(err, io.EOF) { return nil } return err } rawMsg := channelMsg.GetRawBytes().GetValue() fmt.Printf(" *** channelMsg: %x\n", rawMsg) switch rawMsg[0] { case msgChannelOpen: var msg channelOpenMsg gossh.Unmarshal(rawMsg, &msg) var confirm channelOpenConfirmMsg peerId = msg.PeersID confirm.PeersID = peerId confirm.MyID = 1 confirm.MyWindow = msg.PeersWindow confirm.MaxPacketSize = msg.MaxPacketSize downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{ ChannelType: msg.ChanType, DownstreamChannelId: confirm.PeersID, InternalUpstreamChannelId: confirm.MyID, InitialWindowSize: confirm.MyWindow, MaxPacketSize: confirm.MaxPacketSize, } if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(confirm), }, }, }); err != nil { return err } case msgChannelRequest: var msg channelRequestMsg gossh.Unmarshal(rawMsg, &msg) fmt.Println(" *** SSH_MSG_CHANNEL_REQUEST: ", msg.Request) switch msg.Request { case "pty-req": req := parsePtyReq(msg.RequestSpecificData) items := []list.Item{ item("ubuntu@vm"), item("joe@local"), } downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{ TermEnv: req.TermEnv, WidthColumns: req.Width, HeightRows: req.Height, WidthPx: req.WidthPx, HeightPx: req.HeightPx, Modes: req.Modes, } const defaultWidth = 20 l := list.New(items, itemDelegate{}, defaultWidth, listHeight) l.Title = "Connect to which server?" l.SetShowStatusBar(false) l.SetFilteringEnabled(false) l.Styles.Title = titleStyle l.Styles.PaginationStyle = paginationStyle l.Styles.HelpStyle = helpStyle program = tea.NewProgram(model{list: l}, tea.WithInput(inputR), tea.WithOutput(outputW), tea.WithAltScreen(), tea.WithContext(server.Context()), tea.WithEnvironment([]string{"TERM=" + req.TermEnv}), ) go func() { answer, err := program.Run() if err != nil { return } username, hostname, _ := strings.Cut(answer.(model).choice, "@") handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{ Action: &extensions_ssh.SSHChannelControlAction_HandOff{ HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{ DownstreamChannelInfo: downstreamChannelInfo, DownstreamPtyInfo: downstreamPtyInfo, UpstreamAuth: &extensions_ssh.AllowResponse{ Username: username, Hostname: hostname, }, }, }, }) if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_ChannelControl{ ChannelControl: &extensions_ssh.ChannelControl{ Protocol: "ssh", ControlAction: handOff, }, }, }); err != nil { return } }() go func() { var buf [4096]byte for { n, err := outputR.Read(buf[:]) if err != nil { return } if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(channelDataMsg{ PeersID: peerId, Length: uint32(n), Rest: buf[:n], }), }, }, }); err != nil { return } } }() program.Send(tea.WindowSizeMsg{Width: int(req.Width), Height: int(req.Height)}) if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(channelRequestSuccessMsg{ PeersID: peerId, }), }, }, }); err != nil { return err } } case msgChannelData: var msg channelDataMsg gossh.Unmarshal(rawMsg, &msg) if program != nil { inputW.Write(msg.Rest) } case msgChannelClose: var msg channelDataMsg gossh.Unmarshal(rawMsg, &msg) default: panic("unhandled message: " + fmt.Sprint(rawMsg[1])) } } }*/ var whoamiTmpl = template.Must(template.New("whoami").Parse(` User ID: {{.UserId}} Session ID: {{.Id}} Expires at: {{.ExpiresAt.AsTime}} Claims: {{- range $k, $v := .Claims }} {{ $k }}: {{ $v.AsSlice }} {{- end }} `)) func (a *Authorize) serveInternalCommand( server extensions_ssh.StreamManagement_ServeChannelServer, peerID uint32, command string, ) error { md, ok := metadata.FromIncomingContext(server.Context()) fmt.Println("metadata.FromIncomingContext: ", md, ok) var sessionID string if h := md.Get("pomerium-session-id"); len(h) == 1 { sessionID = h[0] } var output string switch command { case "logout": client := a.state.Load().dataBrokerClient err := session.Delete(server.Context(), client, sessionID) if err != nil { output = fmt.Sprint("internal error: ", err.Error()) } else { output = "logged out\n" } case "whoami": client := a.state.Load().dataBrokerClient s, err := session.Get(server.Context(), client, sessionID) if err != nil { output = fmt.Sprint("couldn't fetch session: ", err.Error()) } else { var b bytes.Buffer whoamiTmpl.Execute(&b, s) output = b.String() } default: output = `available commands: logout - ends the current Pomerium session whoami - returns information about the current Pomerium session ` } if err := server.Send(&extensions_ssh.ChannelMessage{ Message: &extensions_ssh.ChannelMessage_RawBytes{ RawBytes: &wrapperspb.BytesValue{ Value: gossh.Marshal(channelDataMsg{ PeersID: peerID, Length: uint32(len(output)), Rest: []byte(output), }), }, }, }); err != nil { return err } time.Sleep(time.Second) // XXX return nil } type ptyReq struct { TermEnv string Width, Height uint32 WidthPx, HeightPx uint32 Modes []byte } func parseString(reqData []byte) string { stringLen := binary.BigEndian.Uint32(reqData) reqData = reqData[4:] return string(reqData[:stringLen]) } func parsePtyReq(reqData []byte) ptyReq { termEnvLen := binary.BigEndian.Uint32(reqData) reqData = reqData[4:] termEnv := string(reqData[:termEnvLen]) reqData = reqData[termEnvLen:] return ptyReq{ TermEnv: termEnv, Width: binary.BigEndian.Uint32(reqData), Height: binary.BigEndian.Uint32(reqData[4:]), WidthPx: binary.BigEndian.Uint32(reqData[8:]), HeightPx: binary.BigEndian.Uint32(reqData[12:]), Modes: reqData[16:], } } const listHeight = 14 var ( titleStyle = lipgloss.NewStyle().MarginLeft(2) itemStyle = lipgloss.NewStyle().PaddingLeft(4) selectedItemStyle = lipgloss.NewStyle().PaddingLeft(2).Foreground(lipgloss.Color("170")) paginationStyle = list.DefaultStyles().PaginationStyle.PaddingLeft(4) helpStyle = list.DefaultStyles().HelpStyle.PaddingLeft(4).PaddingBottom(1) quitTextStyle = lipgloss.NewStyle().Margin(1, 0, 2, 4) ) type item string func (i item) FilterValue() string { return "" } type itemDelegate struct{} func (d itemDelegate) Height() int { return 1 } func (d itemDelegate) Spacing() int { return 0 } func (d itemDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil } func (d itemDelegate) Render(w io.Writer, m list.Model, index int, listItem list.Item) { i, ok := listItem.(item) if !ok { return } str := fmt.Sprintf("%d. %s", index+1, i) fn := itemStyle.Render if index == m.Index() { fn = func(s ...string) string { return selectedItemStyle.Render("> " + strings.Join(s, " ")) } } fmt.Fprint(w, fn(str)) } type model struct { list list.Model choice string quitting bool } func (m model) Init() tea.Cmd { return nil } func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: m.list.SetWidth(msg.Width) return m, nil case tea.KeyMsg: switch keypress := msg.String(); keypress { case "q", "ctrl+c": m.quitting = true return m, tea.Quit case "enter": i, ok := m.list.SelectedItem().(item) if ok { m.choice = string(i) } return m, tea.Quit } } var cmd tea.Cmd m.list, cmd = m.list.Update(msg) return m, cmd } func (m model) View() string { if m.choice != "" { return quitTextStyle.Render(fmt.Sprintf("%s? Sounds good to me.", m.choice)) } if m.quitting { return quitTextStyle.Render("Not hungry? That’s cool.") } return "\n" + m.list.View() }