mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
1454 lines
40 KiB
Go
1454 lines
40 KiB
Go
package authorize
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/url"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"text/template"
|
|
"time"
|
|
|
|
"github.com/charmbracelet/bubbles/list"
|
|
tea "github.com/charmbracelet/bubbletea"
|
|
"github.com/charmbracelet/lipgloss"
|
|
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
|
"github.com/klauspost/compress/zstd"
|
|
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
|
extensions_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording"
|
|
"github.com/pomerium/pomerium/authorize/evaluator"
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/sessions"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/identity"
|
|
"github.com/pomerium/pomerium/pkg/identity/manager"
|
|
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
"github.com/spf13/cobra"
|
|
gossh "golang.org/x/crypto/ssh"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/sync/errgroup"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/encoding/protodelim"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
|
)
|
|
|
|
type StreamState struct {
|
|
Username string
|
|
Hostname string
|
|
PublicKey []byte
|
|
MethodsAuthenticated []string
|
|
}
|
|
|
|
func (a *Authorize) RecordingFinalized(
|
|
stream grpc.ClientStreamingServer[extensions_session_recording.RecordingData, emptypb.Empty],
|
|
) error {
|
|
msg, err := stream.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
md := msg.GetMetadata()
|
|
if md == nil {
|
|
return fmt.Errorf("first message did not contain metadata")
|
|
}
|
|
log.Ctx(stream.Context()).Info().Str("info", protojson.Format(md)).Msg("new recording")
|
|
|
|
var recording []byte
|
|
READ:
|
|
for {
|
|
msg, err := stream.Recv()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
switch data := msg.Data.(type) {
|
|
case *extensions_session_recording.RecordingData_Chunk:
|
|
recording = append(recording, data.Chunk...)
|
|
case *extensions_session_recording.RecordingData_Checksum:
|
|
actual := sha256.Sum256(recording)
|
|
if actual != [32]byte(data.Checksum) {
|
|
return fmt.Errorf("checksum mismatch")
|
|
}
|
|
break READ
|
|
}
|
|
}
|
|
|
|
r, err := zstd.NewReader(bytes.NewReader(recording))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create zstd reader: %w", err)
|
|
}
|
|
|
|
switch md.Format {
|
|
case extensions_session_recording.Format_AsciicastFormat:
|
|
log.Ctx(stream.Context()).Info().Int("compressed_size", len(recording)).Msg("asciicast recording received")
|
|
case extensions_session_recording.Format_RawFormat:
|
|
reader := bufio.NewReader(r)
|
|
var header extensions_session_recording.Header
|
|
if err := protodelim.UnmarshalFrom(reader, &header); err != nil {
|
|
return fmt.Errorf("failed to unmarshal header: %w", err)
|
|
}
|
|
|
|
var packets []*extensions_session_recording.Packet
|
|
for {
|
|
var packet extensions_session_recording.Packet
|
|
err := protodelim.UnmarshalFrom(reader, &packet)
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
return fmt.Errorf("failed to unmarshal packet: %w", err)
|
|
}
|
|
packets = append(packets, &packet)
|
|
}
|
|
|
|
log.Ctx(stream.Context()).Info().Int("compressed_size", len(recording)).Int("packet_count", len(packets)).Msg("recording received")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var activeStreamIds sync.Map
|
|
|
|
func (a *Authorize) ManageStream(
|
|
server extensions_ssh.StreamManagement_ManageStreamServer,
|
|
) error {
|
|
recvC := make(chan *extensions_ssh.ClientMessage, 32)
|
|
sendC := make(chan *extensions_ssh.ServerMessage, 32)
|
|
eg, ctx := errgroup.WithContext(server.Context())
|
|
eg.Go(func() error {
|
|
defer close(recvC)
|
|
for {
|
|
req, err := server.Recv()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
recvC <- req
|
|
}
|
|
})
|
|
|
|
// XXX
|
|
querier := storage.NewCachingQuerier(
|
|
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
|
a.globalCache,
|
|
)
|
|
ctx = storage.WithQuerier(ctx, querier)
|
|
|
|
eg.Go(func() error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case msg := <-sendC:
|
|
if err := server.Send(msg); err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
var state StreamState
|
|
|
|
deviceAuthDone := make(chan struct{})
|
|
sessionState := &atomic.Pointer[sessions.State]{}
|
|
|
|
errC := make(chan error, 1)
|
|
a.activeStreamsMu.Lock()
|
|
a.activeStreams = append(a.activeStreams, errC)
|
|
a.activeStreamsMu.Unlock()
|
|
for {
|
|
select {
|
|
case err := <-errC:
|
|
return err
|
|
case req, ok := <-recvC:
|
|
if !ok {
|
|
return nil
|
|
}
|
|
switch req := req.Message.(type) {
|
|
case *extensions_ssh.ClientMessage_Event:
|
|
switch event := req.Event.Event.(type) {
|
|
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
|
_ = event
|
|
case *extensions_ssh.StreamEvent_UpstreamConnected:
|
|
activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state)
|
|
defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId())
|
|
case nil:
|
|
}
|
|
case *extensions_ssh.ClientMessage_AuthRequest:
|
|
authReq := req.AuthRequest
|
|
if state.Username == "" {
|
|
state.Username = authReq.Username
|
|
}
|
|
if state.Hostname == "" {
|
|
state.Hostname = authReq.Hostname
|
|
}
|
|
switch authReq.AuthMethod {
|
|
case "publickey":
|
|
methodReq, _ := authReq.MethodRequest.UnmarshalNew()
|
|
pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
|
|
if !ok {
|
|
return fmt.Errorf("client sent invalid auth request message")
|
|
}
|
|
|
|
//
|
|
// validate public key here
|
|
//
|
|
session, err := a.GetPomeriumSession(ctx, pubkeyReq.PublicKey)
|
|
if err != nil {
|
|
return err // XXX: wrap this error?
|
|
}
|
|
|
|
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "publickey")
|
|
state.PublicKey = pubkeyReq.PublicKey
|
|
|
|
if authReq.Username == "" {
|
|
return fmt.Errorf("no username given")
|
|
}
|
|
|
|
if session != nil {
|
|
// Perform authorize check for this route
|
|
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
res, err := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id)
|
|
|
|
if res.Allow.Value && !res.Deny.Value {
|
|
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
|
}
|
|
}
|
|
|
|
if session == nil && !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
|
|
resp := extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
|
Deny: &extensions_ssh.DenyResponse{
|
|
Partial: true,
|
|
Methods: []string{"keyboard-interactive"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
sendC <- &resp
|
|
}
|
|
case "keyboard-interactive":
|
|
route := a.getSSHRouteForHostname(state.Hostname)
|
|
// if route == nil {
|
|
// return fmt.Errorf("invalid route")
|
|
// }
|
|
|
|
opts := a.currentConfig.Load().Options
|
|
idp, err := opts.GetIdentityProviderForPolicy(route)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
authenticator, err := identity.NewAuthenticator(ctx, a.tracerProvider, oauth.Options{
|
|
RedirectURL: &url.URL{},
|
|
ProviderName: idp.GetType(),
|
|
ProviderURL: idp.GetUrl(),
|
|
ClientID: idp.GetClientId(),
|
|
ClientSecret: idp.GetClientSecret(),
|
|
Scopes: idp.GetScopes(),
|
|
AuthCodeOptions: idp.GetRequestParams(),
|
|
DeviceAuthClientType: idp.GetDeviceAuthClientType(),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
deviceAuthResp, err := authenticator.DeviceAuth(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
|
|
Name: "Sign in with " + idp.GetType(),
|
|
Instruction: deviceAuthResp.VerificationURIComplete,
|
|
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
|
|
// {},
|
|
},
|
|
}
|
|
|
|
infoReqAny, _ := anypb.New(&infoReq)
|
|
resp := extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
|
|
InfoRequest: &extensions_ssh.InfoRequest{
|
|
Method: "keyboard-interactive",
|
|
Request: infoReqAny,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
sendC <- &resp
|
|
|
|
go func() {
|
|
var claims identity.SessionClaims
|
|
|
|
token, err := authenticator.DeviceAccessToken(ctx, deviceAuthResp, &claims)
|
|
if err != nil {
|
|
errC <- err
|
|
return
|
|
}
|
|
s := sessions.NewState(idp.Id)
|
|
claims.Claims.Claims(&s) // XXX
|
|
s.ID, err = getSessionIDForSSH(state.PublicKey)
|
|
if err != nil {
|
|
errC <- err
|
|
return
|
|
}
|
|
fmt.Println(token)
|
|
err = a.PersistSession(ctx, s, claims, token)
|
|
if err != nil {
|
|
fmt.Println("error from PersistSession:", err)
|
|
errC <- fmt.Errorf("error persisting session: %w", err)
|
|
return
|
|
}
|
|
sessionState.Store(s)
|
|
close(deviceAuthDone)
|
|
}()
|
|
}
|
|
case *extensions_ssh.ClientMessage_InfoResponse:
|
|
resp := req.InfoResponse
|
|
if resp.Method == "keyboard-interactive" {
|
|
r, _ := resp.Response.UnmarshalNew()
|
|
respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
|
|
if ok {
|
|
fmt.Println(respInfo.Responses)
|
|
}
|
|
}
|
|
select {
|
|
case <-deviceAuthDone:
|
|
case <-ctx.Done():
|
|
}
|
|
if sessionState.Load() != nil {
|
|
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
|
|
} else {
|
|
resp := extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
|
Deny: &extensions_ssh.DenyResponse{},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
sendC <- &resp
|
|
continue
|
|
}
|
|
|
|
if slices.Contains(state.MethodsAuthenticated, "publickey") {
|
|
// Perform authorize check for this route
|
|
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
res, err := a.evaluate(ctx, req, sessionState.Load())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID)
|
|
|
|
if res.Allow.Value && !res.Deny.Value {
|
|
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
|
|
}
|
|
} else {
|
|
resp := extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
|
Deny: &extensions_ssh.DenyResponse{
|
|
Partial: true,
|
|
Methods: []string{"publickey"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
sendC <- &resp
|
|
}
|
|
|
|
case nil:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *Authorize) getSSHRouteForHostname(hostname string) *config.Policy {
|
|
opts := a.currentConfig.Load().Options
|
|
from := "ssh://" + strings.TrimSuffix(strings.Join([]string{hostname, opts.SSHHostname}, "."), ".")
|
|
for r := range opts.GetAllPolicies() {
|
|
if r.From == from {
|
|
return r
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *Authorize) GetPomeriumSession(
|
|
ctx context.Context, publicKey []byte,
|
|
) (*session.Session, error) {
|
|
sessionID, err := getSessionIDForSSH(publicKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fmt.Println("session ID:", sessionID) // XXX
|
|
|
|
session, err := session.Get(ctx, a.GetDataBrokerServiceClient(), sessionID)
|
|
if err != nil {
|
|
if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
func getSessionIDForSSH(publicKey []byte) (string, error) {
|
|
// XXX: get the fingerprint from Envoy rather than computing it here
|
|
k, err := gossh.ParsePublicKey(publicKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("couldn't parse ssh key: %w", err)
|
|
}
|
|
return "sshkey-" + gossh.FingerprintSHA256(k), nil
|
|
}
|
|
|
|
func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
|
state *StreamState,
|
|
) (*evaluator.Request, error) {
|
|
sessionID, err := getSessionIDForSSH(state.PublicKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
route := a.getSSHRouteForHostname(state.Hostname)
|
|
if route == nil {
|
|
return &evaluator.Request{
|
|
IsInternal: true,
|
|
Session: evaluator.RequestSession{
|
|
ID: sessionID,
|
|
},
|
|
}, nil
|
|
|
|
// return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
|
|
}
|
|
req := &evaluator.Request{
|
|
IsInternal: false,
|
|
HTTP: evaluator.RequestHTTP{
|
|
Hostname: route.From, // XXX: this is not quite right
|
|
// IP: ? // TODO
|
|
},
|
|
Session: evaluator.RequestSession{
|
|
ID: sessionID,
|
|
},
|
|
Policy: route,
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func handleEvaluatorResponseForSSH(
|
|
result *evaluator.Result, state *StreamState, sessionID string,
|
|
) *extensions_ssh.ServerMessage {
|
|
// fmt.Printf(" *** evaluator result: %+v\n", result)
|
|
|
|
// TODO: ideally there would be a way to keep this in sync with the logic in check_response.go
|
|
allow := result.Allow.Value && !result.Deny.Value
|
|
|
|
if allow {
|
|
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
|
|
|
|
if state.Hostname == "" {
|
|
return &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
|
Allow: &extensions_ssh.AllowResponse{
|
|
Username: state.Username,
|
|
Target: &extensions_ssh.AllowResponse_Internal{
|
|
Internal: &extensions_ssh.InternalTarget{
|
|
SetMetadata: &corev3.Metadata{
|
|
FilterMetadata: map[string]*structpb.Struct{
|
|
"pomerium": {
|
|
Fields: map[string]*structpb.Value{
|
|
"session-id": structpb.NewStringValue(sessionID),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
|
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
|
|
Format: extensions_session_recording.Format_AsciicastFormat,
|
|
})
|
|
return &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
|
Allow: &extensions_ssh.AllowResponse{
|
|
Username: state.Username,
|
|
Target: &extensions_ssh.AllowResponse_Upstream{
|
|
Upstream: &extensions_ssh.UpstreamTarget{
|
|
Hostname: state.Hostname,
|
|
AllowedMethods: []*extensions_ssh.AllowedMethod{
|
|
{
|
|
Method: "publickey",
|
|
MethodData: pkData,
|
|
},
|
|
{
|
|
Method: "keyboard-interactive",
|
|
},
|
|
},
|
|
Extensions: []*corev3.TypedExtensionConfig{
|
|
{
|
|
TypedConfig: sessionRecordingExt,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// XXX: do we want to send an equivalent to the "show error details" output
|
|
// in the case of a deny result?
|
|
|
|
// XXX: this is not quite right -- needs to exactly match the last list of methods
|
|
methods := []string{"publickey"}
|
|
if slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
|
|
methods = append(methods, "keyboard-interactive")
|
|
}
|
|
|
|
return &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
|
Deny: &extensions_ssh.DenyResponse{
|
|
Methods: methods,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse {
|
|
return &extensions_ssh.PublicKeyAllowResponse{
|
|
PublicKey: publicKey,
|
|
Permissions: &extensions_ssh.Permissions{
|
|
PermitPortForwarding: true,
|
|
PermitAgentForwarding: true,
|
|
PermitX11Forwarding: true,
|
|
PermitPty: true,
|
|
PermitUserRc: true,
|
|
ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)),
|
|
// XXX: tie this to Pomerium session lifetime?
|
|
ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)),
|
|
},
|
|
}
|
|
}
|
|
|
|
// PersistSession stores session and user data in the databroker.
|
|
func (a *Authorize) PersistSession(
|
|
ctx context.Context,
|
|
sessionState *sessions.State, // XXX: consider not using this struct
|
|
claims identity.SessionClaims,
|
|
accessToken *oauth2.Token,
|
|
) error {
|
|
now := time.Now()
|
|
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
|
|
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
|
|
|
|
sess := &session.Session{
|
|
Id: sessionState.ID,
|
|
UserId: sessionState.UserID(),
|
|
IssuedAt: timestamppb.New(now),
|
|
AccessedAt: timestamppb.New(now),
|
|
ExpiresAt: sessionExpiry,
|
|
OauthToken: manager.ToOAuthToken(accessToken),
|
|
Audience: sessionState.Audience,
|
|
}
|
|
sess.SetRawIDToken(claims.RawIDToken)
|
|
sess.AddClaims(claims.Flatten())
|
|
|
|
// XXX: do we need to create a user record too?
|
|
// compare with Stateful.PersistSession()
|
|
|
|
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sessionState.DatabrokerServerVersion = res.GetServerVersion()
|
|
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *Authorize) startContinuousAuthorization(
|
|
ctx context.Context,
|
|
errC chan<- error,
|
|
req *evaluator.Request,
|
|
sessionID string,
|
|
) {
|
|
recheck := func() {
|
|
// XXX: probably want to log the results of this evaluation only if it changes
|
|
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
|
|
if !res.Allow.Value || res.Deny.Value {
|
|
errC <- fmt.Errorf("no longer authorized")
|
|
}
|
|
}
|
|
|
|
ticker := time.NewTicker(time.Second)
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
recheck()
|
|
case <-ctx.Done():
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func marshalAny(msg proto.Message) *anypb.Any {
|
|
a, err := anypb.New(msg)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return a
|
|
}
|
|
|
|
// sentinel error to indicate that the command triggered a handoff, and we
|
|
// should not automatically disconnect
|
|
var ErrHandoff = errors.New("handoff")
|
|
|
|
func (a *Authorize) ServeChannel(
|
|
server extensions_ssh.StreamManagement_ServeChannelServer,
|
|
) error {
|
|
ctx := server.Context()
|
|
inputR, inputW := io.Pipe()
|
|
outputR, outputW := io.Pipe()
|
|
var peerId uint32
|
|
var activeProgram atomic.Pointer[tea.Program]
|
|
|
|
errC := make(chan error, 1)
|
|
remoteWindow := &window{Cond: sync.NewCond(&sync.Mutex{})}
|
|
sendC := make(chan any, 8)
|
|
recvC := make(chan *extensions_ssh.ChannelMessage)
|
|
go func() {
|
|
for {
|
|
select {
|
|
case msg := <-sendC:
|
|
switch msg := msg.(type) {
|
|
case *extensions_ssh.ChannelControl:
|
|
log.Ctx(ctx).Debug().Msg("sending channel control message")
|
|
if err := server.Send(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_ChannelControl{
|
|
ChannelControl: msg,
|
|
},
|
|
}); err != nil {
|
|
errC <- err
|
|
return
|
|
}
|
|
case windowAdjustMsg, channelRequestMsg, channelRequestSuccessMsg, channelRequestFailureMsg, channelEOFMsg:
|
|
// these messages don't consume window space
|
|
data := gossh.Marshal(msg)
|
|
if err := server.Send(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
|
RawBytes: wrapperspb.Bytes(data),
|
|
},
|
|
}); err != nil {
|
|
errC <- err
|
|
return
|
|
}
|
|
log.Ctx(ctx).Debug().Uint8("type", data[0]).Msg("message sent")
|
|
default:
|
|
data := gossh.Marshal(msg)
|
|
need := uint32(len(data))
|
|
have := uint32(0)
|
|
for have < need {
|
|
n, err := remoteWindow.reserve(need - have)
|
|
if err != nil {
|
|
errC <- err
|
|
return
|
|
}
|
|
have += n
|
|
}
|
|
if err := server.Send(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
|
RawBytes: wrapperspb.Bytes(data),
|
|
},
|
|
}); err != nil {
|
|
errC <- err
|
|
return
|
|
}
|
|
log.Ctx(ctx).Debug().Uint8("type", data[0]).Uint32("size", need).Msg("message sent")
|
|
}
|
|
case <-ctx.Done():
|
|
errC <- ctx.Err()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
var sessionID atomic.Pointer[string]
|
|
go func() {
|
|
localWindow := uint32(channelWindowSize)
|
|
for {
|
|
channelMsg, err := server.Recv()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
errC <- nil
|
|
return
|
|
}
|
|
errC <- err
|
|
return
|
|
}
|
|
if sessionID.Load() == nil {
|
|
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
|
|
if !ok {
|
|
errC <- fmt.Errorf("first message was not metadata")
|
|
return
|
|
}
|
|
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue()
|
|
sessionID.Store(&id)
|
|
continue
|
|
}
|
|
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
|
msgLen := uint32(len(raw.RawBytes.GetValue()))
|
|
if msgLen == 0 {
|
|
errC <- status.Errorf(codes.InvalidArgument, "peer sent empty message")
|
|
return
|
|
}
|
|
if msgLen > channelMaxPacket {
|
|
errC <- status.Errorf(codes.ResourceExhausted, "message too large")
|
|
return
|
|
}
|
|
log.Ctx(ctx).Debug().Uint8("type", raw.RawBytes.Value[0]).Uint32("size", msgLen).Msg("message received")
|
|
// peek the first byte to check if we need to deduct from the window
|
|
switch raw.RawBytes.Value[0] {
|
|
case msgChannelWindowAdjust, msgChannelRequest, msgChannelSuccess, msgChannelFailure, msgChannelEOF:
|
|
// these messages don't consume window space
|
|
default:
|
|
if localWindow < msgLen {
|
|
errC <- status.Errorf(codes.ResourceExhausted, "peer sent more bytes than allowed by channel window")
|
|
return
|
|
}
|
|
localWindow -= msgLen
|
|
if localWindow < channelWindowSize/2 {
|
|
log.Ctx(ctx).Debug().Msg("flow control: increasing local window size")
|
|
localWindow += channelWindowSize
|
|
sendC <- windowAdjustMsg{
|
|
PeersID: peerId,
|
|
AdditionalBytes: channelWindowSize,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
select {
|
|
case recvC <- channelMsg:
|
|
case <-ctx.Done():
|
|
errC <- ctx.Err()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
|
var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo
|
|
var channelIdCounter uint32
|
|
for {
|
|
select {
|
|
case channelMsg := <-recvC:
|
|
rawMsg := channelMsg.GetRawBytes().GetValue()
|
|
switch rawMsg[0] {
|
|
case msgChannelOpen:
|
|
var msg channelOpenMsg
|
|
gossh.Unmarshal(rawMsg, &msg)
|
|
channelIdCounter++
|
|
if channelIdCounter > 1 {
|
|
return fmt.Errorf("only one channel can be opened")
|
|
}
|
|
peerId = msg.PeersID
|
|
downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
|
|
ChannelType: msg.ChanType,
|
|
DownstreamChannelId: peerId,
|
|
InternalUpstreamChannelId: channelIdCounter,
|
|
InitialWindowSize: msg.PeersWindow,
|
|
MaxPacketSize: msg.MaxPacketSize,
|
|
}
|
|
remoteWindow.add(msg.PeersWindow)
|
|
switch msg.ChanType {
|
|
case "session":
|
|
sendC <- channelOpenConfirmMsg{
|
|
PeersID: peerId,
|
|
MyID: channelIdCounter,
|
|
MyWindow: channelWindowSize,
|
|
MaxPacketSize: channelMaxPacket,
|
|
}
|
|
case "direct-tcpip":
|
|
var subMsg channelOpenDirectMsg
|
|
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
|
return err
|
|
}
|
|
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{
|
|
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
|
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
|
DownstreamChannelInfo: downstreamChannelInfo,
|
|
UpstreamAuth: &extensions_ssh.AllowResponse{
|
|
Target: &extensions_ssh.AllowResponse_Upstream{
|
|
Upstream: &extensions_ssh.UpstreamTarget{
|
|
Hostname: subMsg.DestAddr,
|
|
DirectTcpip: true,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
if err := server.Send(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_ChannelControl{
|
|
ChannelControl: &extensions_ssh.ChannelControl{
|
|
Protocol: "ssh",
|
|
ControlAction: handOff,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
case msgChannelRequest:
|
|
var msg channelRequestMsg
|
|
gossh.Unmarshal(rawMsg, &msg)
|
|
|
|
switch msg.Request {
|
|
case "shell", "exec":
|
|
if err := server.Send(&extensions_ssh.ChannelMessage{
|
|
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
|
RawBytes: &wrapperspb.BytesValue{
|
|
Value: gossh.Marshal(channelRequestSuccessMsg{
|
|
PeersID: peerId,
|
|
}),
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram)
|
|
if msg.Request == "shell" {
|
|
cmd.SetArgs([]string{"portal"})
|
|
} else {
|
|
var execReq execChannelRequestMsg
|
|
if err := gossh.Unmarshal(msg.RequestSpecificData, &execReq); err != nil {
|
|
return err
|
|
}
|
|
cmd.SetArgs(strings.Fields(execReq.Command))
|
|
}
|
|
go func() {
|
|
defer activeProgram.Store(nil)
|
|
defer outputW.Close()
|
|
defer inputR.Close()
|
|
err := cmd.Execute()
|
|
if !errors.Is(err, ErrHandoff) {
|
|
sendC <- &extensions_ssh.ChannelControl{
|
|
Protocol: "ssh",
|
|
ControlAction: marshalAny(&extensions_ssh.SSHChannelControlAction_Disconnect{
|
|
ReasonCode: 11,
|
|
}),
|
|
}
|
|
}
|
|
}()
|
|
go streamOutputToChannel(sendC, peerId, outputR)
|
|
|
|
case "pty-req":
|
|
req := parsePtyReq(msg.RequestSpecificData)
|
|
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
|
|
TermEnv: req.TermEnv,
|
|
WidthColumns: req.Width,
|
|
HeightRows: req.Height,
|
|
WidthPx: req.WidthPx,
|
|
HeightPx: req.HeightPx,
|
|
Modes: req.Modes,
|
|
}
|
|
sendC <- channelRequestSuccessMsg{PeersID: peerId}
|
|
case "window-change":
|
|
var req channelWindowChangeRequestMsg
|
|
if err := gossh.Unmarshal(msg.RequestSpecificData, &req); err != nil {
|
|
return err
|
|
}
|
|
if p := activeProgram.Load(); p != nil {
|
|
p.Send(tea.WindowSizeMsg{
|
|
Width: int(req.WidthColumns),
|
|
Height: int(req.HeightRows),
|
|
})
|
|
}
|
|
}
|
|
case msgChannelData:
|
|
var msg channelDataMsg
|
|
gossh.Unmarshal(rawMsg, &msg)
|
|
if activeProgram.Load() != nil {
|
|
inputW.Write(msg.Rest)
|
|
}
|
|
case msgChannelClose:
|
|
var msg channelDataMsg
|
|
gossh.Unmarshal(rawMsg, &msg)
|
|
case msgChannelWindowAdjust:
|
|
var msg windowAdjustMsg
|
|
if err := gossh.Unmarshal(rawMsg, &msg); err != nil {
|
|
return err
|
|
}
|
|
log.Ctx(ctx).Debug().Uint32("bytes", msg.AdditionalBytes).Msg("flow control: remote window size increased")
|
|
remoteWindow.add(msg.AdditionalBytes)
|
|
case msgChannelEOF:
|
|
return nil
|
|
default:
|
|
panic("unhandled message: " + fmt.Sprint(rawMsg[1]))
|
|
}
|
|
case err := <-errC:
|
|
log.Ctx(ctx).Err(err).Msg("channel error")
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
type ptyReq struct {
|
|
TermEnv string
|
|
Width, Height uint32
|
|
WidthPx, HeightPx uint32
|
|
Modes []byte
|
|
}
|
|
|
|
func parsePtyReq(reqData []byte) ptyReq {
|
|
termEnvLen := binary.BigEndian.Uint32(reqData)
|
|
reqData = reqData[4:]
|
|
termEnv := string(reqData[:termEnvLen])
|
|
reqData = reqData[termEnvLen:]
|
|
return ptyReq{
|
|
TermEnv: termEnv,
|
|
Width: binary.BigEndian.Uint32(reqData),
|
|
Height: binary.BigEndian.Uint32(reqData[4:]),
|
|
WidthPx: binary.BigEndian.Uint32(reqData[8:]),
|
|
HeightPx: binary.BigEndian.Uint32(reqData[12:]),
|
|
Modes: reqData[16:],
|
|
}
|
|
}
|
|
|
|
func streamOutputToChannel(sendC chan<- any, channelID uint32, outputR io.Reader) {
|
|
var buf [4096]byte
|
|
for {
|
|
n, err := outputR.Read(buf[:])
|
|
if err != nil {
|
|
return
|
|
}
|
|
sendC <- channelDataMsg{
|
|
PeersID: channelID,
|
|
Length: uint32(n),
|
|
Rest: slices.Clone(buf[:n]),
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *Authorize) NewSSHCLI(
|
|
cfg *config.Config,
|
|
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
|
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
|
sessionID string,
|
|
stdin io.Reader,
|
|
stdout io.Writer,
|
|
sendC chan any,
|
|
activeProgram *atomic.Pointer[tea.Program],
|
|
) *cobra.Command {
|
|
cmd := &cobra.Command{
|
|
Use: "pomerium",
|
|
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
|
_, cmdIsInteractive := cmd.Annotations["interactive"]
|
|
switch {
|
|
case (ptyInfo == nil) && cmdIsInteractive:
|
|
cmd.SilenceUsage = true
|
|
return fmt.Errorf("\x1b[31m'%s' is an interactive command and requires a TTY (try passing '-t' to ssh)\x1b[0m", cmd.Use)
|
|
case (ptyInfo != nil) && !cmdIsInteractive:
|
|
cmd.SilenceUsage = true
|
|
return fmt.Errorf("\x1b[31m'%s' is not an interactive command (try passing '-T' to ssh, or removing '-t')\x1b[0m\r", cmd.Use)
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
|
|
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
|
|
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
|
|
cmd.CompletionOptions.DisableDefaultCmd = true
|
|
cmd.SetIn(stdin)
|
|
cmd.SetOut(stdout)
|
|
cmd.SetErr(stdout)
|
|
return cmd
|
|
}
|
|
|
|
func (a *Authorize) NewLogoutCommand(
|
|
cfg *config.Config,
|
|
sessionID string,
|
|
) *cobra.Command {
|
|
cmd := &cobra.Command{
|
|
Use: "logout",
|
|
Short: "Log out",
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
client := a.state.Load().dataBrokerClient
|
|
err := session.Delete(cmd.Context(), client, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("internal error: %w", err)
|
|
}
|
|
cmd.OutOrStdout().Write([]byte("Logged out successfully\r\n"))
|
|
return nil
|
|
},
|
|
}
|
|
return cmd
|
|
}
|
|
|
|
var whoamiTmpl = template.Must(template.New("whoami").Parse(`
|
|
User ID: {{.UserId}}
|
|
Session ID: {{.Id}}
|
|
Expires at: {{.ExpiresAt.AsTime}}
|
|
Claims:
|
|
{{- range $k, $v := .Claims }}
|
|
{{ $k }}: {{ $v.AsSlice }}
|
|
{{- end }}
|
|
`))
|
|
|
|
func (a *Authorize) NewWhoamiCommand(
|
|
cfg *config.Config,
|
|
sessionID string,
|
|
) *cobra.Command {
|
|
cmd := &cobra.Command{
|
|
Use: "whoami",
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
client := a.state.Load().dataBrokerClient
|
|
s, err := session.Get(cmd.Context(), client, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("couldn't fetch session: %w", err)
|
|
}
|
|
var b bytes.Buffer
|
|
whoamiTmpl.Execute(&b, s)
|
|
cmd.OutOrStdout().Write([]byte(b.String() + "\r\n"))
|
|
return nil
|
|
},
|
|
}
|
|
return cmd
|
|
}
|
|
|
|
func NewPortalCommand(
|
|
cfg *config.Config,
|
|
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
|
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
|
sendC chan any,
|
|
activeProgram *atomic.Pointer[tea.Program],
|
|
) *cobra.Command {
|
|
cmd := &cobra.Command{
|
|
Use: "portal",
|
|
Short: "Interactive route portal",
|
|
Annotations: map[string]string{
|
|
"interactive": "",
|
|
},
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
var routes []string
|
|
for r := range cfg.Options.GetAllPolicies() {
|
|
if strings.HasPrefix(r.From, "ssh://") {
|
|
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
|
|
}
|
|
}
|
|
items := []list.Item{}
|
|
for _, route := range routes {
|
|
items = append(items, item(route))
|
|
}
|
|
activeStreamIds.Range(func(key, value any) bool {
|
|
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
|
|
return true
|
|
})
|
|
|
|
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
|
|
l.Title = "Connect to which server?"
|
|
l.SetShowStatusBar(false)
|
|
l.SetFilteringEnabled(false)
|
|
l.Styles.Title = titleStyle
|
|
l.Styles.PaginationStyle = paginationStyle
|
|
l.Styles.HelpStyle = helpStyle
|
|
|
|
program := tea.NewProgram(model{list: l},
|
|
tea.WithInput(cmd.InOrStdin()),
|
|
tea.WithOutput(cmd.OutOrStdout()),
|
|
tea.WithAltScreen(),
|
|
tea.WithContext(cmd.Context()),
|
|
tea.WithEnvironment([]string{"TERM=" + ptyInfo.TermEnv}),
|
|
)
|
|
activeProgram.Store(program)
|
|
|
|
go program.Send(tea.WindowSizeMsg{Width: int(ptyInfo.WidthColumns), Height: int(ptyInfo.HeightRows)})
|
|
answer, err := program.Run()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var handOff *anypb.Any
|
|
if strings.HasPrefix(answer.(model).choice, "[demo] mirror session: ") {
|
|
id, err := strconv.ParseUint(strings.TrimPrefix(answer.(model).choice, "[demo] mirror session: "), 10, 64)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{
|
|
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
|
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
|
DownstreamChannelInfo: channelInfo,
|
|
DownstreamPtyInfo: ptyInfo,
|
|
UpstreamAuth: &extensions_ssh.AllowResponse{
|
|
Target: &extensions_ssh.AllowResponse_MirrorSession{
|
|
MirrorSession: &extensions_ssh.MirrorSessionTarget{
|
|
SourceId: id,
|
|
Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
} else {
|
|
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
|
|
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
|
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
|
|
Format: extensions_session_recording.Format_AsciicastFormat,
|
|
})
|
|
handOff = marshalAny(&extensions_ssh.SSHChannelControlAction{
|
|
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
|
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
|
DownstreamChannelInfo: channelInfo,
|
|
DownstreamPtyInfo: ptyInfo,
|
|
UpstreamAuth: &extensions_ssh.AllowResponse{
|
|
Username: username,
|
|
Target: &extensions_ssh.AllowResponse_Upstream{
|
|
Upstream: &extensions_ssh.UpstreamTarget{
|
|
AllowMirrorConnections: true,
|
|
Hostname: hostname,
|
|
Extensions: []*corev3.TypedExtensionConfig{
|
|
{
|
|
TypedConfig: sessionRecordingExt,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
sendC <- &extensions_ssh.ChannelControl{
|
|
Protocol: "ssh",
|
|
ControlAction: handOff,
|
|
}
|
|
return ErrHandoff
|
|
},
|
|
}
|
|
return cmd
|
|
}
|
|
|
|
var (
|
|
titleStyle = lipgloss.NewStyle().MarginLeft(2)
|
|
itemStyle = lipgloss.NewStyle().PaddingLeft(4)
|
|
selectedItemStyle = lipgloss.NewStyle().PaddingLeft(2).Foreground(lipgloss.Color("170"))
|
|
paginationStyle = list.DefaultStyles().PaginationStyle.PaddingLeft(4)
|
|
helpStyle = list.DefaultStyles().HelpStyle.PaddingLeft(4).PaddingBottom(1)
|
|
quitTextStyle = lipgloss.NewStyle().Margin(1, 0, 2, 4)
|
|
)
|
|
|
|
type item string
|
|
|
|
func (i item) FilterValue() string { return "" }
|
|
|
|
type itemDelegate struct{}
|
|
|
|
func (d itemDelegate) Height() int { return 1 }
|
|
func (d itemDelegate) Spacing() int { return 0 }
|
|
func (d itemDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil }
|
|
func (d itemDelegate) Render(w io.Writer, m list.Model, index int, listItem list.Item) {
|
|
i, ok := listItem.(item)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
str := fmt.Sprintf("%d. %s", index+1, i)
|
|
|
|
fn := itemStyle.Render
|
|
if index == m.Index() {
|
|
fn = func(s ...string) string {
|
|
return selectedItemStyle.Render("> " + strings.Join(s, " "))
|
|
}
|
|
}
|
|
|
|
fmt.Fprint(w, fn(str))
|
|
}
|
|
|
|
type model struct {
|
|
list list.Model
|
|
choice string
|
|
quitting bool
|
|
}
|
|
|
|
func (m model) Init() tea.Cmd {
|
|
return nil
|
|
}
|
|
|
|
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|
switch msg := msg.(type) {
|
|
case tea.WindowSizeMsg:
|
|
m.list.SetWidth(msg.Width - 2)
|
|
m.list.SetHeight(msg.Height - 2)
|
|
return m, nil
|
|
|
|
case tea.KeyMsg:
|
|
switch keypress := msg.String(); keypress {
|
|
case "q", "ctrl+c":
|
|
m.quitting = true
|
|
return m, tea.Quit
|
|
|
|
case "enter":
|
|
i, ok := m.list.SelectedItem().(item)
|
|
if ok {
|
|
m.choice = string(i)
|
|
}
|
|
return m, tea.Quit
|
|
}
|
|
}
|
|
|
|
var cmd tea.Cmd
|
|
m.list, cmd = m.list.Update(msg)
|
|
return m, cmd
|
|
}
|
|
|
|
func (m model) View() string {
|
|
return "\n" + m.list.View()
|
|
}
|
|
|
|
// code below copied from x/crypto/ssh/common.go
|
|
|
|
const (
|
|
// channelMaxPacket contains the maximum number of bytes that will be
|
|
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
|
// the minimum.
|
|
channelMaxPacket = 1 << 15
|
|
// We follow OpenSSH here.
|
|
channelWindowSize = 64 * channelMaxPacket
|
|
)
|
|
|
|
// window represents the buffer available to clients
|
|
// wishing to write to a channel.
|
|
type window struct {
|
|
*sync.Cond
|
|
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
|
writeWaiters int
|
|
closed bool
|
|
}
|
|
|
|
// add adds win to the amount of window available
|
|
// for consumers.
|
|
func (w *window) add(win uint32) bool {
|
|
// a zero sized window adjust is a noop.
|
|
if win == 0 {
|
|
return true
|
|
}
|
|
w.L.Lock()
|
|
if w.win+win < win {
|
|
w.L.Unlock()
|
|
return false
|
|
}
|
|
w.win += win
|
|
// It is unusual that multiple goroutines would be attempting to reserve
|
|
// window space, but not guaranteed. Use broadcast to notify all waiters
|
|
// that additional window is available.
|
|
w.Broadcast()
|
|
w.L.Unlock()
|
|
return true
|
|
}
|
|
|
|
// close sets the window to closed, so all reservations fail
|
|
// immediately.
|
|
func (w *window) close() {
|
|
w.L.Lock()
|
|
w.closed = true
|
|
w.Broadcast()
|
|
w.L.Unlock()
|
|
}
|
|
|
|
// reserve reserves win from the available window capacity.
|
|
// If no capacity remains, reserve will block. reserve may
|
|
// return less than requested.
|
|
func (w *window) reserve(win uint32) (uint32, error) {
|
|
var err error
|
|
w.L.Lock()
|
|
w.writeWaiters++
|
|
w.Broadcast()
|
|
for w.win == 0 && !w.closed {
|
|
w.Wait()
|
|
}
|
|
w.writeWaiters--
|
|
if w.win < win {
|
|
win = w.win
|
|
}
|
|
w.win -= win
|
|
if w.closed {
|
|
err = io.EOF
|
|
}
|
|
w.L.Unlock()
|
|
return win, err
|
|
}
|
|
|
|
// code below copied from x/crypto/ssh/messages.go
|
|
// (with some additional messages not included there)
|
|
|
|
// See RFC 4254, section 5.1.
|
|
const msgChannelOpen = 90
|
|
|
|
type channelOpenMsg struct {
|
|
ChanType string `sshtype:"90"`
|
|
PeersID uint32
|
|
PeersWindow uint32
|
|
MaxPacketSize uint32
|
|
TypeSpecificData []byte `ssh:"rest"`
|
|
}
|
|
|
|
const (
|
|
msgChannelExtendedData = 95
|
|
msgChannelData = 94
|
|
)
|
|
|
|
// See RFC 4253, section 11.1.
|
|
const msgDisconnect = 1
|
|
|
|
// disconnectMsg is the message that signals a disconnect. It is also
|
|
// the error type returned from mux.Wait()
|
|
type disconnectMsg struct {
|
|
Reason uint32 `sshtype:"1"`
|
|
Message string
|
|
Language string
|
|
}
|
|
|
|
// Used for debug print outs of packets.
|
|
type channelDataMsg struct {
|
|
PeersID uint32 `sshtype:"94"`
|
|
Length uint32
|
|
Rest []byte `ssh:"rest"`
|
|
}
|
|
|
|
// See RFC 4254, section 5.1.
|
|
const msgChannelOpenConfirm = 91
|
|
|
|
type channelOpenConfirmMsg struct {
|
|
PeersID uint32 `sshtype:"91"`
|
|
MyID uint32
|
|
MyWindow uint32
|
|
MaxPacketSize uint32
|
|
TypeSpecificData []byte `ssh:"rest"`
|
|
}
|
|
|
|
const msgChannelRequest = 98
|
|
|
|
type channelRequestMsg struct {
|
|
PeersID uint32 `sshtype:"98"`
|
|
Request string
|
|
WantReply bool
|
|
RequestSpecificData []byte `ssh:"rest"`
|
|
}
|
|
|
|
type channelOpenDirectMsg struct {
|
|
DestAddr string
|
|
DestPort uint32
|
|
SrcAddr string
|
|
SrcPort uint32
|
|
}
|
|
|
|
type channelWindowChangeRequestMsg struct {
|
|
WidthColumns uint32
|
|
HeightRows uint32
|
|
WidthPx uint32
|
|
HeightPx uint32
|
|
}
|
|
|
|
type shellChannelRequestMsg struct{}
|
|
|
|
type execChannelRequestMsg struct {
|
|
Command string
|
|
}
|
|
|
|
// See RFC 4254, section 5.2
|
|
const msgChannelWindowAdjust = 93
|
|
|
|
type windowAdjustMsg struct {
|
|
PeersID uint32 `sshtype:"93"`
|
|
AdditionalBytes uint32
|
|
}
|
|
|
|
// See RFC 4254, section 5.4.
|
|
const msgChannelSuccess = 99
|
|
|
|
type channelRequestSuccessMsg struct {
|
|
PeersID uint32 `sshtype:"99"`
|
|
}
|
|
|
|
// See RFC 4254, section 5.4.
|
|
const msgChannelFailure = 100
|
|
|
|
type channelRequestFailureMsg struct {
|
|
PeersID uint32 `sshtype:"100"`
|
|
}
|
|
|
|
// See RFC 4254, section 5.3
|
|
const msgChannelClose = 97
|
|
|
|
type channelCloseMsg struct {
|
|
PeersID uint32 `sshtype:"97"`
|
|
}
|
|
|
|
// See RFC 4254, section 5.3
|
|
const msgChannelEOF = 96
|
|
|
|
type channelEOFMsg struct {
|
|
PeersID uint32 `sshtype:"96"`
|
|
}
|