ssh: implement authorization policy evaluation (#5665)

Implement the pkg/ssh.AuthInterface. Add logic for converting from the
ssh stream state to an evaluator request, and for interpreting the
results of policy evaluation. Refactor some of the existing authorize
logic to make it easier to reuse.
This commit is contained in:
Kenneth Jenkins 2025-07-01 12:04:00 -07:00 committed by GitHub
parent 9437cec21d
commit 9678e6a231
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1013 additions and 74 deletions

View file

@ -6,39 +6,19 @@ import (
oteltrace "go.opentelemetry.io/otel/trace"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
)
func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.TracerProvider, options *config.Options, idpID string) (identity.Authenticator, error) {
authenticateURL, err := options.GetAuthenticateURL()
redirectURL, err := options.GetAuthenticateRedirectURL()
if err != nil {
return nil, err
}
redirectURL, err := urlutil.DeepCopy(authenticateURL)
if err != nil {
return nil, err
}
redirectURL.Path = options.AuthenticateCallbackPath
idp, err := options.GetIdentityProviderForID(idpID)
if err != nil {
return nil, err
}
o := oauth.Options{
RedirectURL: redirectURL,
ProviderName: idp.GetType(),
ProviderURL: idp.GetUrl(),
ClientID: idp.GetClientId(),
ClientSecret: idp.GetClientSecret(),
Scopes: idp.GetScopes(),
AuthCodeOptions: idp.GetRequestParams(),
}
if v := idp.GetAccessTokenAllowedAudiences(); v != nil {
o.AccessTokenAllowedAudiences = &v.Values
}
return identity.NewAuthenticator(ctx, tracerProvider, o)
return identity.GetIdentityProvider(ctx, tracerProvider, idp, redirectURL)
}

View file

@ -44,15 +44,13 @@ func (a *Authorize) handleResult(
// when the user is unauthenticated it means they haven't
// logged in yet, so redirect to authenticate
if result.Allow.Reasons.Has(criteria.ReasonUserUnauthenticated) ||
result.Deny.Reasons.Has(criteria.ReasonUserUnauthenticated) {
if result.HasReason(criteria.ReasonUserUnauthenticated) {
return a.requireLoginResponse(ctx, in, request)
}
// when the user's device is unauthenticated it means they haven't
// registered a webauthn device yet, so redirect to the webauthn flow
if result.Allow.Reasons.Has(criteria.ReasonDeviceUnauthenticated) ||
result.Deny.Reasons.Has(criteria.ReasonDeviceUnauthenticated) {
if result.HasReason(criteria.ReasonDeviceUnauthenticated) {
return a.requireWebAuthnResponse(ctx, in, request, result)
}

View file

@ -149,6 +149,10 @@ type Result struct {
AdditionalLogFields map[log.AuthorizeLogField]any
}
func (r *Result) HasReason(reason criteria.Reason) bool {
return r.Allow.Reasons.Has(reason) || r.Deny.Reasons.Has(reason)
}
// An Evaluator evaluates policies.
type Evaluator struct {
evaluationCount, allowCount, denyCount metric.Int64Counter

View file

@ -1,6 +1,7 @@
package authorize
import (
"context"
"errors"
"io"
@ -9,7 +10,10 @@ import (
"google.golang.org/grpc/status"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/ssh"
)
func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageStreamServer) error {
@ -22,15 +26,16 @@ func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageSt
if downstream == nil {
return status.Errorf(codes.Internal, "first message was not a downstream connected event")
}
handler := a.state.Load().ssh.NewStreamHandler(a.currentConfig.Load(), downstream)
state := a.state.Load()
handler := state.ssh.NewStreamHandler(
a.currentConfig.Load(),
ssh.NewAuth(a, state.dataBrokerClient, a.currentConfig, a.tracerProvider),
downstream,
)
defer handler.Close()
eg, ctx := errgroup.WithContext(stream.Context())
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
storage.GlobalCache,
)
ctx = storage.WithQuerier(ctx, querier)
eg.Go(func() error {
for {
@ -87,3 +92,42 @@ func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeCha
return handler.ServeChannel(stream)
}
func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluator.Result, error) {
ctx = a.withQuerierForCheckRequest(ctx)
evalreq := evaluator.Request{
HTTP: evaluator.RequestHTTP{
Hostname: req.Hostname,
},
SSH: evaluator.RequestSSH{
Username: req.Username,
PublicKey: req.PublicKey,
},
Session: evaluator.RequestSession{
ID: req.SessionID,
},
}
if req.Hostname == "" {
evalreq.IsInternal = true
} else {
evalreq.Policy = a.currentConfig.Load().Options.GetRouteForSSHHostname(req.Hostname)
}
res, err := a.state.Load().evaluator.Evaluate(ctx, &evalreq)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("error during OPA evaluation")
return nil, err
}
s, _ := a.getDataBrokerSessionOrServiceAccount(ctx, req.SessionID, 0)
var u *user.User
if s != nil {
u, _ = a.getDataBrokerUser(ctx, s.GetUserId())
}
a.logAuthorizeCheck(ctx, &evalreq, res, s, u)
return res, nil
}

View file

@ -72,7 +72,7 @@ func newAuthorizeStateFromConfig(
evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp))
}
state.ssh = ssh.NewStreamManager(nil) // XXX
state.ssh = ssh.NewStreamManager()
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...)
if err != nil {

View file

@ -832,6 +832,21 @@ func (o *Options) GetInternalAuthenticateURL() (*url.URL, error) {
return urlutil.ParseAndValidateURL(o.AuthenticateInternalURLString)
}
func (o *Options) GetAuthenticateRedirectURL() (*url.URL, error) {
authenticateURL, err := o.GetAuthenticateURL()
if err != nil {
return nil, err
}
redirectURL, err := urlutil.DeepCopy(authenticateURL)
if err != nil {
return nil, err
}
redirectURL.Path = o.AuthenticateCallbackPath
return redirectURL, nil
}
// UseStatelessAuthenticateFlow returns true if the stateless authentication
// flow should be used (i.e. for hosted authenticate).
func (o *Options) UseStatelessAuthenticateFlow() bool {
@ -1054,6 +1069,19 @@ func (o *Options) NumPolicies() int {
return len(o.Policies) + len(o.Routes) + len(o.AdditionalPolicies)
}
func (o *Options) GetRouteForSSHHostname(hostname string) *Policy {
if hostname == "" {
return nil
}
from := "ssh://" + hostname
for r := range o.GetAllPolicies() {
if r.From == from {
return r
}
}
return nil
}
// GetMetricsBasicAuth gets the metrics basic auth username and password.
func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) {
if o.MetricsBasicAuth == "" {

View file

@ -5,18 +5,14 @@ package authenticateflow
import (
"context"
"fmt"
"time"
oteltrace "go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
)
@ -25,21 +21,6 @@ var timeNow = time.Now
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
func populateUserFromClaims(u *user.User, claims map[string]any) {
if v, ok := claims["name"]; ok {
u.Name = fmt.Sprint(v)
}
if v, ok := claims["email"]; ok {
u.Email = fmt.Sprint(v)
}
if u.Claims == nil {
u.Claims = make(map[string]*structpb.ListValue)
}
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
u.Claims[k] = vs
}
}
var outboundDatabrokerTraceClientOpts = []trace.ClientStatsHandlerOption{
trace.WithStatsInterceptor(ignoreNotFoundErrors),
}

View file

@ -208,7 +208,7 @@ func (s *Stateful) PersistSession(
Id: sess.GetUserId(),
}
}
populateUserFromClaims(u, claims.Claims)
u.PopulateFromClaims(claims.Claims)
_, err := databroker.Put(ctx, s.dataBrokerClient, u)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)

View file

@ -422,7 +422,7 @@ func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error {
if err != nil {
u = &user.User{Id: ss.UserID()}
}
populateUserFromClaims(u, profile.GetClaims().AsMap())
u.PopulateFromClaims(profile.Claims.AsMap())
redirectURI, err := getRedirectURIFromValues(values)
if err != nil {

View file

@ -72,6 +72,7 @@ func New(cfg Config) *IDP {
publicJWK: publicJWK,
signingKey: signingKey,
userLookup: userLookup,
enableDeviceAuth: cfg.EnableDeviceAuth,
}
}

View file

@ -16,7 +16,7 @@ import (
)
var (
envoyVersion = "1.34.1-rc1"
envoyVersion = "1.34.1-rc3"
targets = []string{
"darwin-amd64",
"darwin-arm64",

View file

@ -48,6 +48,17 @@ func (x *ServiceAccount) Validate() error {
return nil
}
// PopulateFromClaims sets the Name, Email, and Claims fields from a claims map.
func (x *User) PopulateFromClaims(claims map[string]any) {
if v, ok := claims["name"]; ok {
x.Name = fmt.Sprint(v)
}
if v, ok := claims["email"]; ok {
x.Email = fmt.Sprint(v)
}
x.AddClaims(identity.Claims(claims).Flatten())
}
// AddClaims adds the flattened claims to the user.
func (x *User) AddClaims(claims identity.FlattenedClaims) {
if x.Claims == nil {

View file

@ -6,11 +6,13 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/identity/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
"github.com/pomerium/pomerium/pkg/identity/oauth/apple"
@ -92,3 +94,24 @@ func NewAuthenticator(ctx context.Context, tracerProvider oteltrace.TracerProvid
return ctor(ctx, &o)
}
func GetIdentityProvider(
ctx context.Context,
tracerProvider oteltrace.TracerProvider,
idp *identitypb.Provider,
redirectURL *url.URL,
) (Authenticator, error) {
o := oauth.Options{
RedirectURL: redirectURL,
ProviderName: idp.GetType(),
ProviderURL: idp.GetUrl(),
ClientID: idp.GetClientId(),
ClientSecret: idp.GetClientSecret(),
Scopes: idp.GetScopes(),
AuthCodeOptions: idp.GetRequestParams(),
}
if v := idp.GetAccessTokenAllowedAudiences(); v != nil {
o.AccessTokenAllowedAudiences = &v.Values
}
return NewAuthenticator(ctx, tracerProvider, o)
}

389
pkg/ssh/auth.go Normal file
View file

@ -0,0 +1,389 @@
package ssh
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"text/template"
"time"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/storage"
)
type PolicyEvaluator interface {
EvaluateSSH(context.Context, *Request) (*evaluator.Result, error)
}
type Request struct {
Username string
Hostname string
PublicKey []byte
SessionID string
}
type Auth struct {
evaluator PolicyEvaluator
dataBrokerClient databroker.DataBrokerServiceClient
currentConfig *atomicutil.Value[*config.Config]
tracerProvider oteltrace.TracerProvider
}
func NewAuth(
evaluator PolicyEvaluator,
client databroker.DataBrokerServiceClient,
currentConfig *atomicutil.Value[*config.Config],
tracerProvider oteltrace.TracerProvider,
) *Auth {
return &Auth{evaluator, client, currentConfig, tracerProvider}
}
func (a *Auth) HandlePublicKeyMethodRequest(
ctx context.Context,
info StreamAuthInfo,
req *extensions_ssh.PublicKeyMethodRequest,
) (PublicKeyAuthMethodResponse, error) {
resp, err := a.handlePublicKeyMethodRequest(ctx, info, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("ssh publickey auth request error")
return resp, status.Error(codes.Internal, "internal error")
}
return resp, err
}
func (a *Auth) handlePublicKeyMethodRequest(
ctx context.Context,
info StreamAuthInfo,
req *extensions_ssh.PublicKeyMethodRequest,
) (PublicKeyAuthMethodResponse, error) {
sessionID, err := sessionIDFromFingerprint(req.PublicKeyFingerprintSha256)
if err != nil {
return PublicKeyAuthMethodResponse{}, err
}
sshreq := &Request{
Username: *info.Username,
Hostname: *info.Hostname,
PublicKey: req.PublicKey,
SessionID: sessionID,
}
log.Ctx(ctx).Debug().
Str("username", *info.Username).
Str("hostname", *info.Hostname).
Str("session-id", sessionID).
Msg("ssh publickey auth request")
// Special case: internal command (e.g. routes portal).
if *info.Hostname == "" {
_, err := session.Get(ctx, a.dataBrokerClient, sessionID)
if status.Code(err) == codes.NotFound {
// Require IdP login.
return PublicKeyAuthMethodResponse{
Allow: publicKeyAllowResponse(req.PublicKey),
RequireAdditionalMethods: []string{MethodKeyboardInteractive},
}, nil
} else if err != nil {
return PublicKeyAuthMethodResponse{}, err
}
}
res, err := a.evaluator.EvaluateSSH(ctx, sshreq)
if err != nil {
return PublicKeyAuthMethodResponse{}, err
}
// Interpret the results of policy evaluation.
if res.HasReason(criteria.ReasonSSHPublickeyUnauthorized) {
// This public key is not allowed, but the client is free to try a different key.
return PublicKeyAuthMethodResponse{
RequireAdditionalMethods: []string{MethodPublicKey},
}, nil
} else if res.HasReason(criteria.ReasonUserUnauthenticated) {
// Mark public key as allowed, to initiate IdP login flow.
return PublicKeyAuthMethodResponse{
Allow: publicKeyAllowResponse(req.PublicKey),
RequireAdditionalMethods: []string{MethodKeyboardInteractive},
}, nil
} else if res.Allow.Value && !res.Deny.Value {
// Allowed, no login needed.
return PublicKeyAuthMethodResponse{
Allow: publicKeyAllowResponse(req.PublicKey),
}, nil
}
// Denied, no login needed.
return PublicKeyAuthMethodResponse{}, nil
}
func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse {
return &extensions_ssh.PublicKeyAllowResponse{
PublicKey: publicKey,
Permissions: &extensions_ssh.Permissions{
PermitPortForwarding: true,
PermitAgentForwarding: true,
PermitX11Forwarding: true,
PermitPty: true,
PermitUserRc: true,
ValidStartTime: timestamppb.New(time.Now().Add(-1 * time.Minute)),
ValidEndTime: timestamppb.New(time.Now().Add(1 * time.Hour)),
},
}
}
func (a *Auth) HandleKeyboardInteractiveMethodRequest(
ctx context.Context,
info StreamAuthInfo,
_ *extensions_ssh.KeyboardInteractiveMethodRequest,
querier KeyboardInteractiveQuerier,
) (KeyboardInteractiveAuthMethodResponse, error) {
resp, err := a.handleKeyboardInteractiveMethodRequest(ctx, info, querier)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("ssh keyboard-interactive auth request error")
return resp, status.Error(codes.Internal, "internal error")
}
return resp, err
}
func (a *Auth) handleKeyboardInteractiveMethodRequest(
ctx context.Context,
info StreamAuthInfo,
querier KeyboardInteractiveQuerier,
) (KeyboardInteractiveAuthMethodResponse, error) {
if info.PublicKeyAllow.Value == nil {
// Sanity check: this method is only valid if we already accepted a public key.
return KeyboardInteractiveAuthMethodResponse{}, errPublicKeyAllowNil
}
log.Ctx(ctx).Debug().
Str("username", *info.Username).
Str("hostname", *info.Hostname).
Str("publickey-fingerprint", base64.StdEncoding.EncodeToString(info.PublicKeyFingerprintSha256)).
Msg("ssh keyboard-interactive auth request")
// Initiate the IdP login flow.
err := a.handleLogin(ctx, *info.Hostname, info.PublicKeyFingerprintSha256, querier)
if err != nil {
return KeyboardInteractiveAuthMethodResponse{}, err
}
if err := a.EvaluateDelayed(ctx, info); err != nil {
// Denied.
return KeyboardInteractiveAuthMethodResponse{}, nil
}
// Allowed.
return KeyboardInteractiveAuthMethodResponse{
Allow: &extensions_ssh.KeyboardInteractiveAllowResponse{},
}, nil
}
func (a *Auth) handleLogin(
ctx context.Context,
hostname string,
publicKeyFingerprint []byte,
querier KeyboardInteractiveQuerier,
) error {
// Initiate the IdP login flow.
authenticator, err := a.getAuthenticator(ctx, hostname)
if err != nil {
return err
}
resp, err := authenticator.DeviceAuth(ctx)
if err != nil {
return err
}
// Prompt the user to sign in.
_, _ = querier.Prompt(ctx, &extensions_ssh.KeyboardInteractiveInfoPrompts{
Name: "Please sign in with " + authenticator.Name() + " to continue",
Instruction: resp.VerificationURIComplete,
Prompts: nil,
})
var sessionClaims identity.SessionClaims
token, err := authenticator.DeviceAccessToken(ctx, resp, &sessionClaims)
if err != nil {
return err
}
sessionID, err := sessionIDFromFingerprint(publicKeyFingerprint)
if err != nil {
return err
}
return a.saveSession(ctx, sessionID, &sessionClaims, token)
}
var errAccessDenied = errors.New("access denied")
func (a *Auth) EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error {
req, err := sshRequestFromStreamAuthInfo(info)
if err != nil {
return err
}
res, err := a.evaluator.EvaluateSSH(ctx, req)
if err != nil {
return err
}
if res.Allow.Value && !res.Deny.Value {
return nil
}
return errAccessDenied
}
func (a *Auth) FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error) {
sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256)
if err != nil {
return nil, err
}
session, err := session.Get(ctx, a.dataBrokerClient, sessionID)
if err != nil {
return nil, err
}
var b bytes.Buffer
err = sessionInfoTmpl.Execute(&b, session)
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error {
sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256)
if err != nil {
return err
}
err = session.Delete(ctx, a.dataBrokerClient, sessionID)
a.invalidateCacheForRecord(ctx, &databroker.Record{
Type: "type.googleapis.com/session.Session",
Id: sessionID,
})
return err
}
func (a *Auth) saveSession(
ctx context.Context,
id string,
claims *identity.SessionClaims,
token *oauth2.Token,
) error {
now := time.Now()
nowpb := timestamppb.New(now)
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
state := sessions.State{ID: id}
if err := claims.Claims.Claims(&state); err != nil {
return err
}
sess := &session.Session{
Id: id,
UserId: state.UserID(),
IssuedAt: nowpb,
AccessedAt: nowpb,
ExpiresAt: timestamppb.New(now.Add(sessionLifetime)),
OauthToken: manager.ToOAuthToken(token),
Audience: state.Audience,
}
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
u, _ := user.Get(ctx, a.dataBrokerClient, sess.GetUserId())
if u == nil {
// if no user exists yet, create a new one
u = &user.User{
Id: sess.GetUserId(),
}
}
u.PopulateFromClaims(claims.Claims)
_, err := databroker.Put(ctx, a.dataBrokerClient, u)
if err != nil {
return err
}
resp, err := session.Put(ctx, a.dataBrokerClient, sess)
if err != nil {
return err
}
a.invalidateCacheForRecord(ctx, resp.GetRecord())
return nil
}
func (a *Auth) invalidateCacheForRecord(ctx context.Context, record *databroker.Record) {
ctx = storage.WithQuerier(ctx,
storage.NewCachingQuerier(storage.NewQuerier(a.dataBrokerClient), storage.GlobalCache))
storage.InvalidateCacheForDataBrokerRecords(ctx, record)
}
func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (identity.Authenticator, error) {
opts := a.currentConfig.Load().Options
redirectURL, err := opts.GetAuthenticateRedirectURL()
if err != nil {
return nil, err
}
idp, err := opts.GetIdentityProviderForPolicy(opts.GetRouteForSSHHostname(hostname))
if err != nil {
return nil, err
}
return identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL)
}
var _ AuthInterface = (*Auth)(nil)
var errInvalidFingerprint = errors.New("invalid public key fingerprint")
func sessionIDFromFingerprint(sha256fingerprint []byte) (string, error) {
if len(sha256fingerprint) != sha256.Size {
return "", errInvalidFingerprint
}
return "sshkey-SHA256:" + base64.StdEncoding.EncodeToString(sha256fingerprint), nil
}
var errPublicKeyAllowNil = errors.New("expected PublicKeyAllow message not to be nil")
// Converts from StreamAuthInfo to an SSHRequest, assuming the PublicKeyAllow field is not nil.
func sshRequestFromStreamAuthInfo(info StreamAuthInfo) (*Request, error) {
if info.PublicKeyAllow.Value == nil {
return nil, errPublicKeyAllowNil
}
sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256)
if err != nil {
return nil, err
}
return &Request{
Username: *info.Username,
Hostname: *info.Hostname,
PublicKey: info.PublicKeyAllow.Value.PublicKey,
SessionID: sessionID,
}, nil
}
var sessionInfoTmpl = template.Must(template.New("session-info").Parse(`
User ID: {{.UserId}}
Session ID: {{.Id}}
Expires at: {{.ExpiresAt.AsTime}}
Claims:
{{- range $k, $v := .Claims }}
{{ $k }}: {{ $v.AsSlice }}
{{- end }}
`))

469
pkg/ssh/auth_test.go Normal file
View file

@ -0,0 +1,469 @@
package ssh
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/testutil/mockidp"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestHandlePublicKeyMethodRequest(t *testing.T) {
t.Run("no public key fingerprint", func(t *testing.T) {
var a Auth
var req extensions_ssh.PublicKeyMethodRequest
_, err := a.handlePublicKeyMethodRequest(t.Context(), StreamAuthInfo{}, &req)
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
t.Run("evaluate error", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(context.Context, *Request) (*evaluator.Result, error) {
return nil, errors.New("error evaluating policy")
})
a := NewAuth(pe, nil, nil, nil)
_, err := a.handlePublicKeyMethodRequest(t.Context(), info, &req)
assert.ErrorContains(t, err, "error evaluating policy")
})
t.Run("allow", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
fakePublicKey := []byte("fake-public-key")
req.PublicKey = fakePublicKey
pe := policyEvaluatorFunc(func(_ context.Context, r *Request) (*evaluator.Result, error) {
assert.Equal(t, r, &Request{
Username: "username",
Hostname: "hostname",
PublicKey: fakePublicKey,
SessionID: "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=",
})
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Empty(t, res.RequireAdditionalMethods)
require.NotNil(t, res.Allow)
assert.Equal(t, res.Allow.PublicKey, fakePublicKey)
})
t.Run("deny", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(true),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Nil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
})
t.Run("public key unauthorized", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false, criteria.ReasonSSHPublickeyUnauthorized),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Nil(t, res.Allow)
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodPublicKey})
})
t.Run("needs login", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive})
})
t.Run("internal command no session", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
},
}
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr(""),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
}, nil
})
a := NewAuth(pe, client, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive})
})
t.Run("internal command with session", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
Record: &databroker.Record{
Type: "type.googleapis.com/session.Session",
Id: "abc",
Data: protoutil.NewAny(&session.Session{
Id: "abc",
UserId: "USER-ID",
}),
},
}, nil
},
}
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr(""),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, client, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
})
t.Run("internal command databroker error", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.Unknown, "unknown")
},
}
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr(""),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, client, nil, nil)
_, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.ErrorContains(t, err, "internal error")
})
}
func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
t.Run("no public key", func(t *testing.T) {
var a Auth
_, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), StreamAuthInfo{}, nil)
assert.ErrorContains(t, err, "expected PublicKeyAllow message not to be nil")
})
t.Run("ok", func(t *testing.T) {
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
var putRecords []*databroker.Record
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
},
put: func(
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
) (*databroker.PutResponse, error) {
putRecords = append(putRecords, in.Records...)
return &databroker.PutResponse{
Records: in.Records,
}, nil
},
}
cfg := config.Config{
Options: config.NewDefaultOptions(),
}
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
idpURL := mockIDP.Start(t)
cfg.Options.Provider = "oidc"
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
Value: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: []byte("fake-public-key"),
},
},
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{})
require.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
// A new Session and User record should have been saved to the databroker.
assert.Len(t, putRecords, 2)
assert.Equal(t, "type.googleapis.com/user.User", putRecords[0].Type)
assert.Equal(t, "fake.user@example.com", putRecords[0].Id)
assert.Equal(t, "type.googleapis.com/session.Session", putRecords[1].Type)
assert.Equal(t, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=", putRecords[1].Id)
})
t.Run("denied", func(t *testing.T) {
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false),
}, nil
})
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
},
put: func(
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
) (*databroker.PutResponse, error) {
return &databroker.PutResponse{
Records: in.Records,
}, nil
},
}
cfg := config.Config{
Options: config.NewDefaultOptions(),
}
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
idpURL := mockIDP.Start(t)
cfg.Options.Provider = "oidc"
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
Value: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: []byte("fake-public-key"),
},
},
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{})
require.NoError(t, err)
assert.Nil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
})
t.Run("invalid fingerprint", func(t *testing.T) {
cfg := config.Config{
Options: config.NewDefaultOptions(),
}
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
idpURL := mockIDP.Start(t)
cfg.Options.Provider = "oidc"
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(nil, nil, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
Value: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: []byte("fake-public-key"),
},
},
}
_, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), info, noopQuerier{})
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
}
func TestFormatSession(t *testing.T) {
t.Run("invalid fingerprint", func(t *testing.T) {
var a Auth
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("wrong-length"),
}
_, err := a.FormatSession(t.Context(), info)
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
t.Run("ok", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, in *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
const expectedID = "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY="
assert.Equal(t, in.Type, "type.googleapis.com/session.Session")
assert.Equal(t, in.Id, expectedID)
claims := identity.FlattenedClaims{
"foo": []any{"bar", "baz"},
"quux": []any{42},
}
return &databroker.GetResponse{
Record: &databroker.Record{
Type: "type.googleapis.com/session.Session",
Id: expectedID,
Data: protoutil.NewAny(&session.Session{
Id: expectedID,
UserId: "USER-ID",
ExpiresAt: &timestamppb.Timestamp{Seconds: 1750965358},
Claims: claims.ToPB(),
}),
},
}, nil
},
}
a := NewAuth(nil, client, nil, nil)
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
b, err := a.FormatSession(t.Context(), info)
assert.NoError(t, err)
assert.Equal(t, string(b), `
User ID: USER-ID
Session ID: sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=
Expires at: 2025-06-26 19:15:58 +0000 UTC
Claims:
foo: [bar baz]
quux: [42]
`)
})
}
func TestDeleteSession(t *testing.T) {
t.Run("invalid fingerprint", func(t *testing.T) {
var a Auth
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("wrong-length"),
}
err := a.DeleteSession(t.Context(), info)
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
t.Run("ok", func(t *testing.T) {
putError := errors.New("sentinel")
client := fakeDataBrokerServiceClient{
put: func(
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
) (*databroker.PutResponse, error) {
require.Len(t, in.Records, 1)
assert.Equal(t, in.Records[0].Id, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=")
assert.NotNil(t, in.Records[0].DeletedAt)
return nil, putError
},
}
a := NewAuth(nil, client, nil, nil)
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
err := a.DeleteSession(t.Context(), info)
assert.Equal(t, putError, err)
})
}
type policyEvaluatorFunc func(context.Context, *Request) (*evaluator.Result, error)
func (f policyEvaluatorFunc) EvaluateSSH(
ctx context.Context, req *Request,
) (*evaluator.Result, error) {
return f(ctx, req)
}
type fakeDataBrokerServiceClient struct {
databroker.DataBrokerServiceClient
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
}
func (m fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return m.get(ctx, in, opts...)
}
func (m fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
return m.put(ctx, in, opts...)
}
type noopQuerier struct{}
func (noopQuerier) Prompt(
_ context.Context, _ *extensions_ssh.KeyboardInteractiveInfoPrompts,
) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
return nil, nil
}
func ptr[T any](t T) *T {
return &t
}

View file

@ -8,14 +8,12 @@ import (
)
type StreamManager struct {
auth AuthInterface
mu sync.Mutex
activeStreams map[uint64]*StreamHandler
}
func NewStreamManager(auth AuthInterface) *StreamManager {
func NewStreamManager() *StreamManager {
return &StreamManager{
auth: auth,
activeStreams: map[uint64]*StreamHandler{},
}
}
@ -30,13 +28,17 @@ func (sm *StreamManager) LookupStream(streamID uint64) *StreamHandler {
return stream
}
func (sm *StreamManager) NewStreamHandler(cfg *config.Config, downstream *extensions_ssh.DownstreamConnectEvent) *StreamHandler {
func (sm *StreamManager) NewStreamHandler(
cfg *config.Config,
auth AuthInterface,
downstream *extensions_ssh.DownstreamConnectEvent,
) *StreamHandler {
sm.mu.Lock()
defer sm.mu.Unlock()
streamID := downstream.StreamId
writeC := make(chan *extensions_ssh.ServerMessage, 32)
sh := &StreamHandler{
auth: sm.auth,
auth: auth,
config: cfg,
downstream: downstream,
readC: make(chan *extensions_ssh.ClientMessage, 32),

View file

@ -22,7 +22,7 @@ func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL {
func TestStreamManager(t *testing.T) {
ctrl := gomock.NewController(t)
auth := mock_ssh.NewMockAuthInterface(ctrl)
m := ssh.NewStreamManager(auth)
m := ssh.NewStreamManager()
cfg := &config.Config{Options: config.NewDefaultOptions()}
cfg.Options.Policies = []config.Policy{
@ -32,7 +32,7 @@ func TestStreamManager(t *testing.T) {
t.Run("LookupStream", func(t *testing.T) {
assert.Nil(t, m.LookupStream(1234))
sh := m.NewStreamHandler(cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
sh := m.NewStreamHandler(cfg, auth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
assert.Equal(t, sh, m.LookupStream(1234))
sh.Close()
assert.Nil(t, m.LookupStream(1234))

View file

@ -3,7 +3,7 @@
//
// Generated by this command:
//
// mockgen -typed . AuthInterface
// mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
//
// Package mock_ssh is a generated GoMock package.

View file

@ -37,6 +37,8 @@ type (
KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
)
//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
type AuthInterface interface {
HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
@ -284,8 +286,10 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
if err != nil {
return err
} else if response.Allow != nil {
partial = true
sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
}
partial = response.Allow != nil
sh.state.PublicKeyAllow.Update(response.Allow)
updateMethods(response.RequireAdditionalMethods)
case MethodKeyboardInteractive:

View file

@ -100,7 +100,7 @@ type StreamHandlerSuite struct {
func (s *StreamHandlerSuite) SetupTest() {
s.ctrl = NewController(s.T())
s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl)
s.mgr = ssh.NewStreamManager(s.mockAuth)
s.mgr = ssh.NewStreamManager()
s.cleanup = []func(){}
s.errC = make(chan error, 1)
@ -162,7 +162,8 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
}
func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler {
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
s.errC = make(chan error, 1)
ctx, ca := context.WithCancel(s.T().Context())
go func() {
@ -1996,7 +1997,8 @@ func (s *StreamHandlerSuite) TestFormatSession() {
s.mockAuth.EXPECT().
FormatSession(Any(), Any()).
Return([]byte("example"), nil)
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
// this will exit immediately, but it will have a state, which is only
@ -2012,7 +2014,8 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
s.mockAuth.EXPECT().
DeleteSession(Any(), Any()).
Return(nil)
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
// this will exit immediately, but it will have a state, which is only
@ -2024,7 +2027,8 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
}
func (s *StreamHandlerSuite) TestRunCalledTwice() {
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
ctx, ca := context.WithCancel(context.Background())
ca()
sh.Run(ctx)
@ -2034,7 +2038,8 @@ func (s *StreamHandlerSuite) TestRunCalledTwice() {
}
func (s *StreamHandlerSuite) TestAllSSHRoutes() {
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
sh := s.mgr.NewStreamHandler(
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
routes := slices.Collect(sh.AllSSHRoutes())
s.Len(routes, 2)
s.Equal("ssh://host1", routes[0].From)