ssh: continuous authorization (#5687)

Re-evaluate ssh authorization decision on a fixed interval, or whenever 
the config changes. If access is no longer allowed, log a new 'authorize
check' message and disconnect. 

Refactor the ssh.StreamManager initialization so that its lifecycle 
matches the Authorize lifecycle.
This commit is contained in:
Kenneth Jenkins 2025-07-02 12:01:25 -07:00 committed by GitHub
parent 31020a75a6
commit 177677f239
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 170 additions and 98 deletions

View file

@ -20,6 +20,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/ssh"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
)
@ -31,6 +32,7 @@ type Authorize struct {
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
ssh *ssh.StreamManager
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
@ -52,6 +54,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
tracer: tracer,
}
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
a.ssh = ssh.NewStreamManager(ctx, ssh.NewAuth(a, a.currentConfig, a.tracerProvider), cfg)
state, err := newAuthorizeStateFromConfig(ctx, nil, tracerProvider, cfg, a.store)
if err != nil {
@ -161,4 +164,5 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
} else {
a.state.Store(newState)
}
a.ssh.OnConfigChange(cfg)
}

View file

@ -27,12 +27,7 @@ func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageSt
return status.Errorf(codes.Internal, "first message was not a downstream connected event")
}
state := a.state.Load()
handler := state.ssh.NewStreamHandler(
a.currentConfig.Load(),
ssh.NewAuth(a, state.dataBrokerClient, a.currentConfig, a.tracerProvider),
downstream,
)
handler := a.ssh.NewStreamHandler(downstream)
defer handler.Close()
eg, ctx := errgroup.WithContext(stream.Context())
@ -85,7 +80,7 @@ func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeCha
} else {
return status.Errorf(codes.Internal, "first message was not metadata")
}
handler := a.state.Load().ssh.LookupStream(streamID)
handler := a.ssh.LookupStream(streamID)
if handler == nil || !handler.IsExpectingInternalChannel() {
return status.Errorf(codes.InvalidArgument, "stream not found")
}
@ -121,6 +116,8 @@ func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluat
return nil, err
}
skipLogging := req.LogOnlyIfDenied && res.Allow.Value && !res.Deny.Value
if !skipLogging {
s, _ := a.getDataBrokerSessionOrServiceAccount(ctx, req.SessionID, 0)
var u *user.User
@ -128,6 +125,7 @@ func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluat
u, _ = a.getDataBrokerUser(ctx, s.GetUserId())
}
a.logAuthorizeCheck(ctx, &evalreq, res, s, u)
}
return res, nil
}

View file

@ -20,7 +20,6 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/ssh"
"github.com/pomerium/pomerium/pkg/storage"
)
@ -40,7 +39,6 @@ type authorizeState struct {
authenticateFlow authenticateFlow
syncQueriers map[string]storage.Querier
mcp *mcp.Handler
ssh *ssh.StreamManager
}
func newAuthorizeStateFromConfig(
@ -72,8 +70,6 @@ func newAuthorizeStateFromConfig(
evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp))
}
state.ssh = ssh.NewStreamManager()
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...)
if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)

View file

