mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
refactor session state
This commit is contained in:
parent
3225d3b032
commit
315ee2610f
2 changed files with 118 additions and 59 deletions
|
@ -6,7 +6,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -45,8 +44,7 @@ type Authorize struct {
|
||||||
tracerProvider oteltrace.TracerProvider
|
tracerProvider oteltrace.TracerProvider
|
||||||
tracer oteltrace.Tracer
|
tracer oteltrace.Tracer
|
||||||
|
|
||||||
activeStreamsMu sync.Mutex
|
activeStreams ActiveStreams
|
||||||
activeStreams []chan error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New validates and creates a new Authorize service from a set of config options.
|
// New validates and creates a new Authorize service from a set of config options.
|
||||||
|
@ -59,7 +57,9 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
globalCache: storage.NewGlobalCache(time.Minute),
|
globalCache: storage.NewGlobalCache(time.Minute),
|
||||||
tracerProvider: tracerProvider,
|
tracerProvider: tracerProvider,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
activeStreams: []chan error{},
|
activeStreams: ActiveStreams{
|
||||||
|
streamsById: map[uint64]*StreamState{},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||||
|
|
||||||
|
@ -167,15 +167,6 @@ func newPolicyEvaluator(
|
||||||
|
|
||||||
// OnConfigChange updates internal structures based on config.Options
|
// OnConfigChange updates internal structures based on config.Options
|
||||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
a.activeStreamsMu.Lock()
|
|
||||||
// demo code
|
|
||||||
if cfg.Options.Routes[0].AllowAnyAuthenticatedUser == false {
|
|
||||||
for _, s := range a.activeStreams {
|
|
||||||
s <- fmt.Errorf("no longer authorized")
|
|
||||||
}
|
|
||||||
clear(a.activeStreams)
|
|
||||||
}
|
|
||||||
a.activeStreamsMu.Unlock()
|
|
||||||
currentState := a.state.Load()
|
currentState := a.state.Load()
|
||||||
a.currentConfig.Store(cfg)
|
a.currentConfig.Store(cfg)
|
||||||
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||||
|
|
|
@ -51,11 +51,46 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ActiveStreams struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
streamsById map[uint64]*StreamState
|
||||||
|
}
|
||||||
|
|
||||||
type StreamState struct {
|
type StreamState struct {
|
||||||
|
Context context.Context
|
||||||
|
StreamID uint64
|
||||||
|
ErrorC chan<- error
|
||||||
Username string
|
Username string
|
||||||
Hostname string
|
Hostname string
|
||||||
PublicKey []byte
|
PublicKey []byte
|
||||||
MethodsAuthenticated []string
|
MethodsAuthenticated []string
|
||||||
|
Session *session.Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ActiveStreams) Get(id uint64) *StreamState {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
return a.streamsById[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ActiveStreams) Put(id uint64, state *StreamState) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
a.streamsById[id] = state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ActiveStreams) Delete(id uint64) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
delete(a.streamsById, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ActiveStreams) Range(f func(id uint64, state *StreamState)) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
for id, state := range a.streamsById {
|
||||||
|
f(id, state)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) RecordingFinalized(
|
func (a *Authorize) RecordingFinalized(
|
||||||
|
@ -126,8 +161,6 @@ READ:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var activeStreamIds sync.Map
|
|
||||||
|
|
||||||
func (a *Authorize) ManageStream(
|
func (a *Authorize) ManageStream(
|
||||||
server extensions_ssh.StreamManagement_ManageStreamServer,
|
server extensions_ssh.StreamManagement_ManageStreamServer,
|
||||||
) error {
|
) error {
|
||||||
|
@ -171,15 +204,13 @@ func (a *Authorize) ManageStream(
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
var state StreamState
|
errC := make(chan error, 1)
|
||||||
|
state := &StreamState{
|
||||||
|
Context: ctx,
|
||||||
|
ErrorC: errC,
|
||||||
|
}
|
||||||
|
|
||||||
deviceAuthDone := make(chan struct{})
|
deviceAuthDone := make(chan struct{})
|
||||||
sessionState := &atomic.Pointer[sessions.State]{}
|
|
||||||
|
|
||||||
errC := make(chan error, 1)
|
|
||||||
a.activeStreamsMu.Lock()
|
|
||||||
a.activeStreams = append(a.activeStreams, errC)
|
|
||||||
a.activeStreamsMu.Unlock()
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case err := <-errC:
|
case err := <-errC:
|
||||||
|
@ -192,10 +223,14 @@ func (a *Authorize) ManageStream(
|
||||||
case *extensions_ssh.ClientMessage_Event:
|
case *extensions_ssh.ClientMessage_Event:
|
||||||
switch event := req.Event.Event.(type) {
|
switch event := req.Event.Event.(type) {
|
||||||
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
||||||
_ = event
|
id := event.DownstreamConnected.StreamId
|
||||||
|
if id == 0 {
|
||||||
|
return fmt.Errorf("invalid stream ID: %v", id)
|
||||||
|
}
|
||||||
|
state.StreamID = id
|
||||||
|
a.activeStreams.Put(id, state)
|
||||||
|
defer a.activeStreams.Delete(id)
|
||||||
case *extensions_ssh.StreamEvent_UpstreamConnected:
|
case *extensions_ssh.StreamEvent_UpstreamConnected:
|
||||||
activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state)
|
|
||||||
defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId())
|
|
||||||
case nil:
|
case nil:
|
||||||
}
|
}
|
||||||
case *extensions_ssh.ClientMessage_AuthRequest:
|
case *extensions_ssh.ClientMessage_AuthRequest:
|
||||||
|
@ -230,8 +265,9 @@ func (a *Authorize) ManageStream(
|
||||||
}
|
}
|
||||||
|
|
||||||
if session != nil {
|
if session != nil {
|
||||||
|
state.Session = session
|
||||||
// Perform authorize check for this route
|
// Perform authorize check for this route
|
||||||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -239,7 +275,7 @@ func (a *Authorize) ManageStream(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id)
|
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||||
|
|
||||||
if res.Allow.Value && !res.Deny.Value {
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
||||||
|
@ -263,9 +299,7 @@ func (a *Authorize) ManageStream(
|
||||||
}
|
}
|
||||||
case "keyboard-interactive":
|
case "keyboard-interactive":
|
||||||
route := a.getSSHRouteForHostname(state.Hostname)
|
route := a.getSSHRouteForHostname(state.Hostname)
|
||||||
// if route == nil {
|
// route can be nil, in which case the default idp will be used
|
||||||
// return fmt.Errorf("invalid route")
|
|
||||||
// }
|
|
||||||
|
|
||||||
opts := a.currentConfig.Load().Options
|
opts := a.currentConfig.Load().Options
|
||||||
idp, err := opts.GetIdentityProviderForPolicy(route)
|
idp, err := opts.GetIdentityProviderForPolicy(route)
|
||||||
|
@ -292,9 +326,7 @@ func (a *Authorize) ManageStream(
|
||||||
infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
|
infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
|
||||||
Name: "Sign in with " + idp.GetType(),
|
Name: "Sign in with " + idp.GetType(),
|
||||||
Instruction: deviceAuthResp.VerificationURIComplete,
|
Instruction: deviceAuthResp.VerificationURIComplete,
|
||||||
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
|
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{},
|
||||||
// {},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
infoReqAny, _ := anypb.New(&infoReq)
|
infoReqAny, _ := anypb.New(&infoReq)
|
||||||
|
@ -328,13 +360,12 @@ func (a *Authorize) ManageStream(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Println(token)
|
fmt.Println(token)
|
||||||
err = a.PersistSession(ctx, s, claims, token)
|
state.Session, err = a.PersistSession(ctx, s, claims, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("error from PersistSession:", err)
|
fmt.Println("error from PersistSession:", err)
|
||||||
errC <- fmt.Errorf("error persisting session: %w", err)
|
errC <- fmt.Errorf("error persisting session: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sessionState.Store(s)
|
|
||||||
close(deviceAuthDone)
|
close(deviceAuthDone)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -351,7 +382,7 @@ func (a *Authorize) ManageStream(
|
||||||
case <-deviceAuthDone:
|
case <-deviceAuthDone:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
if sessionState.Load() != nil {
|
if state.Session != nil {
|
||||||
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
|
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
|
||||||
} else {
|
} else {
|
||||||
resp := extensions_ssh.ServerMessage{
|
resp := extensions_ssh.ServerMessage{
|
||||||
|
@ -369,18 +400,18 @@ func (a *Authorize) ManageStream(
|
||||||
|
|
||||||
if slices.Contains(state.MethodsAuthenticated, "publickey") {
|
if slices.Contains(state.MethodsAuthenticated, "publickey") {
|
||||||
// Perform authorize check for this route
|
// Perform authorize check for this route
|
||||||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res, err := a.evaluate(ctx, req, sessionState.Load())
|
res, err := a.evaluate(ctx, req, &sessions.State{ID: state.Session.Id})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID)
|
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||||
|
|
||||||
if res.Allow.Value && !res.Deny.Value {
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
|
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
resp := extensions_ssh.ServerMessage{
|
resp := extensions_ssh.ServerMessage{
|
||||||
|
@ -476,7 +507,8 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleEvaluatorResponseForSSH(
|
func handleEvaluatorResponseForSSH(
|
||||||
result *evaluator.Result, state *StreamState, sessionID string,
|
result *evaluator.Result,
|
||||||
|
state *StreamState,
|
||||||
) *extensions_ssh.ServerMessage {
|
) *extensions_ssh.ServerMessage {
|
||||||
// fmt.Printf(" *** evaluator result: %+v\n", result)
|
// fmt.Printf(" *** evaluator result: %+v\n", result)
|
||||||
|
|
||||||
|
@ -499,7 +531,7 @@ func handleEvaluatorResponseForSSH(
|
||||||
FilterMetadata: map[string]*structpb.Struct{
|
FilterMetadata: map[string]*structpb.Struct{
|
||||||
"pomerium": {
|
"pomerium": {
|
||||||
Fields: map[string]*structpb.Value{
|
Fields: map[string]*structpb.Value{
|
||||||
"session-id": structpb.NewStringValue(sessionID),
|
"stream-id": structpb.NewStringValue(strconv.FormatUint(state.StreamID, 10)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -592,7 +624,7 @@ func (a *Authorize) PersistSession(
|
||||||
sessionState *sessions.State, // XXX: consider not using this struct
|
sessionState *sessions.State, // XXX: consider not using this struct
|
||||||
claims identity.SessionClaims,
|
claims identity.SessionClaims,
|
||||||
accessToken *oauth2.Token,
|
accessToken *oauth2.Token,
|
||||||
) error {
|
) (*session.Session, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
|
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
|
||||||
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
|
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
|
||||||
|
@ -614,12 +646,12 @@ func (a *Authorize) PersistSession(
|
||||||
|
|
||||||
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
|
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
sessionState.DatabrokerServerVersion = res.GetServerVersion()
|
sessionState.DatabrokerServerVersion = res.GetServerVersion()
|
||||||
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
||||||
|
|
||||||
return nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) startContinuousAuthorization(
|
func (a *Authorize) startContinuousAuthorization(
|
||||||
|
@ -731,7 +763,7 @@ func (a *Authorize) ServeChannel(
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var sessionID atomic.Pointer[string]
|
var state *StreamState
|
||||||
go func() {
|
go func() {
|
||||||
localWindow := uint32(channelWindowSize)
|
localWindow := uint32(channelWindowSize)
|
||||||
for {
|
for {
|
||||||
|
@ -744,14 +776,28 @@ func (a *Authorize) ServeChannel(
|
||||||
errC <- err
|
errC <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if sessionID.Load() == nil {
|
if state == nil {
|
||||||
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
|
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
|
||||||
if !ok {
|
if !ok {
|
||||||
errC <- fmt.Errorf("first message was not metadata")
|
errC <- fmt.Errorf("first message was not metadata")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue()
|
idStr := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["stream-id"].GetStringValue()
|
||||||
sessionID.Store(&id)
|
if idStr == "" {
|
||||||
|
errC <- fmt.Errorf("no session ID found for stream %q", idStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
errC <- fmt.Errorf("invalid stream ID %q: %w", idStr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v := a.activeStreams.Get(id); v != nil {
|
||||||
|
state = v
|
||||||
|
} else {
|
||||||
|
errC <- fmt.Errorf("no stream state found for ID %d", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
||||||
|
@ -877,7 +923,7 @@ func (a *Authorize) ServeChannel(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram)
|
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, state, inputR, outputW, sendC, &activeProgram)
|
||||||
if msg.Request == "shell" {
|
if msg.Request == "shell" {
|
||||||
cmd.SetArgs([]string{"portal"})
|
cmd.SetArgs([]string{"portal"})
|
||||||
} else {
|
} else {
|
||||||
|
@ -995,10 +1041,10 @@ func (a *Authorize) NewSSHCLI(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||||
sessionID string,
|
state *StreamState,
|
||||||
stdin io.Reader,
|
stdin io.Reader,
|
||||||
stdout io.Writer,
|
stdout io.Writer,
|
||||||
sendC chan any,
|
sendC chan<- any,
|
||||||
activeProgram *atomic.Pointer[tea.Program],
|
activeProgram *atomic.Pointer[tea.Program],
|
||||||
) *cobra.Command {
|
) *cobra.Command {
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
|
@ -1016,7 +1062,8 @@ func (a *Authorize) NewSSHCLI(
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
|
sessionID := state.Session.Id
|
||||||
|
cmd.AddCommand(a.NewPortalCommand(cfg, ptyInfo, channelInfo, state, sendC, activeProgram))
|
||||||
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
|
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
|
||||||
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
|
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
|
||||||
cmd.CompletionOptions.DisableDefaultCmd = true
|
cmd.CompletionOptions.DisableDefaultCmd = true
|
||||||
|
@ -1077,11 +1124,12 @@ func (a *Authorize) NewWhoamiCommand(
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPortalCommand(
|
func (a *Authorize) NewPortalCommand(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||||
sendC chan any,
|
state *StreamState,
|
||||||
|
sendC chan<- any,
|
||||||
activeProgram *atomic.Pointer[tea.Program],
|
activeProgram *atomic.Pointer[tea.Program],
|
||||||
) *cobra.Command {
|
) *cobra.Command {
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
|
@ -1094,16 +1142,17 @@ func NewPortalCommand(
|
||||||
var routes []string
|
var routes []string
|
||||||
for r := range cfg.Options.GetAllPolicies() {
|
for r := range cfg.Options.GetAllPolicies() {
|
||||||
if strings.HasPrefix(r.From, "ssh://") {
|
if strings.HasPrefix(r.From, "ssh://") {
|
||||||
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
|
routes = append(routes, fmt.Sprintf("%s@%s", state.Username, strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
items := []list.Item{}
|
items := []list.Item{}
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
items = append(items, item(route))
|
items = append(items, item(route))
|
||||||
}
|
}
|
||||||
activeStreamIds.Range(func(key, value any) bool {
|
a.activeStreams.Range(func(id uint64, state *StreamState) {
|
||||||
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
|
if id != state.StreamID {
|
||||||
return true
|
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", id)))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
|
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
|
||||||
|
@ -1152,6 +1201,25 @@ func NewPortalCommand(
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
|
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
|
||||||
|
// Perform authorize check for this route
|
||||||
|
state.Hostname = hostname
|
||||||
|
if username != state.Username {
|
||||||
|
return fmt.Errorf("internal error: username mismatch")
|
||||||
|
}
|
||||||
|
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res, err := a.evaluate(cmd.Context(), req, &sessions.State{ID: state.Session.Id})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Allow.Value && !res.Deny.Value {
|
||||||
|
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("not authorized")
|
||||||
|
}
|
||||||
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
||||||
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
|
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
|
||||||
Format: extensions_session_recording.Format_AsciicastFormat,
|
Format: extensions_session_recording.Format_AsciicastFormat,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue