mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
1113 lines
30 KiB
Go
1113 lines
30 KiB
Go
package authorize
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/binary"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/url"
|
||
"slices"
|
||
"strings"
|
||
"sync/atomic"
|
||
"text/template"
|
||
"time"
|
||
|
||
"github.com/charmbracelet/bubbles/list"
|
||
tea "github.com/charmbracelet/bubbletea"
|
||
"github.com/charmbracelet/lipgloss"
|
||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||
"github.com/pomerium/pomerium/config"
|
||
"github.com/pomerium/pomerium/internal/sessions"
|
||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||
"github.com/pomerium/pomerium/pkg/identity"
|
||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||
"github.com/pomerium/pomerium/pkg/storage"
|
||
gossh "golang.org/x/crypto/ssh"
|
||
"golang.org/x/oauth2"
|
||
"golang.org/x/sync/errgroup"
|
||
"google.golang.org/grpc/codes"
|
||
"google.golang.org/grpc/metadata"
|
||
"google.golang.org/grpc/status"
|
||
"google.golang.org/protobuf/types/known/anypb"
|
||
"google.golang.org/protobuf/types/known/structpb"
|
||
"google.golang.org/protobuf/types/known/timestamppb"
|
||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||
)
|
||
|
||
type StreamState struct {
|
||
Username string
|
||
Hostname string
|
||
PublicKey []byte
|
||
MethodsAuthenticated []string
|
||
}
|
||
|
||
func (a *Authorize) ManageStream(
|
||
server extensions_ssh.StreamManagement_ManageStreamServer,
|
||
) error {
|
||
recvC := make(chan *extensions_ssh.ClientMessage, 32)
|
||
sendC := make(chan *extensions_ssh.ServerMessage, 32)
|
||
eg, ctx := errgroup.WithContext(server.Context())
|
||
eg.Go(func() error {
|
||
defer close(recvC)
|
||
for {
|
||
req, err := server.Recv()
|
||
if err != nil {
|
||
if errors.Is(err, io.EOF) {
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
recvC <- req
|
||
}
|
||
})
|
||
|
||
// XXX
|
||
querier := storage.NewCachingQuerier(
|
||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||
a.globalCache,
|
||
)
|
||
ctx = storage.WithQuerier(ctx, querier)
|
||
|
||
eg.Go(func() error {
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return nil
|
||
case msg := <-sendC:
|
||
if err := server.Send(msg); err != nil {
|
||
if errors.Is(err, io.EOF) {
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
})
|
||
|
||
var state StreamState
|
||
|
||
//deviceAuthSuccess := &atomic.Bool{}
|
||
sessionState := &atomic.Pointer[sessions.State]{}
|
||
|
||
errC := make(chan error, 1)
|
||
a.activeStreamsMu.Lock()
|
||
a.activeStreams = append(a.activeStreams, errC)
|
||
a.activeStreamsMu.Unlock()
|
||
for {
|
||
select {
|
||
case err := <-errC:
|
||
|
||
return err
|
||
case req, ok := <-recvC:
|
||
if !ok {
|
||
return nil
|
||
}
|
||
switch req := req.Message.(type) {
|
||
case *extensions_ssh.ClientMessage_Event:
|
||
switch event := req.Event.Event.(type) {
|
||
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
||
fmt.Println("downstream connected")
|
||
_ = event
|
||
case nil:
|
||
}
|
||
case *extensions_ssh.ClientMessage_AuthRequest:
|
||
authReq := req.AuthRequest
|
||
fmt.Println("auth request")
|
||
if state.Username == "" {
|
||
state.Username = authReq.Username
|
||
}
|
||
if state.Hostname == "" {
|
||
state.Hostname = authReq.Hostname
|
||
}
|
||
switch authReq.AuthMethod {
|
||
case "publickey":
|
||
methodReq, _ := authReq.MethodRequest.UnmarshalNew()
|
||
pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
|
||
if !ok {
|
||
return fmt.Errorf("client sent invalid auth request message")
|
||
}
|
||
|
||
//
|
||
// validate public key here
|
||
//
|
||
session, err := a.GetPomeriumSession(ctx, pubkeyReq.PublicKey)
|
||
if err != nil {
|
||
return err // XXX: wrap this error?
|
||
}
|
||
|
||
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "publickey")
|
||
state.PublicKey = pubkeyReq.PublicKey
|
||
|
||
if authReq.Username == "" && authReq.Hostname == "" {
|
||
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
|
||
resp := extensions_ssh.ServerMessage{
|
||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
||
Allow: &extensions_ssh.AllowResponse{
|
||
Username: state.Username,
|
||
Hostname: state.Hostname,
|
||
AllowedMethods: []*extensions_ssh.AllowedMethod{
|
||
{
|
||
Method: "publickey",
|
||
MethodData: pkData,
|
||
},
|
||
},
|
||
Target: extensions_ssh.Target_Internal,
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
sendC <- &resp
|
||
continue
|
||
}
|
||
|
||
if session != nil {
|
||
// Perform authorize check for this route
|
||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
res, err := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
sendC <- handleEvaluatorResponseForSSH(res, &state)
|
||
}
|
||
|
||
if session == nil && !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
|
||
resp := extensions_ssh.ServerMessage{
|
||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
||
Deny: &extensions_ssh.DenyResponse{
|
||
Partial: true,
|
||
Methods: []string{"keyboard-interactive"},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
sendC <- &resp
|
||
}
|
||
case "keyboard-interactive":
|
||
route := a.getSSHRouteForHostname(state.Hostname)
|
||
if route == nil {
|
||
return fmt.Errorf("invalid route")
|
||
}
|
||
|
||
opts := a.currentOptions.Load()
|
||
idp, err := opts.GetIdentityProviderForPolicy(route)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
authenticator, err := identity.NewAuthenticator(ctx, a.tracerProvider, oauth.Options{
|
||
RedirectURL: &url.URL{},
|
||
ProviderName: idp.GetType(),
|
||
ProviderURL: idp.GetUrl(),
|
||
ClientID: idp.GetClientId(),
|
||
ClientSecret: idp.GetClientSecret(),
|
||
Scopes: idp.GetScopes(),
|
||
AuthCodeOptions: idp.GetRequestParams(),
|
||
DeviceAuthClientType: idp.GetDeviceAuthClientType(),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
deviceAuthResp, err := authenticator.DeviceAuth(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
|
||
Name: "Sign in with " + idp.GetType(),
|
||
Instruction: deviceAuthResp.VerificationURIComplete,
|
||
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
|
||
//{}, // XXX: proof of concept (no prompt)
|
||
},
|
||
}
|
||
|
||
infoReqAny, _ := anypb.New(&infoReq)
|
||
resp := extensions_ssh.ServerMessage{
|
||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||
Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
|
||
InfoRequest: &extensions_ssh.InfoRequest{
|
||
Method: "keyboard-interactive",
|
||
Request: infoReqAny,
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
sendC <- &resp
|
||
|
||
go func() {
|
||
var claims identity.SessionClaims
|
||
|
||
token, err := authenticator.DeviceAccessToken(ctx, deviceAuthResp, &claims)
|
||
if err != nil {
|
||
errC <- err
|
||
return
|
||
}
|
||
s := sessions.NewState(idp.Id)
|
||
claims.Claims.Claims(&s) // XXX
|
||
s.ID, err = getSessionIDForSSH(state.PublicKey)
|
||
if err != nil {
|
||
errC <- err
|
||
return
|
||
}
|
||
fmt.Println(token)
|
||
err = a.PersistSession(ctx, s, claims, token)
|
||
if err != nil {
|
||
fmt.Println("error from PersistSession:", err)
|
||
errC <- fmt.Errorf("error persisting session: %w", err)
|
||
return
|
||
}
|
||
sessionState.Store(s)
|
||
}()
|
||
}
|
||
case *extensions_ssh.ClientMessage_InfoResponse:
|
||
resp := req.InfoResponse
|
||
if resp.Method == "keyboard-interactive" {
|
||
r, _ := resp.Response.UnmarshalNew()
|
||
respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
|
||
if ok {
|
||
fmt.Println(respInfo.Responses)
|
||
}
|
||
}
|
||
|
||
// XXX: proof of concept -- busy wait for login to complete
|
||
for sessionState.Load() == nil {
|
||
time.Sleep(time.Second)
|
||
}
|
||
|
||
if sessionState.Load() != nil {
|
||
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
|
||
} else {
|
||
retryReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
|
||
Name: "",
|
||
Instruction: "Login not successful yet, try again",
|
||
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
|
||
{},
|
||
},
|
||
}
|
||
infoReqAny, _ := anypb.New(&retryReq)
|
||
|
||
resp := extensions_ssh.ServerMessage{
|
||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||
Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
|
||
InfoRequest: &extensions_ssh.InfoRequest{
|
||
Method: "keyboard-interactive",
|
||
Request: infoReqAny,
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
sendC <- &resp
|
||
continue
|
||
}
|
||
|
||
if slices.Contains(state.MethodsAuthenticated, "publickey") {
|
||
// Perform authorize check for this route
|
||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
res, err := a.evaluate(ctx, req, sessionState.Load())
|
||
if err != nil {
|
||
return err
|
||
}
|
||
sendC <- handleEvaluatorResponseForSSH(res, &state)
|
||
} else {
|
||
resp := extensions_ssh.ServerMessage{
|
||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
||
Deny: &extensions_ssh.DenyResponse{
|
||
Partial: true,
|
||
Methods: []string{"publickey"},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
sendC <- &resp
|
||
}
|
||
|
||
case nil:
|
||
}
|
||
}
|
||
}
|
||
|
||
return eg.Wait()
|
||
}
|
||
|
||
func (a *Authorize) getSSHRouteForHostname(hostname string) *config.Policy {
|
||
opts := a.currentOptions.Load()
|
||
from := "ssh://" + strings.TrimSuffix(strings.Join([]string{hostname, opts.SSHHostname}, "."), ".")
|
||
for r := range opts.GetAllPolicies() {
|
||
if r.From == from {
|
||
return r
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (a *Authorize) GetPomeriumSession(
|
||
ctx context.Context, publicKey []byte,
|
||
) (*session.Session, error) {
|
||
sessionID, err := getSessionIDForSSH(publicKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
fmt.Println("session ID:", sessionID) // XXX
|
||
|
||
session, err := session.Get(ctx, a.GetDataBrokerServiceClient(), sessionID)
|
||
if err != nil {
|
||
if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound {
|
||
return nil, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
return session, nil
|
||
}
|
||
|
||
func getSessionIDForSSH(publicKey []byte) (string, error) {
|
||
// XXX: get the fingerprint from Envoy rather than computing it here
|
||
k, err := gossh.ParsePublicKey(publicKey)
|
||
if err != nil {
|
||
return "", fmt.Errorf("couldn't parse ssh key: %w", err)
|
||
}
|
||
return "sshkey-" + gossh.FingerprintSHA256(k), nil
|
||
}
|
||
|
||
func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
||
state *StreamState,
|
||
) (*evaluator.Request, error) {
|
||
sessionID, err := getSessionIDForSSH(state.PublicKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
route := a.getSSHRouteForHostname(state.Hostname)
|
||
if route == nil {
|
||
return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
|
||
}
|
||
req := &evaluator.Request{
|
||
IsInternal: false,
|
||
HTTP: evaluator.RequestHTTP{
|
||
Hostname: route.From, // XXX: this is not quite right
|
||
//IP: ? // TODO
|
||
},
|
||
Session: evaluator.RequestSession{
|
||
ID: sessionID,
|
||
},
|
||
Policy: route,
|
||
}
|
||
return req, nil
|
||
}
|
||
|
||
func handleEvaluatorResponseForSSH(
|
||
result *evaluator.Result, state *StreamState,
|
||
) *extensions_ssh.ServerMessage {
|
||
//fmt.Printf(" *** evaluator result: %+v\n", result)
|
||
|
||
// TODO: ideally there would be a way to keep this in sync with the logic in check_response.go
|
||
allow := result.Allow.Value && !result.Deny.Value
|
||
|
||
// XXX
|
||
sessionID, _ := getSessionIDForSSH(state.PublicKey)
|
||
|
||
if allow {
|
||
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
|
||
return &extensions_ssh.ServerMessage{
|
||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
||
Allow: &extensions_ssh.AllowResponse{
|
||
Username: state.Username,
|
||
Hostname: state.Hostname,
|
||
AllowedMethods: []*extensions_ssh.AllowedMethod{
|
||
{
|
||
Method: "publickey",
|
||
MethodData: pkData,
|
||
},
|
||
{
|
||
Method: "keyboard-interactive",
|
||
},
|
||
},
|
||
//Target: extensions_ssh.Target_Upstream,
|
||
Target: extensions_ssh.Target_Internal,
|
||
SetMetadata: &envoy_config_core_v3.Metadata{
|
||
FilterMetadata: map[string]*structpb.Struct{
|
||
"pomerium.ssh": {
|
||
Fields: map[string]*structpb.Value{
|
||
"pomerium-session-id": structpb.NewStringValue(sessionID),
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
// XXX: do we want to send an equivalent to the "show error details" output
|
||
// in the case of a deny result?
|
||
|
||
// XXX: this is not quite right -- needs to exactly match the last list of methods
|
||
methods := []string{"publickey"}
|
||
if slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
|
||
methods = append(methods, "keyboard-interactive")
|
||
}
|
||
|
||
return &extensions_ssh.ServerMessage{
|
||
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
||
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
||
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
||
Deny: &extensions_ssh.DenyResponse{
|
||
Methods: methods,
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse {
|
||
return &extensions_ssh.PublicKeyAllowResponse{
|
||
PublicKey: publicKey,
|
||
Permissions: &extensions_ssh.Permissions{
|
||
PermitPortForwarding: true,
|
||
PermitAgentForwarding: true,
|
||
PermitX11Forwarding: true,
|
||
PermitPty: true,
|
||
PermitUserRc: true,
|
||
ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)),
|
||
// XXX: tie this to Pomerium session lifetime?
|
||
ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)),
|
||
},
|
||
}
|
||
}
|
||
|
||
// PersistSession stores session and user data in the databroker.
|
||
func (a *Authorize) PersistSession(
|
||
ctx context.Context,
|
||
sessionState *sessions.State, // XXX: consider not using this struct
|
||
claims identity.SessionClaims,
|
||
accessToken *oauth2.Token,
|
||
) error {
|
||
now := time.Now()
|
||
sessionLifetime := a.currentOptions.Load().CookieExpire
|
||
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
|
||
|
||
sess := &session.Session{
|
||
Id: sessionState.ID,
|
||
UserId: sessionState.UserID(),
|
||
IssuedAt: timestamppb.New(now),
|
||
AccessedAt: timestamppb.New(now),
|
||
ExpiresAt: sessionExpiry,
|
||
OauthToken: manager.ToOAuthToken(accessToken),
|
||
Audience: sessionState.Audience,
|
||
}
|
||
sess.SetRawIDToken(claims.RawIDToken)
|
||
sess.AddClaims(claims.Flatten())
|
||
|
||
// XXX: do we need to create a user record too?
|
||
// compare with Stateful.PersistSession()
|
||
|
||
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
sessionState.DatabrokerServerVersion = res.GetServerVersion()
|
||
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
||
|
||
return nil
|
||
}
|
||
|
||
// See RFC 4254, section 5.1.
|
||
const msgChannelOpen = 90
|
||
|
||
type channelOpenMsg struct {
|
||
ChanType string `sshtype:"90"`
|
||
PeersID uint32
|
||
PeersWindow uint32
|
||
MaxPacketSize uint32
|
||
TypeSpecificData []byte `ssh:"rest"`
|
||
}
|
||
|
||
const (
|
||
msgChannelExtendedData = 95
|
||
msgChannelData = 94
|
||
)
|
||
|
||
// Used for debug print outs of packets.
|
||
type channelDataMsg struct {
|
||
PeersID uint32 `sshtype:"94"`
|
||
Length uint32
|
||
Rest []byte `ssh:"rest"`
|
||
}
|
||
|
||
// See RFC 4254, section 5.1.
|
||
const msgChannelOpenConfirm = 91
|
||
|
||
type channelOpenConfirmMsg struct {
|
||
PeersID uint32 `sshtype:"91"`
|
||
MyID uint32
|
||
MyWindow uint32
|
||
MaxPacketSize uint32
|
||
TypeSpecificData []byte `ssh:"rest"`
|
||
}
|
||
|
||
const msgChannelRequest = 98
|
||
|
||
type channelRequestMsg struct {
|
||
PeersID uint32 `sshtype:"98"`
|
||
Request string
|
||
WantReply bool
|
||
RequestSpecificData []byte `ssh:"rest"`
|
||
}
|
||
|
||
// See RFC 4254, section 5.4.
|
||
const msgChannelSuccess = 99
|
||
|
||
type channelRequestSuccessMsg struct {
|
||
PeersID uint32 `sshtype:"99"`
|
||
}
|
||
|
||
// See RFC 4254, section 5.4.
|
||
const msgChannelFailure = 100
|
||
|
||
type channelRequestFailureMsg struct {
|
||
PeersID uint32 `sshtype:"100"`
|
||
}
|
||
|
||
// See RFC 4254, section 5.3
|
||
const msgChannelClose = 97
|
||
|
||
type channelCloseMsg struct {
|
||
PeersID uint32 `sshtype:"97"`
|
||
}
|
||
|
||
// See RFC 4254, section 5.3
|
||
const msgChannelEOF = 96
|
||
|
||
type channelEOFMsg struct {
|
||
PeersID uint32 `sshtype:"96"`
|
||
}
|
||
|
||
func (a *Authorize) ServeChannel(
|
||
server extensions_ssh.StreamManagement_ServeChannelServer,
|
||
) error {
|
||
//inputR, inputW := io.Pipe()
|
||
//outputR, outputW := io.Pipe()
|
||
var peerId uint32
|
||
|
||
var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
||
var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo
|
||
|
||
handedOff := false
|
||
|
||
handoff := func() error {
|
||
handedOff = true
|
||
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{
|
||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||
DownstreamChannelInfo: downstreamChannelInfo,
|
||
DownstreamPtyInfo: downstreamPtyInfo,
|
||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||
// XXX
|
||
Username: "demo",
|
||
Hostname: "ssh",
|
||
},
|
||
},
|
||
},
|
||
})
|
||
fmt.Println(" *** sending handoff request *** ")
|
||
return server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_ChannelControl{
|
||
ChannelControl: &extensions_ssh.ChannelControl{
|
||
Protocol: "ssh",
|
||
ControlAction: handOff,
|
||
},
|
||
},
|
||
})
|
||
}
|
||
|
||
for {
|
||
channelMsg, err := server.Recv()
|
||
if err != nil {
|
||
if errors.Is(err, io.EOF) {
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
if handedOff {
|
||
continue
|
||
}
|
||
rawMsg := channelMsg.GetRawBytes().GetValue()
|
||
fmt.Printf(" *** channelMsg: %x\n", rawMsg)
|
||
switch rawMsg[0] {
|
||
case msgChannelOpen:
|
||
var msg channelOpenMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
|
||
var confirm channelOpenConfirmMsg
|
||
peerId = msg.PeersID
|
||
confirm.PeersID = peerId
|
||
confirm.MyID = 1
|
||
confirm.MyWindow = msg.PeersWindow
|
||
confirm.MaxPacketSize = msg.MaxPacketSize
|
||
downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
|
||
ChannelType: msg.ChanType,
|
||
DownstreamChannelId: confirm.PeersID,
|
||
InternalUpstreamChannelId: confirm.MyID,
|
||
InitialWindowSize: confirm.MyWindow,
|
||
MaxPacketSize: confirm.MaxPacketSize,
|
||
}
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||
RawBytes: &wrapperspb.BytesValue{
|
||
Value: gossh.Marshal(confirm),
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
|
||
case msgChannelRequest:
|
||
var msg channelRequestMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
|
||
fmt.Println(" *** SSH_MSG_CHANNEL_REQUEST: ", msg.Request)
|
||
|
||
switch msg.Request {
|
||
case "env":
|
||
// ignore for now
|
||
case "pty-req":
|
||
req := parsePtyReq(msg.RequestSpecificData)
|
||
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
|
||
TermEnv: req.TermEnv,
|
||
WidthColumns: req.Width,
|
||
HeightRows: req.Height,
|
||
WidthPx: req.WidthPx,
|
||
HeightPx: req.HeightPx,
|
||
Modes: req.Modes,
|
||
}
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||
RawBytes: &wrapperspb.BytesValue{
|
||
Value: gossh.Marshal(channelRequestSuccessMsg{
|
||
PeersID: peerId,
|
||
}),
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
if err := handoff(); err != nil {
|
||
return err
|
||
}
|
||
case "subsystem":
|
||
subsystem := parseString(msg.RequestSpecificData)
|
||
command, isInternal := strings.CutPrefix(subsystem, "pomerium")
|
||
if isInternal {
|
||
command = strings.TrimSpace(command)
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||
RawBytes: &wrapperspb.BytesValue{
|
||
Value: gossh.Marshal(channelRequestSuccessMsg{
|
||
PeersID: peerId,
|
||
}),
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
return a.serveInternalCommand(server, peerId, command)
|
||
}
|
||
if err := handoff(); err != nil {
|
||
return err
|
||
}
|
||
default:
|
||
// We're not interested in hijacking any other kinds of session.
|
||
if err := handoff(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
case msgChannelData:
|
||
var msg channelDataMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
// ignore any data from the client (for now)
|
||
|
||
case msgChannelClose:
|
||
var msg channelDataMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
|
||
default:
|
||
panic("unhandled message: " + fmt.Sprint(rawMsg[1]))
|
||
}
|
||
}
|
||
}
|
||
|
||
/*func (a *Authorize) ServeChannel(
|
||
server extensions_ssh.StreamManagement_ServeChannelServer,
|
||
) error {
|
||
var program *tea.Program
|
||
inputR, inputW := io.Pipe()
|
||
outputR, outputW := io.Pipe()
|
||
var peerId uint32
|
||
|
||
var downstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
||
var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo
|
||
for {
|
||
channelMsg, err := server.Recv()
|
||
if err != nil {
|
||
if errors.Is(err, io.EOF) {
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
rawMsg := channelMsg.GetRawBytes().GetValue()
|
||
fmt.Printf(" *** channelMsg: %x\n", rawMsg)
|
||
switch rawMsg[0] {
|
||
case msgChannelOpen:
|
||
var msg channelOpenMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
|
||
var confirm channelOpenConfirmMsg
|
||
peerId = msg.PeersID
|
||
confirm.PeersID = peerId
|
||
confirm.MyID = 1
|
||
confirm.MyWindow = msg.PeersWindow
|
||
confirm.MaxPacketSize = msg.MaxPacketSize
|
||
downstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
|
||
ChannelType: msg.ChanType,
|
||
DownstreamChannelId: confirm.PeersID,
|
||
InternalUpstreamChannelId: confirm.MyID,
|
||
InitialWindowSize: confirm.MyWindow,
|
||
MaxPacketSize: confirm.MaxPacketSize,
|
||
}
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||
RawBytes: &wrapperspb.BytesValue{
|
||
Value: gossh.Marshal(confirm),
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
|
||
case msgChannelRequest:
|
||
var msg channelRequestMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
|
||
fmt.Println(" *** SSH_MSG_CHANNEL_REQUEST: ", msg.Request)
|
||
|
||
switch msg.Request {
|
||
case "pty-req":
|
||
req := parsePtyReq(msg.RequestSpecificData)
|
||
items := []list.Item{
|
||
item("ubuntu@vm"),
|
||
item("joe@local"),
|
||
}
|
||
downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
|
||
TermEnv: req.TermEnv,
|
||
WidthColumns: req.Width,
|
||
HeightRows: req.Height,
|
||
WidthPx: req.WidthPx,
|
||
HeightPx: req.HeightPx,
|
||
Modes: req.Modes,
|
||
}
|
||
|
||
const defaultWidth = 20
|
||
|
||
l := list.New(items, itemDelegate{}, defaultWidth, listHeight)
|
||
l.Title = "Connect to which server?"
|
||
l.SetShowStatusBar(false)
|
||
l.SetFilteringEnabled(false)
|
||
l.Styles.Title = titleStyle
|
||
l.Styles.PaginationStyle = paginationStyle
|
||
l.Styles.HelpStyle = helpStyle
|
||
|
||
program = tea.NewProgram(model{list: l},
|
||
tea.WithInput(inputR),
|
||
tea.WithOutput(outputW),
|
||
tea.WithAltScreen(),
|
||
tea.WithContext(server.Context()),
|
||
tea.WithEnvironment([]string{"TERM=" + req.TermEnv}),
|
||
)
|
||
go func() {
|
||
answer, err := program.Run()
|
||
if err != nil {
|
||
return
|
||
}
|
||
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
|
||
handOff, _ := anypb.New(&extensions_ssh.SSHChannelControlAction{
|
||
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
||
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
||
DownstreamChannelInfo: downstreamChannelInfo,
|
||
DownstreamPtyInfo: downstreamPtyInfo,
|
||
UpstreamAuth: &extensions_ssh.AllowResponse{
|
||
Username: username,
|
||
Hostname: hostname,
|
||
},
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_ChannelControl{
|
||
ChannelControl: &extensions_ssh.ChannelControl{
|
||
Protocol: "ssh",
|
||
ControlAction: handOff,
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return
|
||
}
|
||
}()
|
||
go func() {
|
||
var buf [4096]byte
|
||
for {
|
||
n, err := outputR.Read(buf[:])
|
||
if err != nil {
|
||
return
|
||
}
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||
RawBytes: &wrapperspb.BytesValue{
|
||
Value: gossh.Marshal(channelDataMsg{
|
||
PeersID: peerId,
|
||
Length: uint32(n),
|
||
Rest: buf[:n],
|
||
}),
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
program.Send(tea.WindowSizeMsg{Width: int(req.Width), Height: int(req.Height)})
|
||
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||
RawBytes: &wrapperspb.BytesValue{
|
||
Value: gossh.Marshal(channelRequestSuccessMsg{
|
||
PeersID: peerId,
|
||
}),
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
case msgChannelData:
|
||
var msg channelDataMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
|
||
if program != nil {
|
||
inputW.Write(msg.Rest)
|
||
}
|
||
case msgChannelClose:
|
||
var msg channelDataMsg
|
||
gossh.Unmarshal(rawMsg, &msg)
|
||
default:
|
||
panic("unhandled message: " + fmt.Sprint(rawMsg[1]))
|
||
}
|
||
}
|
||
}*/
|
||
|
||
var whoamiTmpl = template.Must(template.New("whoami").Parse(`
|
||
User ID: {{.UserId}}
|
||
Session ID: {{.Id}}
|
||
Expires at: {{.ExpiresAt.AsTime}}
|
||
Claims:
|
||
{{- range $k, $v := .Claims }}
|
||
{{ $k }}: {{ $v.AsSlice }}
|
||
{{- end }}
|
||
`))
|
||
|
||
func (a *Authorize) serveInternalCommand(
|
||
server extensions_ssh.StreamManagement_ServeChannelServer,
|
||
peerID uint32,
|
||
command string,
|
||
) error {
|
||
md, ok := metadata.FromIncomingContext(server.Context())
|
||
fmt.Println("metadata.FromIncomingContext: ", md, ok)
|
||
|
||
var sessionID string
|
||
if h := md.Get("pomerium-session-id"); len(h) == 1 {
|
||
sessionID = h[0]
|
||
}
|
||
|
||
var output string
|
||
|
||
switch command {
|
||
case "logout":
|
||
client := a.state.Load().dataBrokerClient
|
||
err := session.Delete(server.Context(), client, sessionID)
|
||
if err != nil {
|
||
output = fmt.Sprint("internal error: ", err.Error())
|
||
} else {
|
||
output = "logged out\n"
|
||
}
|
||
case "whoami":
|
||
client := a.state.Load().dataBrokerClient
|
||
s, err := session.Get(server.Context(), client, sessionID)
|
||
if err != nil {
|
||
output = fmt.Sprint("couldn't fetch session: ", err.Error())
|
||
} else {
|
||
var b bytes.Buffer
|
||
whoamiTmpl.Execute(&b, s)
|
||
output = b.String()
|
||
}
|
||
default:
|
||
output = `available commands:
|
||
logout - ends the current Pomerium session
|
||
whoami - returns information about the current Pomerium session
|
||
`
|
||
}
|
||
|
||
if err := server.Send(&extensions_ssh.ChannelMessage{
|
||
Message: &extensions_ssh.ChannelMessage_RawBytes{
|
||
RawBytes: &wrapperspb.BytesValue{
|
||
Value: gossh.Marshal(channelDataMsg{
|
||
PeersID: peerID,
|
||
Length: uint32(len(output)),
|
||
Rest: []byte(output),
|
||
}),
|
||
},
|
||
},
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
|
||
time.Sleep(time.Second) // XXX
|
||
|
||
return nil
|
||
}
|
||
|
||
type ptyReq struct {
|
||
TermEnv string
|
||
Width, Height uint32
|
||
WidthPx, HeightPx uint32
|
||
Modes []byte
|
||
}
|
||
|
||
func parseString(reqData []byte) string {
|
||
stringLen := binary.BigEndian.Uint32(reqData)
|
||
reqData = reqData[4:]
|
||
return string(reqData[:stringLen])
|
||
}
|
||
|
||
func parsePtyReq(reqData []byte) ptyReq {
|
||
termEnvLen := binary.BigEndian.Uint32(reqData)
|
||
reqData = reqData[4:]
|
||
termEnv := string(reqData[:termEnvLen])
|
||
reqData = reqData[termEnvLen:]
|
||
return ptyReq{
|
||
TermEnv: termEnv,
|
||
Width: binary.BigEndian.Uint32(reqData),
|
||
Height: binary.BigEndian.Uint32(reqData[4:]),
|
||
WidthPx: binary.BigEndian.Uint32(reqData[8:]),
|
||
HeightPx: binary.BigEndian.Uint32(reqData[12:]),
|
||
Modes: reqData[16:],
|
||
}
|
||
}
|
||
|
||
const listHeight = 14
|
||
|
||
var (
|
||
titleStyle = lipgloss.NewStyle().MarginLeft(2)
|
||
itemStyle = lipgloss.NewStyle().PaddingLeft(4)
|
||
selectedItemStyle = lipgloss.NewStyle().PaddingLeft(2).Foreground(lipgloss.Color("170"))
|
||
paginationStyle = list.DefaultStyles().PaginationStyle.PaddingLeft(4)
|
||
helpStyle = list.DefaultStyles().HelpStyle.PaddingLeft(4).PaddingBottom(1)
|
||
quitTextStyle = lipgloss.NewStyle().Margin(1, 0, 2, 4)
|
||
)
|
||
|
||
type item string
|
||
|
||
func (i item) FilterValue() string { return "" }
|
||
|
||
type itemDelegate struct{}
|
||
|
||
func (d itemDelegate) Height() int { return 1 }
|
||
func (d itemDelegate) Spacing() int { return 0 }
|
||
func (d itemDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil }
|
||
func (d itemDelegate) Render(w io.Writer, m list.Model, index int, listItem list.Item) {
|
||
i, ok := listItem.(item)
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
str := fmt.Sprintf("%d. %s", index+1, i)
|
||
|
||
fn := itemStyle.Render
|
||
if index == m.Index() {
|
||
fn = func(s ...string) string {
|
||
return selectedItemStyle.Render("> " + strings.Join(s, " "))
|
||
}
|
||
}
|
||
|
||
fmt.Fprint(w, fn(str))
|
||
}
|
||
|
||
type model struct {
|
||
list list.Model
|
||
choice string
|
||
quitting bool
|
||
}
|
||
|
||
func (m model) Init() tea.Cmd {
|
||
return nil
|
||
}
|
||
|
||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||
switch msg := msg.(type) {
|
||
case tea.WindowSizeMsg:
|
||
m.list.SetWidth(msg.Width)
|
||
return m, nil
|
||
|
||
case tea.KeyMsg:
|
||
switch keypress := msg.String(); keypress {
|
||
case "q", "ctrl+c":
|
||
m.quitting = true
|
||
return m, tea.Quit
|
||
|
||
case "enter":
|
||
i, ok := m.list.SelectedItem().(item)
|
||
if ok {
|
||
m.choice = string(i)
|
||
}
|
||
return m, tea.Quit
|
||
}
|
||
}
|
||
|
||
var cmd tea.Cmd
|
||
m.list, cmd = m.list.Update(msg)
|
||
return m, cmd
|
||
}
|
||
|
||
func (m model) View() string {
|
||
if m.choice != "" {
|
||
return quitTextStyle.Render(fmt.Sprintf("%s? Sounds good to me.", m.choice))
|
||
}
|
||
if m.quitting {
|
||
return quitTextStyle.Render("Not hungry? That’s cool.")
|
||
}
|
||
return "\n" + m.list.View()
|
||
}
|