@ -30,8 +30,9 @@ import (
"github.com/pomerium/pomerium/pkg/storage"
)
type PolicyEvaluator interface {
type Evaluator interface {
EvaluateSSH(context.Context, *Request) (*evaluator.Result, error)
GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
}
type Request struct {
@ -39,22 +40,22 @@ type Request struct {
Hostname string
PublicKey []byte
SessionID string
LogOnlyIfDenied bool
}
type Auth struct {
evaluator PolicyEvaluator
dataBrokerClient databroker.DataBrokerServiceClient
evaluator Evaluator
currentConfig *atomicutil.Value[*config.Config]
tracerProvider oteltrace.TracerProvider
}
func NewAuth(
evaluator PolicyEvaluator,
client databroker.DataBrokerServiceClient,
evaluator Evaluator,
currentConfig *atomicutil.Value[*config.Config],
tracerProvider oteltrace.TracerProvider,
) *Auth {
return &Auth{evaluator, client, currentConfig, tracerProvider}
return &Auth{evaluator, currentConfig, tracerProvider}
}
func (a *Auth) HandlePublicKeyMethodRequest(
@ -93,7 +94,7 @@ func (a *Auth) handlePublicKeyMethodRequest(
// Special case: internal command (e.g. routes portal).
if *info.Hostname == "" {
_, err := session.Get(ctx, a.dataBrokerClient, sessionID)
_, err := session.Get(ctx, a.evaluator.GetDataBrokerServiceClient(), sessionID)
if status.Code(err) == codes.NotFound {
// Require IdP login.
return PublicKeyAuthMethodResponse{
@ -229,7 +230,7 @@ func (a *Auth) handleLogin(
return a.saveSession(ctx, sessionID, &sessionClaims, token)
}
var errAccessDenied = errors.New("access denied")
var errAccessDenied = status.Error(codes.PermissionDenied, "access denied")
func (a *Auth) EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error {
req, err := sshRequestFromStreamAuthInfo(info)
@ -252,7 +253,7 @@ func (a *Auth) FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte,
if err != nil {
return nil, err
}
session, err := session.Get(ctx, a.dataBrokerClient, sessionID)
session, err := session.Get(ctx, a.evaluator.GetDataBrokerServiceClient(), sessionID)
if err != nil {
return nil, err
}
@ -269,7 +270,7 @@ func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error {
if err != nil {
return err
}
err = session.Delete(ctx, a.dataBrokerClient, sessionID)
err = session.Delete(ctx, a.evaluator.GetDataBrokerServiceClient(), sessionID)
a.invalidateCacheForRecord(ctx, &databroker.Record{
Type: "type.googleapis.com/session.Session",
Id: sessionID,
@ -304,7 +305,7 @@ func (a *Auth) saveSession(
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
u, _ := user.Get(ctx, a.dataBrokerClient, sess.GetUserId())
u, _ := user.Get(ctx, a.evaluator.GetDataBrokerServiceClient(), sess.GetUserId())
if u == nil {
// if no user exists yet, create a new one
u = &user.User{
@ -312,12 +313,12 @@ func (a *Auth) saveSession(
}
}
u.PopulateFromClaims(claims.Claims)
_, err := databroker.Put(ctx, a.dataBrokerClient, u)
_, err := databroker.Put(ctx, a.evaluator.GetDataBrokerServiceClient(), u)
if err != nil {
return err
}
resp, err := session.Put(ctx, a.dataBrokerClient, sess)
resp, err := session.Put(ctx, a.evaluator.GetDataBrokerServiceClient(), sess)
if err != nil {
return err
}
@ -327,7 +328,9 @@ func (a *Auth) saveSession(
func (a *Auth) invalidateCacheForRecord(ctx context.Context, record *databroker.Record) {
ctx = storage.WithQuerier(ctx,
storage.NewCachingQuerier(storage.NewQuerier(a.dataBrokerClient), storage.GlobalCache))
storage.NewCachingQuerier(
storage.NewQuerier(a.evaluator.GetDataBrokerServiceClient()),
storage.GlobalCache))
storage.InvalidateCacheForDataBrokerRecords(ctx, record)
}
@ -375,6 +378,8 @@ func sshRequestFromStreamAuthInfo(info StreamAuthInfo) (*Request, error) {
Hostname: *info.Hostname,
PublicKey: info.PublicKeyAllow.Value.PublicKey,
SessionID: sessionID,
LogOnlyIfDenied: info.InitialAuthComplete,
}, nil
}

View file

@ -38,10 +38,10 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(context.Context, *Request) (*evaluator.Result, error) {
pe := func(context.Context, *Request) (*evaluator.Result, error) {
return nil, errors.New("error evaluating policy")
})
a := NewAuth(pe, nil, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
_, err := a.handlePublicKeyMethodRequest(t.Context(), info, &req)
assert.ErrorContains(t, err, "error evaluating policy")
})
@ -54,7 +54,7 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
fakePublicKey := []byte("fake-public-key")
req.PublicKey = fakePublicKey
pe := policyEvaluatorFunc(func(_ context.Context, r *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, r *Request) (*evaluator.Result, error) {
assert.Equal(t, r, &Request{
Username: "username",
Hostname: "hostname",
@ -65,8 +65,8 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Empty(t, res.RequireAdditionalMethods)
@ -80,13 +80,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(true),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Nil(t, res.Allow)
@ -99,13 +99,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false, criteria.ReasonSSHPublickeyUnauthorized),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Nil(t, res.Allow)
@ -118,13 +118,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
@ -144,13 +144,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
}, nil
})
a := NewAuth(pe, client, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{pe, client}, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
@ -179,13 +179,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, client, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{pe, client}, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
@ -205,13 +205,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, client, nil, nil)
}
a := NewAuth(fakePolicyEvaluator{pe, client}, nil, nil)
_, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.ErrorContains(t, err, "internal error")
})
@ -224,12 +224,12 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
assert.ErrorContains(t, err, "expected PublicKeyAllow message not to be nil")
})
t.Run("ok", func(t *testing.T) {
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
}
var putRecords []*databroker.Record
client := fakeDataBrokerServiceClient{
get: func(
@ -255,7 +255,7 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
a := NewAuth(fakePolicyEvaluator{pe, client}, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
@ -281,12 +281,12 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
assert.Equal(t, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=", putRecords[1].Id)
})
t.Run("denied", func(t *testing.T) {
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false),
}, nil
})
}
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
@ -310,7 +310,7 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
a := NewAuth(fakePolicyEvaluator{pe, client}, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
@ -336,7 +336,7 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(nil, nil, atomicutil.NewValue(&cfg), nil)
a := NewAuth(nil, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
@ -386,7 +386,7 @@ func TestFormatSession(t *testing.T) {
}, nil
},
}
a := NewAuth(nil, client, nil, nil)
a := NewAuth(fakePolicyEvaluator{client: client}, nil, nil)
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
@ -424,7 +424,7 @@ func TestDeleteSession(t *testing.T) {
return nil, putError
},
}
a := NewAuth(nil, client, nil, nil)
a := NewAuth(fakePolicyEvaluator{client: client}, nil, nil)
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
@ -433,12 +433,17 @@ func TestDeleteSession(t *testing.T) {
})
}
type policyEvaluatorFunc func(context.Context, *Request) (*evaluator.Result, error)
type fakePolicyEvaluator struct {
evaluateSSH func(context.Context, *Request) (*evaluator.Result, error)
client databroker.DataBrokerServiceClient
}
func (f policyEvaluatorFunc) EvaluateSSH(
ctx context.Context, req *Request,
) (*evaluator.Result, error) {
return f(ctx, req)
func (f fakePolicyEvaluator) EvaluateSSH(ctx context.Context, req *Request) (*evaluator.Result, error) {
return f.evaluateSSH(ctx, req)
}
func (f fakePolicyEvaluator) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return f.client
}
type fakeDataBrokerServiceClient struct {
@ -448,12 +453,12 @@ type fakeDataBrokerServiceClient struct {
put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
}
func (m fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return m.get(ctx, in, opts...)
func (f fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return f.get(ctx, in, opts...)
}
func (m fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
return m.put(ctx, in, opts...)
func (f fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
return f.put(ctx, in, opts...)
}
type noopQuerier struct{}

View file

@ -1,6 +1,7 @@
package ssh
import (
"context"
"sync"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
@ -8,29 +9,43 @@ import (
)
type StreamManager struct {
auth AuthInterface
reauthC chan struct{}
mu sync.Mutex
cfg *config.Config
activeStreams map[uint64]*StreamHandler
}
func NewStreamManager() *StreamManager {
return &StreamManager{
func NewStreamManager(ctx context.Context, auth AuthInterface, cfg *config.Config) *StreamManager {
sm := &StreamManager{
auth: auth,
reauthC: make(chan struct{}, 1),
cfg: cfg,
activeStreams: map[uint64]*StreamHandler{},
}
go sm.reauthLoop(ctx)
return sm
}
func (sm *StreamManager) OnConfigChange(cfg *config.Config) {
sm.mu.Lock()
sm.cfg = cfg
sm.mu.Unlock()
select {
case sm.reauthC <- struct{}{}:
default:
}
}
func (sm *StreamManager) LookupStream(streamID uint64) *StreamHandler {
sm.mu.Lock()
defer sm.mu.Unlock()
stream := sm.activeStreams[streamID]
if stream == nil {
return nil
}
return stream
return sm.activeStreams[streamID]
}
func (sm *StreamManager) NewStreamHandler(
cfg *config.Config,
auth AuthInterface,
downstream *extensions_ssh.DownstreamConnectEvent,
) *StreamHandler {
sm.mu.Lock()
@ -38,11 +53,12 @@ func (sm *StreamManager) NewStreamHandler(
streamID := downstream.StreamId
writeC := make(chan *extensions_ssh.ServerMessage, 32)
sh := &StreamHandler{
auth: auth,
config: cfg,
auth: sm.auth,
config: sm.cfg,
downstream: downstream,
readC: make(chan *extensions_ssh.ClientMessage, 32),
writeC: writeC,
reauthC: make(chan struct{}),
close: func() {
sm.mu.Lock()
defer sm.mu.Unlock()
@ -53,3 +69,23 @@ func (sm *StreamManager) NewStreamHandler(
sm.activeStreams[streamID] = sh
return sh
}
func (sm *StreamManager) reauthLoop(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-sm.reauthC:
sm.mu.Lock()
snapshot := make([]*StreamHandler, 0, len(sm.activeStreams))
for _, s := range sm.activeStreams {
snapshot = append(snapshot, s)
}
sm.mu.Unlock()
for _, s := range snapshot {
s.Reauth()
}
}
}
}

View file

@ -22,17 +22,17 @@ func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL {
func TestStreamManager(t *testing.T) {
ctrl := gomock.NewController(t)
auth := mock_ssh.NewMockAuthInterface(ctrl)
m := ssh.NewStreamManager()
cfg := &config.Config{Options: config.NewDefaultOptions()}
cfg.Options.Policies = []config.Policy{
{From: "ssh://host1", To: mustParseWeightedURLs(t, "ssh://dest1:22")},
{From: "ssh://host2", To: mustParseWeightedURLs(t, "ssh://dest2:22")},
}
m := ssh.NewStreamManager(t.Context(), auth, cfg)
t.Run("LookupStream", func(t *testing.T) {
assert.Nil(t, m.LookupStream(1234))
sh := m.NewStreamHandler(cfg, auth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
sh := m.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
assert.Equal(t, sh, m.LookupStream(1234))
sh.Close()
assert.Nil(t, m.LookupStream(1234))

View file

@ -3,6 +3,7 @@ package ssh
import (
"context"
"iter"
"time"
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
gossh "golang.org/x/crypto/ssh"
@ -73,6 +74,7 @@ type StreamAuthInfo struct {
PublicKeyFingerprintSha256 []byte
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
InitialAuthComplete bool
}
func (i *StreamAuthInfo) allMethodsValid() bool {
@ -93,6 +95,7 @@ type StreamHandler struct {
downstream *extensions_ssh.DownstreamConnectEvent
writeC chan *extensions_ssh.ServerMessage
readC chan *extensions_ssh.ClientMessage
reauthC chan struct{}
state *StreamState
close func()
@ -119,6 +122,21 @@ func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
return sh.writeC
}
// Reauth blocks until authorization policy is reevaluated.
func (sh *StreamHandler) Reauth() {
sh.reauthC <- struct{}{}
}
func (sh *StreamHandler) periodicReauth() (cancel func()) {
t := time.NewTicker(1 * time.Minute)
go func() {
for range t.C {
sh.Reauth()
}
}()
return t.Stop
}
// Prompt implements KeyboardInteractiveQuerier.
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
sh.sendInfoPrompts(prompts)
@ -154,10 +172,16 @@ func (sh *StreamHandler) Run(ctx context.Context) error {
SourceAddress: sh.downstream.SourceAddress,
},
}
cancelReauth := sh.periodicReauth()
defer cancelReauth()
for {
select {
case <-ctx.Done():
return context.Cause(ctx)
case <-sh.reauthC:
if err := sh.reauth(ctx); err != nil {
return err
}
case req := <-sh.readC:
switch req := req.Message.(type) {
case *extensions_ssh.ClientMessage_Event:
@ -317,6 +341,7 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
// if there are no methods remaining, the user is allowed if all attempted
// methods have a valid response in the state
sh.state.InitialAuthComplete = true
log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
sh.sendAllowResponse()
} else {
@ -326,6 +351,13 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
return nil
}
func (sh *StreamHandler) reauth(ctx context.Context) error {
if !sh.state.InitialAuthComplete {
return nil
}
return sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
}
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) {
if hostname == "" {
return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")

View file

@ -100,7 +100,6 @@ type StreamHandlerSuite struct {
func (s *StreamHandlerSuite) SetupTest() {
s.ctrl = NewController(s.T())
s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl)
s.mgr = ssh.NewStreamManager()
s.cleanup = []func(){}
s.errC = make(chan error, 1)
@ -123,6 +122,8 @@ func (s *StreamHandlerSuite) SetupTest() {
for _, f := range s.ConfigModifiers {
f(s.cfg)
}
s.mgr = ssh.NewStreamManager(context.Background(), s.mockAuth, s.cfg)
}
func (s *StreamHandlerSuite) TearDownTest() {
@ -162,8 +163,7 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
}
func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler {
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
s.errC = make(chan error, 1)
ctx, ca := context.WithCancel(s.T().Context())
go func() {
@ -1997,8 +1997,7 @@ func (s *StreamHandlerSuite) TestFormatSession() {
s.mockAuth.EXPECT().
FormatSession(Any(), Any()).
Return([]byte("example"), nil)
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
// this will exit immediately, but it will have a state, which is only
@ -2014,8 +2013,7 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
s.mockAuth.EXPECT().
DeleteSession(Any(), Any()).
Return(nil)
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
// this will exit immediately, but it will have a state, which is only
@ -2027,8 +2025,7 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
}
func (s *StreamHandlerSuite) TestRunCalledTwice() {
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
sh.Run(ctx)
@ -2038,8 +2035,7 @@ func (s *StreamHandlerSuite) TestRunCalledTwice() {
}
func (s *StreamHandlerSuite) TestAllSSHRoutes() {
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
routes := slices.Collect(sh.AllSSHRoutes())
s.Len(routes, 2)
s.Equal("ssh://host1", routes[0].From)