mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
refactor session state
This commit is contained in:
parent
3225d3b032
commit
315ee2610f
2 changed files with 118 additions and 59 deletions
|
@ -6,7 +6,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -45,8 +44,7 @@ type Authorize struct {
|
|||
tracerProvider oteltrace.TracerProvider
|
||||
tracer oteltrace.Tracer
|
||||
|
||||
activeStreamsMu sync.Mutex
|
||||
activeStreams []chan error
|
||||
activeStreams ActiveStreams
|
||||
}
|
||||
|
||||
// New validates and creates a new Authorize service from a set of config options.
|
||||
|
@ -59,7 +57,9 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
globalCache: storage.NewGlobalCache(time.Minute),
|
||||
tracerProvider: tracerProvider,
|
||||
tracer: tracer,
|
||||
activeStreams: []chan error{},
|
||||
activeStreams: ActiveStreams{
|
||||
streamsById: map[uint64]*StreamState{},
|
||||
},
|
||||
}
|
||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||
|
||||
|
@ -167,15 +167,6 @@ func newPolicyEvaluator(
|
|||
|
||||
// OnConfigChange updates internal structures based on config.Options
|
||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
a.activeStreamsMu.Lock()
|
||||
// demo code
|
||||
if cfg.Options.Routes[0].AllowAnyAuthenticatedUser == false {
|
||||
for _, s := range a.activeStreams {
|
||||
s <- fmt.Errorf("no longer authorized")
|
||||
}
|
||||
clear(a.activeStreams)
|
||||
}
|
||||
a.activeStreamsMu.Unlock()
|
||||
currentState := a.state.Load()
|
||||
a.currentConfig.Store(cfg)
|
||||
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||
|
|
|
@ -51,11 +51,46 @@ import (
|
|||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
type ActiveStreams struct {
|
||||
mu sync.Mutex
|
||||
streamsById map[uint64]*StreamState
|
||||
}
|
||||
|
||||
type StreamState struct {
|
||||
Context context.Context
|
||||
StreamID uint64
|
||||
ErrorC chan<- error
|
||||
Username string
|
||||
Hostname string
|
||||
PublicKey []byte
|
||||
MethodsAuthenticated []string
|
||||
Session *session.Session
|
||||
}
|
||||
|
||||
func (a *ActiveStreams) Get(id uint64) *StreamState {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.streamsById[id]
|
||||
}
|
||||
|
||||
func (a *ActiveStreams) Put(id uint64, state *StreamState) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.streamsById[id] = state
|
||||
}
|
||||
|
||||
func (a *ActiveStreams) Delete(id uint64) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
delete(a.streamsById, id)
|
||||
}
|
||||
|
||||
func (a *ActiveStreams) Range(f func(id uint64, state *StreamState)) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
for id, state := range a.streamsById {
|
||||
f(id, state)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Authorize) RecordingFinalized(
|
||||
|
@ -126,8 +161,6 @@ READ:
|
|||
return nil
|
||||
}
|
||||
|
||||
var activeStreamIds sync.Map
|
||||
|
||||
func (a *Authorize) ManageStream(
|
||||
server extensions_ssh.StreamManagement_ManageStreamServer,
|
||||
) error {
|
||||
|
@ -171,15 +204,13 @@ func (a *Authorize) ManageStream(
|
|||
}
|
||||
})
|
||||
|
||||
var state StreamState
|
||||
errC := make(chan error, 1)
|
||||
state := &StreamState{
|
||||
Context: ctx,
|
||||
ErrorC: errC,
|
||||
}
|
||||
|
||||
deviceAuthDone := make(chan struct{})
|
||||
sessionState := &atomic.Pointer[sessions.State]{}
|
||||
|
||||
errC := make(chan error, 1)
|
||||
a.activeStreamsMu.Lock()
|
||||
a.activeStreams = append(a.activeStreams, errC)
|
||||
a.activeStreamsMu.Unlock()
|
||||
for {
|
||||
select {
|
||||
case err := <-errC:
|
||||
|
@ -192,10 +223,14 @@ func (a *Authorize) ManageStream(
|
|||
case *extensions_ssh.ClientMessage_Event:
|
||||
switch event := req.Event.Event.(type) {
|
||||
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
||||
_ = event
|
||||
id := event.DownstreamConnected.StreamId
|
||||
if id == 0 {
|
||||
return fmt.Errorf("invalid stream ID: %v", id)
|
||||
}
|
||||
state.StreamID = id
|
||||
a.activeStreams.Put(id, state)
|
||||
defer a.activeStreams.Delete(id)
|
||||
case *extensions_ssh.StreamEvent_UpstreamConnected:
|
||||
activeStreamIds.Store(event.UpstreamConnected.GetStreamId(), state)
|
||||
defer activeStreamIds.Delete(event.UpstreamConnected.GetStreamId())
|
||||
case nil:
|
||||
}
|
||||
case *extensions_ssh.ClientMessage_AuthRequest:
|
||||
|
@ -230,8 +265,9 @@ func (a *Authorize) ManageStream(
|
|||
}
|
||||
|
||||
if session != nil {
|
||||
state.Session = session
|
||||
// Perform authorize check for this route
|
||||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
||||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -239,7 +275,7 @@ func (a *Authorize) ManageStream(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sendC <- handleEvaluatorResponseForSSH(res, &state, session.Id)
|
||||
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(ctx, errC, req, session.Id)
|
||||
|
@ -263,9 +299,7 @@ func (a *Authorize) ManageStream(
|
|||
}
|
||||
case "keyboard-interactive":
|
||||
route := a.getSSHRouteForHostname(state.Hostname)
|
||||
// if route == nil {
|
||||
// return fmt.Errorf("invalid route")
|
||||
// }
|
||||
// route can be nil, in which case the default idp will be used
|
||||
|
||||
opts := a.currentConfig.Load().Options
|
||||
idp, err := opts.GetIdentityProviderForPolicy(route)
|
||||
|
@ -292,9 +326,7 @@ func (a *Authorize) ManageStream(
|
|||
infoReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
|
||||
Name: "Sign in with " + idp.GetType(),
|
||||
Instruction: deviceAuthResp.VerificationURIComplete,
|
||||
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
|
||||
// {},
|
||||
},
|
||||
Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{},
|
||||
}
|
||||
|
||||
infoReqAny, _ := anypb.New(&infoReq)
|
||||
|
@ -328,13 +360,12 @@ func (a *Authorize) ManageStream(
|
|||
return
|
||||
}
|
||||
fmt.Println(token)
|
||||
err = a.PersistSession(ctx, s, claims, token)
|
||||
state.Session, err = a.PersistSession(ctx, s, claims, token)
|
||||
if err != nil {
|
||||
fmt.Println("error from PersistSession:", err)
|
||||
errC <- fmt.Errorf("error persisting session: %w", err)
|
||||
return
|
||||
}
|
||||
sessionState.Store(s)
|
||||
close(deviceAuthDone)
|
||||
}()
|
||||
}
|
||||
|
@ -351,7 +382,7 @@ func (a *Authorize) ManageStream(
|
|||
case <-deviceAuthDone:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
if sessionState.Load() != nil {
|
||||
if state.Session != nil {
|
||||
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
|
||||
} else {
|
||||
resp := extensions_ssh.ServerMessage{
|
||||
|
@ -369,18 +400,18 @@ func (a *Authorize) ManageStream(
|
|||
|
||||
if slices.Contains(state.MethodsAuthenticated, "publickey") {
|
||||
// Perform authorize check for this route
|
||||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
|
||||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := a.evaluate(ctx, req, sessionState.Load())
|
||||
res, err := a.evaluate(ctx, req, &sessions.State{ID: state.Session.Id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sendC <- handleEvaluatorResponseForSSH(res, &state, sessionState.Load().ID)
|
||||
sendC <- handleEvaluatorResponseForSSH(res, state)
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
|
||||
a.startContinuousAuthorization(ctx, errC, req, state.Session.Id)
|
||||
}
|
||||
} else {
|
||||
resp := extensions_ssh.ServerMessage{
|
||||
|
@ -476,7 +507,8 @@ func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
|
|||
}
|
||||
|
||||
func handleEvaluatorResponseForSSH(
|
||||
result *evaluator.Result, state *StreamState, sessionID string,
|
||||
result *evaluator.Result,
|
||||
state *StreamState,
|
||||
) *extensions_ssh.ServerMessage {
|
||||
// fmt.Printf(" *** evaluator result: %+v\n", result)
|
||||
|
||||
|
@ -499,7 +531,7 @@ func handleEvaluatorResponseForSSH(
|
|||
FilterMetadata: map[string]*structpb.Struct{
|
||||
"pomerium": {
|
||||
Fields: map[string]*structpb.Value{
|
||||
"session-id": structpb.NewStringValue(sessionID),
|
||||
"stream-id": structpb.NewStringValue(strconv.FormatUint(state.StreamID, 10)),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -592,7 +624,7 @@ func (a *Authorize) PersistSession(
|
|||
sessionState *sessions.State, // XXX: consider not using this struct
|
||||
claims identity.SessionClaims,
|
||||
accessToken *oauth2.Token,
|
||||
) error {
|
||||
) (*session.Session, error) {
|
||||
now := time.Now()
|
||||
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
|
||||
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
|
||||
|
@ -614,12 +646,12 @@ func (a *Authorize) PersistSession(
|
|||
|
||||
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
sessionState.DatabrokerServerVersion = res.GetServerVersion()
|
||||
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
||||
|
||||
return nil
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) startContinuousAuthorization(
|
||||
|
@ -731,7 +763,7 @@ func (a *Authorize) ServeChannel(
|
|||
}
|
||||
}()
|
||||
|
||||
var sessionID atomic.Pointer[string]
|
||||
var state *StreamState
|
||||
go func() {
|
||||
localWindow := uint32(channelWindowSize)
|
||||
for {
|
||||
|
@ -744,14 +776,28 @@ func (a *Authorize) ServeChannel(
|
|||
errC <- err
|
||||
return
|
||||
}
|
||||
if sessionID.Load() == nil {
|
||||
if state == nil {
|
||||
mdMsg, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_Metadata)
|
||||
if !ok {
|
||||
errC <- fmt.Errorf("first message was not metadata")
|
||||
return
|
||||
}
|
||||
id := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["session-id"].GetStringValue()
|
||||
sessionID.Store(&id)
|
||||
idStr := mdMsg.Metadata.FilterMetadata["pomerium"].Fields["stream-id"].GetStringValue()
|
||||
if idStr == "" {
|
||||
errC <- fmt.Errorf("no session ID found for stream %q", idStr)
|
||||
return
|
||||
}
|
||||
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
errC <- fmt.Errorf("invalid stream ID %q: %w", idStr, err)
|
||||
return
|
||||
}
|
||||
if v := a.activeStreams.Get(id); v != nil {
|
||||
state = v
|
||||
} else {
|
||||
errC <- fmt.Errorf("no stream state found for ID %d", id)
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if raw, ok := channelMsg.Message.(*extensions_ssh.ChannelMessage_RawBytes); ok {
|
||||
|
@ -877,7 +923,7 @@ func (a *Authorize) ServeChannel(
|
|||
return err
|
||||
}
|
||||
|
||||
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, *sessionID.Load(), inputR, outputW, sendC, &activeProgram)
|
||||
cmd := a.NewSSHCLI(a.currentConfig.Load(), downstreamPtyInfo, downstreamChannelInfo, state, inputR, outputW, sendC, &activeProgram)
|
||||
if msg.Request == "shell" {
|
||||
cmd.SetArgs([]string{"portal"})
|
||||
} else {
|
||||
|
@ -995,10 +1041,10 @@ func (a *Authorize) NewSSHCLI(
|
|||
cfg *config.Config,
|
||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||
sessionID string,
|
||||
state *StreamState,
|
||||
stdin io.Reader,
|
||||
stdout io.Writer,
|
||||
sendC chan any,
|
||||
sendC chan<- any,
|
||||
activeProgram *atomic.Pointer[tea.Program],
|
||||
) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
|
@ -1016,7 +1062,8 @@ func (a *Authorize) NewSSHCLI(
|
|||
return nil
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(NewPortalCommand(cfg, ptyInfo, channelInfo, sendC, activeProgram))
|
||||
sessionID := state.Session.Id
|
||||
cmd.AddCommand(a.NewPortalCommand(cfg, ptyInfo, channelInfo, state, sendC, activeProgram))
|
||||
cmd.AddCommand(a.NewLogoutCommand(cfg, sessionID))
|
||||
cmd.AddCommand(a.NewWhoamiCommand(cfg, sessionID))
|
||||
cmd.CompletionOptions.DisableDefaultCmd = true
|
||||
|
@ -1077,11 +1124,12 @@ func (a *Authorize) NewWhoamiCommand(
|
|||
return cmd
|
||||
}
|
||||
|
||||
func NewPortalCommand(
|
||||
func (a *Authorize) NewPortalCommand(
|
||||
cfg *config.Config,
|
||||
ptyInfo *extensions_ssh.SSHDownstreamPTYInfo,
|
||||
channelInfo *extensions_ssh.SSHDownstreamChannelInfo,
|
||||
sendC chan any,
|
||||
state *StreamState,
|
||||
sendC chan<- any,
|
||||
activeProgram *atomic.Pointer[tea.Program],
|
||||
) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
|
@ -1094,16 +1142,17 @@ func NewPortalCommand(
|
|||
var routes []string
|
||||
for r := range cfg.Options.GetAllPolicies() {
|
||||
if strings.HasPrefix(r.From, "ssh://") {
|
||||
routes = append(routes, fmt.Sprintf("ubuntu@%s", strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
|
||||
routes = append(routes, fmt.Sprintf("%s@%s", state.Username, strings.TrimSuffix(strings.TrimPrefix(r.From, "ssh://"), "."+cfg.Options.SSHHostname)))
|
||||
}
|
||||
}
|
||||
items := []list.Item{}
|
||||
for _, route := range routes {
|
||||
items = append(items, item(route))
|
||||
}
|
||||
activeStreamIds.Range(func(key, value any) bool {
|
||||
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", key)))
|
||||
return true
|
||||
a.activeStreams.Range(func(id uint64, state *StreamState) {
|
||||
if id != state.StreamID {
|
||||
items = append(items, item(fmt.Sprintf("[demo] mirror session: %v", id)))
|
||||
}
|
||||
})
|
||||
|
||||
l := list.New(items, itemDelegate{}, int(ptyInfo.WidthColumns-2), int(ptyInfo.HeightRows-2))
|
||||
|
@ -1152,6 +1201,25 @@ func NewPortalCommand(
|
|||
})
|
||||
} else {
|
||||
username, hostname, _ := strings.Cut(answer.(model).choice, "@")
|
||||
// Perform authorize check for this route
|
||||
state.Hostname = hostname
|
||||
if username != state.Username {
|
||||
return fmt.Errorf("internal error: username mismatch")
|
||||
}
|
||||
req, err := a.getEvaluatorRequestFromSSHAuthRequest(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := a.evaluate(cmd.Context(), req, &sessions.State{ID: state.Session.Id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
a.startContinuousAuthorization(state.Context, state.ErrorC, req, state.Session.Id)
|
||||
} else {
|
||||
return fmt.Errorf("not authorized")
|
||||
}
|
||||
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
|
||||
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", username, hostname, time.Now().UnixNano()),
|
||||
Format: extensions_session_recording.Format_AsciicastFormat,
|
||||
|
|
Loading…
Add table
Reference in a new issue