mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
ssh: continuous authorization (#5687)
Re-evaluate ssh authorization decision on a fixed interval, or whenever the config changes. If access is no longer allowed, log a new 'authorize check' message and disconnect. Refactor the ssh.StreamManager initialization so that its lifecycle matches the Authorize lifecycle.
This commit is contained in:
parent
31020a75a6
commit
177677f239
9 changed files with 170 additions and 98 deletions
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/ssh"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||
)
|
||||
|
||||
|
@ -31,6 +32,7 @@ type Authorize struct {
|
|||
store *store.Store
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
ssh *ssh.StreamManager
|
||||
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
tracer oteltrace.Tracer
|
||||
|
@ -52,6 +54,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
tracer: tracer,
|
||||
}
|
||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||
a.ssh = ssh.NewStreamManager(ctx, ssh.NewAuth(a, a.currentConfig, a.tracerProvider), cfg)
|
||||
|
||||
state, err := newAuthorizeStateFromConfig(ctx, nil, tracerProvider, cfg, a.store)
|
||||
if err != nil {
|
||||
|
@ -161,4 +164,5 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
} else {
|
||||
a.state.Store(newState)
|
||||
}
|
||||
a.ssh.OnConfigChange(cfg)
|
||||
}
|
||||
|
|
|
@ -27,12 +27,7 @@ func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageSt
|
|||
return status.Errorf(codes.Internal, "first message was not a downstream connected event")
|
||||
}
|
||||
|
||||
state := a.state.Load()
|
||||
handler := state.ssh.NewStreamHandler(
|
||||
a.currentConfig.Load(),
|
||||
ssh.NewAuth(a, state.dataBrokerClient, a.currentConfig, a.tracerProvider),
|
||||
downstream,
|
||||
)
|
||||
handler := a.ssh.NewStreamHandler(downstream)
|
||||
defer handler.Close()
|
||||
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
|
@ -85,7 +80,7 @@ func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeCha
|
|||
} else {
|
||||
return status.Errorf(codes.Internal, "first message was not metadata")
|
||||
}
|
||||
handler := a.state.Load().ssh.LookupStream(streamID)
|
||||
handler := a.ssh.LookupStream(streamID)
|
||||
if handler == nil || !handler.IsExpectingInternalChannel() {
|
||||
return status.Errorf(codes.InvalidArgument, "stream not found")
|
||||
}
|
||||
|
@ -121,6 +116,8 @@ func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluat
|
|||
return nil, err
|
||||
}
|
||||
|
||||
skipLogging := req.LogOnlyIfDenied && res.Allow.Value && !res.Deny.Value
|
||||
if !skipLogging {
|
||||
s, _ := a.getDataBrokerSessionOrServiceAccount(ctx, req.SessionID, 0)
|
||||
|
||||
var u *user.User
|
||||
|
@ -128,6 +125,7 @@ func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluat
|
|||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId())
|
||||
}
|
||||
a.logAuthorizeCheck(ctx, &evalreq, res, s, u)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/ssh"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
|
@ -40,7 +39,6 @@ type authorizeState struct {
|
|||
authenticateFlow authenticateFlow
|
||||
syncQueriers map[string]storage.Querier
|
||||
mcp *mcp.Handler
|
||||
ssh *ssh.StreamManager
|
||||
}
|
||||
|
||||
func newAuthorizeStateFromConfig(
|
||||
|
@ -72,8 +70,6 @@ func newAuthorizeStateFromConfig(
|
|||
evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp))
|
||||
}
|
||||
|
||||
state.ssh = ssh.NewStreamManager()
|
||||
|
||||
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
||||
|
|
|
@ -30,8 +30,9 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
type PolicyEvaluator interface {
|
||||
type Evaluator interface {
|
||||
EvaluateSSH(context.Context, *Request) (*evaluator.Result, error)
|
||||
GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
|
@ -39,22 +40,22 @@ type Request struct {
|
|||
Hostname string
|
||||
PublicKey []byte
|
||||
SessionID string
|
||||
|
||||
LogOnlyIfDenied bool
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
evaluator PolicyEvaluator
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
evaluator Evaluator
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
}
|
||||
|
||||
func NewAuth(
|
||||
evaluator PolicyEvaluator,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
evaluator Evaluator,
|
||||
currentConfig *atomicutil.Value[*config.Config],
|
||||
tracerProvider oteltrace.TracerProvider,
|
||||
) *Auth {
|
||||
return &Auth{evaluator, client, currentConfig, tracerProvider}
|
||||
return &Auth{evaluator, currentConfig, tracerProvider}
|
||||
}
|
||||
|
||||
func (a *Auth) HandlePublicKeyMethodRequest(
|
||||
|
@ -93,7 +94,7 @@ func (a *Auth) handlePublicKeyMethodRequest(
|
|||
|
||||
// Special case: internal command (e.g. routes portal).
|
||||
if *info.Hostname == "" {
|
||||
_, err := session.Get(ctx, a.dataBrokerClient, sessionID)
|
||||
_, err := session.Get(ctx, a.evaluator.GetDataBrokerServiceClient(), sessionID)
|
||||
if status.Code(err) == codes.NotFound {
|
||||
// Require IdP login.
|
||||
return PublicKeyAuthMethodResponse{
|
||||
|
@ -229,7 +230,7 @@ func (a *Auth) handleLogin(
|
|||
return a.saveSession(ctx, sessionID, &sessionClaims, token)
|
||||
}
|
||||
|
||||
var errAccessDenied = errors.New("access denied")
|
||||
var errAccessDenied = status.Error(codes.PermissionDenied, "access denied")
|
||||
|
||||
func (a *Auth) EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error {
|
||||
req, err := sshRequestFromStreamAuthInfo(info)
|
||||
|
@ -252,7 +253,7 @@ func (a *Auth) FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := session.Get(ctx, a.dataBrokerClient, sessionID)
|
||||
session, err := session.Get(ctx, a.evaluator.GetDataBrokerServiceClient(), sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -269,7 +270,7 @@ func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = session.Delete(ctx, a.dataBrokerClient, sessionID)
|
||||
err = session.Delete(ctx, a.evaluator.GetDataBrokerServiceClient(), sessionID)
|
||||
a.invalidateCacheForRecord(ctx, &databroker.Record{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: sessionID,
|
||||
|
@ -304,7 +305,7 @@ func (a *Auth) saveSession(
|
|||
sess.SetRawIDToken(claims.RawIDToken)
|
||||
sess.AddClaims(claims.Flatten())
|
||||
|
||||
u, _ := user.Get(ctx, a.dataBrokerClient, sess.GetUserId())
|
||||
u, _ := user.Get(ctx, a.evaluator.GetDataBrokerServiceClient(), sess.GetUserId())
|
||||
if u == nil {
|
||||
// if no user exists yet, create a new one
|
||||
u = &user.User{
|
||||
|
@ -312,12 +313,12 @@ func (a *Auth) saveSession(
|
|||
}
|
||||
}
|
||||
u.PopulateFromClaims(claims.Claims)
|
||||
_, err := databroker.Put(ctx, a.dataBrokerClient, u)
|
||||
_, err := databroker.Put(ctx, a.evaluator.GetDataBrokerServiceClient(), u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := session.Put(ctx, a.dataBrokerClient, sess)
|
||||
resp, err := session.Put(ctx, a.evaluator.GetDataBrokerServiceClient(), sess)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -327,7 +328,9 @@ func (a *Auth) saveSession(
|
|||
|
||||
func (a *Auth) invalidateCacheForRecord(ctx context.Context, record *databroker.Record) {
|
||||
ctx = storage.WithQuerier(ctx,
|
||||
storage.NewCachingQuerier(storage.NewQuerier(a.dataBrokerClient), storage.GlobalCache))
|
||||
storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.evaluator.GetDataBrokerServiceClient()),
|
||||
storage.GlobalCache))
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, record)
|
||||
}
|
||||
|
||||
|
@ -375,6 +378,8 @@ func sshRequestFromStreamAuthInfo(info StreamAuthInfo) (*Request, error) {
|
|||
Hostname: *info.Hostname,
|
||||
PublicKey: info.PublicKeyAllow.Value.PublicKey,
|
||||
SessionID: sessionID,
|
||||
|
||||
LogOnlyIfDenied: info.InitialAuthComplete,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(context.Context, *Request) (*evaluator.Result, error) {
|
||||
pe := func(context.Context, *Request) (*evaluator.Result, error) {
|
||||
return nil, errors.New("error evaluating policy")
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
|
||||
_, err := a.handlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.ErrorContains(t, err, "error evaluating policy")
|
||||
})
|
||||
|
@ -54,7 +54,7 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
fakePublicKey := []byte("fake-public-key")
|
||||
req.PublicKey = fakePublicKey
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, r *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, r *Request) (*evaluator.Result, error) {
|
||||
assert.Equal(t, r, &Request{
|
||||
Username: "username",
|
||||
Hostname: "hostname",
|
||||
|
@ -65,8 +65,8 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, res.RequireAdditionalMethods)
|
||||
|
@ -80,13 +80,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(true),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res.Allow)
|
||||
|
@ -99,13 +99,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false, criteria.ReasonSSHPublickeyUnauthorized),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res.Allow)
|
||||
|
@ -118,13 +118,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{evaluateSSH: pe}, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res.Allow)
|
||||
|
@ -144,13 +144,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, client, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{pe, client}, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res.Allow)
|
||||
|
@ -179,13 +179,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, client, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{pe, client}, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res.Allow)
|
||||
|
@ -205,13 +205,13 @@ func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
|||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, client, nil, nil)
|
||||
}
|
||||
a := NewAuth(fakePolicyEvaluator{pe, client}, nil, nil)
|
||||
_, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.ErrorContains(t, err, "internal error")
|
||||
})
|
||||
|
@ -224,12 +224,12 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
|
|||
assert.ErrorContains(t, err, "expected PublicKeyAllow message not to be nil")
|
||||
})
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
var putRecords []*databroker.Record
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
|
@ -255,7 +255,7 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
|
|||
cfg.Options.ProviderURL = idpURL
|
||||
cfg.Options.ClientID = "client-id"
|
||||
cfg.Options.ClientSecret = "client-secret"
|
||||
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
|
||||
a := NewAuth(fakePolicyEvaluator{pe, client}, atomicutil.NewValue(&cfg), nil)
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
|
@ -281,12 +281,12 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
|
|||
assert.Equal(t, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=", putRecords[1].Id)
|
||||
})
|
||||
t.Run("denied", func(t *testing.T) {
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
pe := func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
|
||||
|
@ -310,7 +310,7 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
|
|||
cfg.Options.ProviderURL = idpURL
|
||||
cfg.Options.ClientID = "client-id"
|
||||
cfg.Options.ClientSecret = "client-secret"
|
||||
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
|
||||
a := NewAuth(fakePolicyEvaluator{pe, client}, atomicutil.NewValue(&cfg), nil)
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
|
@ -336,7 +336,7 @@ func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
|
|||
cfg.Options.ProviderURL = idpURL
|
||||
cfg.Options.ClientID = "client-id"
|
||||
cfg.Options.ClientSecret = "client-secret"
|
||||
a := NewAuth(nil, nil, atomicutil.NewValue(&cfg), nil)
|
||||
a := NewAuth(nil, atomicutil.NewValue(&cfg), nil)
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
|
@ -386,7 +386,7 @@ func TestFormatSession(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
}
|
||||
a := NewAuth(nil, client, nil, nil)
|
||||
a := NewAuth(fakePolicyEvaluator{client: client}, nil, nil)
|
||||
info := StreamAuthInfo{
|
||||
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
|
||||
}
|
||||
|
@ -424,7 +424,7 @@ func TestDeleteSession(t *testing.T) {
|
|||
return nil, putError
|
||||
},
|
||||
}
|
||||
a := NewAuth(nil, client, nil, nil)
|
||||
a := NewAuth(fakePolicyEvaluator{client: client}, nil, nil)
|
||||
info := StreamAuthInfo{
|
||||
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
|
||||
}
|
||||
|
@ -433,12 +433,17 @@ func TestDeleteSession(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
type policyEvaluatorFunc func(context.Context, *Request) (*evaluator.Result, error)
|
||||
type fakePolicyEvaluator struct {
|
||||
evaluateSSH func(context.Context, *Request) (*evaluator.Result, error)
|
||||
client databroker.DataBrokerServiceClient
|
||||
}
|
||||
|
||||
func (f policyEvaluatorFunc) EvaluateSSH(
|
||||
ctx context.Context, req *Request,
|
||||
) (*evaluator.Result, error) {
|
||||
return f(ctx, req)
|
||||
func (f fakePolicyEvaluator) EvaluateSSH(ctx context.Context, req *Request) (*evaluator.Result, error) {
|
||||
return f.evaluateSSH(ctx, req)
|
||||
}
|
||||
|
||||
func (f fakePolicyEvaluator) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return f.client
|
||||
}
|
||||
|
||||
type fakeDataBrokerServiceClient struct {
|
||||
|
@ -448,12 +453,12 @@ type fakeDataBrokerServiceClient struct {
|
|||
put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
|
||||
}
|
||||
|
||||
func (m fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
return m.get(ctx, in, opts...)
|
||||
func (f fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
return f.get(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (m fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
return m.put(ctx, in, opts...)
|
||||
func (f fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
return f.put(ctx, in, opts...)
|
||||
}
|
||||
|
||||
type noopQuerier struct{}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
||||
|
@ -8,29 +9,43 @@ import (
|
|||
)
|
||||
|
||||
type StreamManager struct {
|
||||
auth AuthInterface
|
||||
reauthC chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
cfg *config.Config
|
||||
activeStreams map[uint64]*StreamHandler
|
||||
}
|
||||
|
||||
func NewStreamManager() *StreamManager {
|
||||
return &StreamManager{
|
||||
func NewStreamManager(ctx context.Context, auth AuthInterface, cfg *config.Config) *StreamManager {
|
||||
sm := &StreamManager{
|
||||
auth: auth,
|
||||
reauthC: make(chan struct{}, 1),
|
||||
cfg: cfg,
|
||||
activeStreams: map[uint64]*StreamHandler{},
|
||||
}
|
||||
go sm.reauthLoop(ctx)
|
||||
return sm
|
||||
}
|
||||
|
||||
func (sm *StreamManager) OnConfigChange(cfg *config.Config) {
|
||||
sm.mu.Lock()
|
||||
sm.cfg = cfg
|
||||
sm.mu.Unlock()
|
||||
|
||||
select {
|
||||
case sm.reauthC <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StreamManager) LookupStream(streamID uint64) *StreamHandler {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
stream := sm.activeStreams[streamID]
|
||||
if stream == nil {
|
||||
return nil
|
||||
}
|
||||
return stream
|
||||
return sm.activeStreams[streamID]
|
||||
}
|
||||
|
||||
func (sm *StreamManager) NewStreamHandler(
|
||||
cfg *config.Config,
|
||||
auth AuthInterface,
|
||||
downstream *extensions_ssh.DownstreamConnectEvent,
|
||||
) *StreamHandler {
|
||||
sm.mu.Lock()
|
||||
|
@ -38,11 +53,12 @@ func (sm *StreamManager) NewStreamHandler(
|
|||
streamID := downstream.StreamId
|
||||
writeC := make(chan *extensions_ssh.ServerMessage, 32)
|
||||
sh := &StreamHandler{
|
||||
auth: auth,
|
||||
config: cfg,
|
||||
auth: sm.auth,
|
||||
config: sm.cfg,
|
||||
downstream: downstream,
|
||||
readC: make(chan *extensions_ssh.ClientMessage, 32),
|
||||
writeC: writeC,
|
||||
reauthC: make(chan struct{}),
|
||||
close: func() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
@ -53,3 +69,23 @@ func (sm *StreamManager) NewStreamHandler(
|
|||
sm.activeStreams[streamID] = sh
|
||||
return sh
|
||||
}
|
||||
|
||||
func (sm *StreamManager) reauthLoop(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-sm.reauthC:
|
||||
sm.mu.Lock()
|
||||
snapshot := make([]*StreamHandler, 0, len(sm.activeStreams))
|
||||
for _, s := range sm.activeStreams {
|
||||
snapshot = append(snapshot, s)
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
for _, s := range snapshot {
|
||||
s.Reauth()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,17 +22,17 @@ func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL {
|
|||
func TestStreamManager(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
auth := mock_ssh.NewMockAuthInterface(ctrl)
|
||||
m := ssh.NewStreamManager()
|
||||
|
||||
cfg := &config.Config{Options: config.NewDefaultOptions()}
|
||||
cfg.Options.Policies = []config.Policy{
|
||||
{From: "ssh://host1", To: mustParseWeightedURLs(t, "ssh://dest1:22")},
|
||||
{From: "ssh://host2", To: mustParseWeightedURLs(t, "ssh://dest2:22")},
|
||||
}
|
||||
m := ssh.NewStreamManager(t.Context(), auth, cfg)
|
||||
|
||||
t.Run("LookupStream", func(t *testing.T) {
|
||||
assert.Nil(t, m.LookupStream(1234))
|
||||
sh := m.NewStreamHandler(cfg, auth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
|
||||
sh := m.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
|
||||
assert.Equal(t, sh, m.LookupStream(1234))
|
||||
sh.Close()
|
||||
assert.Nil(t, m.LookupStream(1234))
|
||||
|
|
|
@ -3,6 +3,7 @@ package ssh
|
|||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"time"
|
||||
|
||||
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
@ -73,6 +74,7 @@ type StreamAuthInfo struct {
|
|||
PublicKeyFingerprintSha256 []byte
|
||||
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
|
||||
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
|
||||
InitialAuthComplete bool
|
||||
}
|
||||
|
||||
func (i *StreamAuthInfo) allMethodsValid() bool {
|
||||
|
@ -93,6 +95,7 @@ type StreamHandler struct {
|
|||
downstream *extensions_ssh.DownstreamConnectEvent
|
||||
writeC chan *extensions_ssh.ServerMessage
|
||||
readC chan *extensions_ssh.ClientMessage
|
||||
reauthC chan struct{}
|
||||
|
||||
state *StreamState
|
||||
close func()
|
||||
|
@ -119,6 +122,21 @@ func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
|
|||
return sh.writeC
|
||||
}
|
||||
|
||||
// Reauth blocks until authorization policy is reevaluated.
|
||||
func (sh *StreamHandler) Reauth() {
|
||||
sh.reauthC <- struct{}{}
|
||||
}
|
||||
|
||||
func (sh *StreamHandler) periodicReauth() (cancel func()) {
|
||||
t := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range t.C {
|
||||
sh.Reauth()
|
||||
}
|
||||
}()
|
||||
return t.Stop
|
||||
}
|
||||
|
||||
// Prompt implements KeyboardInteractiveQuerier.
|
||||
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
|
||||
sh.sendInfoPrompts(prompts)
|
||||
|
@ -154,10 +172,16 @@ func (sh *StreamHandler) Run(ctx context.Context) error {
|
|||
SourceAddress: sh.downstream.SourceAddress,
|
||||
},
|
||||
}
|
||||
cancelReauth := sh.periodicReauth()
|
||||
defer cancelReauth()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return context.Cause(ctx)
|
||||
case <-sh.reauthC:
|
||||
if err := sh.reauth(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
case req := <-sh.readC:
|
||||
switch req := req.Message.(type) {
|
||||
case *extensions_ssh.ClientMessage_Event:
|
||||
|
@ -317,6 +341,7 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
|
|||
if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
|
||||
// if there are no methods remaining, the user is allowed if all attempted
|
||||
// methods have a valid response in the state
|
||||
sh.state.InitialAuthComplete = true
|
||||
log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
|
||||
sh.sendAllowResponse()
|
||||
} else {
|
||||
|
@ -326,6 +351,13 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sh *StreamHandler) reauth(ctx context.Context) error {
|
||||
if !sh.state.InitialAuthComplete {
|
||||
return nil
|
||||
}
|
||||
return sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
|
||||
}
|
||||
|
||||
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) {
|
||||
if hostname == "" {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")
|
||||
|
|
|
@ -100,7 +100,6 @@ type StreamHandlerSuite struct {
|
|||
func (s *StreamHandlerSuite) SetupTest() {
|
||||
s.ctrl = NewController(s.T())
|
||||
s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl)
|
||||
s.mgr = ssh.NewStreamManager()
|
||||
s.cleanup = []func(){}
|
||||
s.errC = make(chan error, 1)
|
||||
|
||||
|
@ -123,6 +122,8 @@ func (s *StreamHandlerSuite) SetupTest() {
|
|||
for _, f := range s.ConfigModifiers {
|
||||
f(s.cfg)
|
||||
}
|
||||
|
||||
s.mgr = ssh.NewStreamManager(context.Background(), s.mockAuth, s.cfg)
|
||||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TearDownTest() {
|
||||
|
@ -162,8 +163,7 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler {
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
|
||||
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
|
||||
s.errC = make(chan error, 1)
|
||||
ctx, ca := context.WithCancel(s.T().Context())
|
||||
go func() {
|
||||
|
@ -1997,8 +1997,7 @@ func (s *StreamHandlerSuite) TestFormatSession() {
|
|||
s.mockAuth.EXPECT().
|
||||
FormatSession(Any(), Any()).
|
||||
Return([]byte("example"), nil)
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
ctx, ca := context.WithCancel(context.Background())
|
||||
ca()
|
||||
// this will exit immediately, but it will have a state, which is only
|
||||
|
@ -2014,8 +2013,7 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
|
|||
s.mockAuth.EXPECT().
|
||||
DeleteSession(Any(), Any()).
|
||||
Return(nil)
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
ctx, ca := context.WithCancel(context.Background())
|
||||
ca()
|
||||
// this will exit immediately, but it will have a state, which is only
|
||||
|
@ -2027,8 +2025,7 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TestRunCalledTwice() {
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
ctx, ca := context.WithCancel(context.Background())
|
||||
ca()
|
||||
sh.Run(ctx)
|
||||
|
@ -2038,8 +2035,7 @@ func (s *StreamHandlerSuite) TestRunCalledTwice() {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TestAllSSHRoutes() {
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(&extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
routes := slices.Collect(sh.AllSSHRoutes())
|
||||
s.Len(routes, 2)
|
||||
s.Equal("ssh://host1", routes[0].From)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue