diff --git a/authenticate/identity.go b/authenticate/identity.go index befa18f73..87cb004bd 100644 --- a/authenticate/identity.go +++ b/authenticate/identity.go @@ -6,39 +6,19 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/identity" - "github.com/pomerium/pomerium/pkg/identity/oauth" ) func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.TracerProvider, options *config.Options, idpID string) (identity.Authenticator, error) { - authenticateURL, err := options.GetAuthenticateURL() + redirectURL, err := options.GetAuthenticateRedirectURL() if err != nil { return nil, err } - redirectURL, err := urlutil.DeepCopy(authenticateURL) - if err != nil { - return nil, err - } - redirectURL.Path = options.AuthenticateCallbackPath - idp, err := options.GetIdentityProviderForID(idpID) if err != nil { return nil, err } - o := oauth.Options{ - RedirectURL: redirectURL, - ProviderName: idp.GetType(), - ProviderURL: idp.GetUrl(), - ClientID: idp.GetClientId(), - ClientSecret: idp.GetClientSecret(), - Scopes: idp.GetScopes(), - AuthCodeOptions: idp.GetRequestParams(), - } - if v := idp.GetAccessTokenAllowedAudiences(); v != nil { - o.AccessTokenAllowedAudiences = &v.Values - } - return identity.NewAuthenticator(ctx, tracerProvider, o) + return identity.GetIdentityProvider(ctx, tracerProvider, idp, redirectURL) } diff --git a/authorize/check_response.go b/authorize/check_response.go index 4665eecae..8eb532f65 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -44,15 +44,13 @@ func (a *Authorize) handleResult( // when the user is unauthenticated it means they haven't // logged in yet, so redirect to authenticate - if result.Allow.Reasons.Has(criteria.ReasonUserUnauthenticated) || - result.Deny.Reasons.Has(criteria.ReasonUserUnauthenticated) { + if result.HasReason(criteria.ReasonUserUnauthenticated) { return a.requireLoginResponse(ctx, in, request) } // when the user's device is unauthenticated it means they haven't // registered a webauthn device yet, so redirect to the webauthn flow - if result.Allow.Reasons.Has(criteria.ReasonDeviceUnauthenticated) || - result.Deny.Reasons.Has(criteria.ReasonDeviceUnauthenticated) { + if result.HasReason(criteria.ReasonDeviceUnauthenticated) { return a.requireWebAuthnResponse(ctx, in, request, result) } diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index c5ca31336..dea780e43 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -149,6 +149,10 @@ type Result struct { AdditionalLogFields map[log.AuthorizeLogField]any } +func (r *Result) HasReason(reason criteria.Reason) bool { + return r.Allow.Reasons.Has(reason) || r.Deny.Reasons.Has(reason) +} + // An Evaluator evaluates policies. type Evaluator struct { evaluationCount, allowCount, denyCount metric.Int64Counter diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index 692c876fb..3074ea255 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -1,6 +1,7 @@ package authorize import ( + "context" "errors" "io" @@ -9,7 +10,10 @@ import ( "google.golang.org/grpc/status" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" - "github.com/pomerium/pomerium/pkg/storage" + "github.com/pomerium/pomerium/authorize/evaluator" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/ssh" ) func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageStreamServer) error { @@ -22,15 +26,16 @@ func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageSt if downstream == nil { return status.Errorf(codes.Internal, "first message was not a downstream connected event") } - handler := a.state.Load().ssh.NewStreamHandler(a.currentConfig.Load(), downstream) + + state := a.state.Load() + handler := state.ssh.NewStreamHandler( + a.currentConfig.Load(), + ssh.NewAuth(a, state.dataBrokerClient, a.currentConfig, a.tracerProvider), + downstream, + ) defer handler.Close() eg, ctx := errgroup.WithContext(stream.Context()) - querier := storage.NewCachingQuerier( - storage.NewQuerier(a.state.Load().dataBrokerClient), - storage.GlobalCache, - ) - ctx = storage.WithQuerier(ctx, querier) eg.Go(func() error { for { @@ -87,3 +92,42 @@ func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeCha return handler.ServeChannel(stream) } + +func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluator.Result, error) { + ctx = a.withQuerierForCheckRequest(ctx) + + evalreq := evaluator.Request{ + HTTP: evaluator.RequestHTTP{ + Hostname: req.Hostname, + }, + SSH: evaluator.RequestSSH{ + Username: req.Username, + PublicKey: req.PublicKey, + }, + Session: evaluator.RequestSession{ + ID: req.SessionID, + }, + } + + if req.Hostname == "" { + evalreq.IsInternal = true + } else { + evalreq.Policy = a.currentConfig.Load().Options.GetRouteForSSHHostname(req.Hostname) + } + + res, err := a.state.Load().evaluator.Evaluate(ctx, &evalreq) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("error during OPA evaluation") + return nil, err + } + + s, _ := a.getDataBrokerSessionOrServiceAccount(ctx, req.SessionID, 0) + + var u *user.User + if s != nil { + u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) + } + a.logAuthorizeCheck(ctx, &evalreq, res, s, u) + + return res, nil +} diff --git a/authorize/state.go b/authorize/state.go index bc0a37364..e6b77efe6 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -72,7 +72,7 @@ func newAuthorizeStateFromConfig( evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp)) } - state.ssh = ssh.NewStreamManager(nil) // XXX + state.ssh = ssh.NewStreamManager() state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...) if err != nil { diff --git a/config/options.go b/config/options.go index e628d587d..798dc7713 100644 --- a/config/options.go +++ b/config/options.go @@ -832,6 +832,21 @@ func (o *Options) GetInternalAuthenticateURL() (*url.URL, error) { return urlutil.ParseAndValidateURL(o.AuthenticateInternalURLString) } +func (o *Options) GetAuthenticateRedirectURL() (*url.URL, error) { + authenticateURL, err := o.GetAuthenticateURL() + if err != nil { + return nil, err + } + + redirectURL, err := urlutil.DeepCopy(authenticateURL) + if err != nil { + return nil, err + } + redirectURL.Path = o.AuthenticateCallbackPath + + return redirectURL, nil +} + // UseStatelessAuthenticateFlow returns true if the stateless authentication // flow should be used (i.e. for hosted authenticate). func (o *Options) UseStatelessAuthenticateFlow() bool { @@ -1054,6 +1069,19 @@ func (o *Options) NumPolicies() int { return len(o.Policies) + len(o.Routes) + len(o.AdditionalPolicies) } +func (o *Options) GetRouteForSSHHostname(hostname string) *Policy { + if hostname == "" { + return nil + } + from := "ssh://" + hostname + for r := range o.GetAllPolicies() { + if r.From == from { + return r + } + } + return nil +} + // GetMetricsBasicAuth gets the metrics basic auth username and password. func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) { if o.MetricsBasicAuth == "" { diff --git a/internal/authenticateflow/authenticateflow.go b/internal/authenticateflow/authenticateflow.go index 3e6dcc6d8..20befc430 100644 --- a/internal/authenticateflow/authenticateflow.go +++ b/internal/authenticateflow/authenticateflow.go @@ -5,18 +5,14 @@ package authenticateflow import ( "context" - "fmt" "time" oteltrace "go.opentelemetry.io/otel/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" "github.com/pomerium/pomerium/pkg/grpc" - "github.com/pomerium/pomerium/pkg/grpc/user" - "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/telemetry/trace" ) @@ -25,21 +21,6 @@ var timeNow = time.Now var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) -func populateUserFromClaims(u *user.User, claims map[string]any) { - if v, ok := claims["name"]; ok { - u.Name = fmt.Sprint(v) - } - if v, ok := claims["email"]; ok { - u.Email = fmt.Sprint(v) - } - if u.Claims == nil { - u.Claims = make(map[string]*structpb.ListValue) - } - for k, vs := range identity.Claims(claims).Flatten().ToPB() { - u.Claims[k] = vs - } -} - var outboundDatabrokerTraceClientOpts = []trace.ClientStatsHandlerOption{ trace.WithStatsInterceptor(ignoreNotFoundErrors), } diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index 17bc37a7c..f534a0019 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -208,7 +208,7 @@ func (s *Stateful) PersistSession( Id: sess.GetUserId(), } } - populateUserFromClaims(u, claims.Claims) + u.PopulateFromClaims(claims.Claims) _, err := databroker.Put(ctx, s.dataBrokerClient, u) if err != nil { return fmt.Errorf("authenticate: error saving user: %w", err) diff --git a/internal/authenticateflow/stateless.go b/internal/authenticateflow/stateless.go index 6060581d3..e9d26b056 100644 --- a/internal/authenticateflow/stateless.go +++ b/internal/authenticateflow/stateless.go @@ -422,7 +422,7 @@ func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error { if err != nil { u = &user.User{Id: ss.UserID()} } - populateUserFromClaims(u, profile.GetClaims().AsMap()) + u.PopulateFromClaims(profile.Claims.AsMap()) redirectURI, err := getRedirectURIFromValues(values) if err != nil { diff --git a/internal/testutil/mockidp/mockidp.go b/internal/testutil/mockidp/mockidp.go index acb036637..f9c1a18f3 100644 --- a/internal/testutil/mockidp/mockidp.go +++ b/internal/testutil/mockidp/mockidp.go @@ -69,9 +69,10 @@ func New(cfg Config) *IDP { userLookup[user.ID] = user } return &IDP{ - publicJWK: publicJWK, - signingKey: signingKey, - userLookup: userLookup, + publicJWK: publicJWK, + signingKey: signingKey, + userLookup: userLookup, + enableDeviceAuth: cfg.EnableDeviceAuth, } } diff --git a/pkg/envoy/get-envoy/main.go b/pkg/envoy/get-envoy/main.go index f7193839b..01db3332f 100644 --- a/pkg/envoy/get-envoy/main.go +++ b/pkg/envoy/get-envoy/main.go @@ -16,7 +16,7 @@ import ( ) var ( - envoyVersion = "1.34.1-rc1" + envoyVersion = "1.34.1-rc3" targets = []string{ "darwin-amd64", "darwin-arm64", diff --git a/pkg/grpc/user/user.go b/pkg/grpc/user/user.go index ecfc80413..b8326b018 100644 --- a/pkg/grpc/user/user.go +++ b/pkg/grpc/user/user.go @@ -48,6 +48,17 @@ func (x *ServiceAccount) Validate() error { return nil } +// PopulateFromClaims sets the Name, Email, and Claims fields from a claims map. +func (x *User) PopulateFromClaims(claims map[string]any) { + if v, ok := claims["name"]; ok { + x.Name = fmt.Sprint(v) + } + if v, ok := claims["email"]; ok { + x.Email = fmt.Sprint(v) + } + x.AddClaims(identity.Claims(claims).Flatten()) +} + // AddClaims adds the flattened claims to the user. func (x *User) AddClaims(claims identity.FlattenedClaims) { if x.Claims == nil { diff --git a/pkg/identity/providers.go b/pkg/identity/providers.go index 430709062..8df924af0 100644 --- a/pkg/identity/providers.go +++ b/pkg/identity/providers.go @@ -6,11 +6,13 @@ import ( "context" "fmt" "net/http" + "net/url" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/identity/identity" "github.com/pomerium/pomerium/pkg/identity/oauth" "github.com/pomerium/pomerium/pkg/identity/oauth/apple" @@ -92,3 +94,24 @@ func NewAuthenticator(ctx context.Context, tracerProvider oteltrace.TracerProvid return ctor(ctx, &o) } + +func GetIdentityProvider( + ctx context.Context, + tracerProvider oteltrace.TracerProvider, + idp *identitypb.Provider, + redirectURL *url.URL, +) (Authenticator, error) { + o := oauth.Options{ + RedirectURL: redirectURL, + ProviderName: idp.GetType(), + ProviderURL: idp.GetUrl(), + ClientID: idp.GetClientId(), + ClientSecret: idp.GetClientSecret(), + Scopes: idp.GetScopes(), + AuthCodeOptions: idp.GetRequestParams(), + } + if v := idp.GetAccessTokenAllowedAudiences(); v != nil { + o.AccessTokenAllowedAudiences = &v.Values + } + return NewAuthenticator(ctx, tracerProvider, o) +} diff --git a/pkg/ssh/auth.go b/pkg/ssh/auth.go new file mode 100644 index 000000000..c09cef0db --- /dev/null +++ b/pkg/ssh/auth.go @@ -0,0 +1,389 @@ +package ssh + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "errors" + "text/template" + "time" + + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/oauth2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/authorize/evaluator" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/identity" + "github.com/pomerium/pomerium/pkg/identity/manager" + "github.com/pomerium/pomerium/pkg/policy/criteria" + "github.com/pomerium/pomerium/pkg/storage" +) + +type PolicyEvaluator interface { + EvaluateSSH(context.Context, *Request) (*evaluator.Result, error) +} + +type Request struct { + Username string + Hostname string + PublicKey []byte + SessionID string +} + +type Auth struct { + evaluator PolicyEvaluator + dataBrokerClient databroker.DataBrokerServiceClient + currentConfig *atomicutil.Value[*config.Config] + tracerProvider oteltrace.TracerProvider +} + +func NewAuth( + evaluator PolicyEvaluator, + client databroker.DataBrokerServiceClient, + currentConfig *atomicutil.Value[*config.Config], + tracerProvider oteltrace.TracerProvider, +) *Auth { + return &Auth{evaluator, client, currentConfig, tracerProvider} +} + +func (a *Auth) HandlePublicKeyMethodRequest( + ctx context.Context, + info StreamAuthInfo, + req *extensions_ssh.PublicKeyMethodRequest, +) (PublicKeyAuthMethodResponse, error) { + resp, err := a.handlePublicKeyMethodRequest(ctx, info, req) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("ssh publickey auth request error") + return resp, status.Error(codes.Internal, "internal error") + } + return resp, err +} + +func (a *Auth) handlePublicKeyMethodRequest( + ctx context.Context, + info StreamAuthInfo, + req *extensions_ssh.PublicKeyMethodRequest, +) (PublicKeyAuthMethodResponse, error) { + sessionID, err := sessionIDFromFingerprint(req.PublicKeyFingerprintSha256) + if err != nil { + return PublicKeyAuthMethodResponse{}, err + } + sshreq := &Request{ + Username: *info.Username, + Hostname: *info.Hostname, + PublicKey: req.PublicKey, + SessionID: sessionID, + } + log.Ctx(ctx).Debug(). + Str("username", *info.Username). + Str("hostname", *info.Hostname). + Str("session-id", sessionID). + Msg("ssh publickey auth request") + + // Special case: internal command (e.g. routes portal). + if *info.Hostname == "" { + _, err := session.Get(ctx, a.dataBrokerClient, sessionID) + if status.Code(err) == codes.NotFound { + // Require IdP login. + return PublicKeyAuthMethodResponse{ + Allow: publicKeyAllowResponse(req.PublicKey), + RequireAdditionalMethods: []string{MethodKeyboardInteractive}, + }, nil + } else if err != nil { + return PublicKeyAuthMethodResponse{}, err + } + } + + res, err := a.evaluator.EvaluateSSH(ctx, sshreq) + if err != nil { + return PublicKeyAuthMethodResponse{}, err + } + + // Interpret the results of policy evaluation. + if res.HasReason(criteria.ReasonSSHPublickeyUnauthorized) { + // This public key is not allowed, but the client is free to try a different key. + return PublicKeyAuthMethodResponse{ + RequireAdditionalMethods: []string{MethodPublicKey}, + }, nil + } else if res.HasReason(criteria.ReasonUserUnauthenticated) { + // Mark public key as allowed, to initiate IdP login flow. + return PublicKeyAuthMethodResponse{ + Allow: publicKeyAllowResponse(req.PublicKey), + RequireAdditionalMethods: []string{MethodKeyboardInteractive}, + }, nil + } else if res.Allow.Value && !res.Deny.Value { + // Allowed, no login needed. + return PublicKeyAuthMethodResponse{ + Allow: publicKeyAllowResponse(req.PublicKey), + }, nil + } + // Denied, no login needed. + return PublicKeyAuthMethodResponse{}, nil +} + +func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse { + return &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: publicKey, + Permissions: &extensions_ssh.Permissions{ + PermitPortForwarding: true, + PermitAgentForwarding: true, + PermitX11Forwarding: true, + PermitPty: true, + PermitUserRc: true, + ValidStartTime: timestamppb.New(time.Now().Add(-1 * time.Minute)), + ValidEndTime: timestamppb.New(time.Now().Add(1 * time.Hour)), + }, + } +} + +func (a *Auth) HandleKeyboardInteractiveMethodRequest( + ctx context.Context, + info StreamAuthInfo, + _ *extensions_ssh.KeyboardInteractiveMethodRequest, + querier KeyboardInteractiveQuerier, +) (KeyboardInteractiveAuthMethodResponse, error) { + resp, err := a.handleKeyboardInteractiveMethodRequest(ctx, info, querier) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("ssh keyboard-interactive auth request error") + return resp, status.Error(codes.Internal, "internal error") + } + return resp, err +} + +func (a *Auth) handleKeyboardInteractiveMethodRequest( + ctx context.Context, + info StreamAuthInfo, + querier KeyboardInteractiveQuerier, +) (KeyboardInteractiveAuthMethodResponse, error) { + if info.PublicKeyAllow.Value == nil { + // Sanity check: this method is only valid if we already accepted a public key. + return KeyboardInteractiveAuthMethodResponse{}, errPublicKeyAllowNil + } + + log.Ctx(ctx).Debug(). + Str("username", *info.Username). + Str("hostname", *info.Hostname). + Str("publickey-fingerprint", base64.StdEncoding.EncodeToString(info.PublicKeyFingerprintSha256)). + Msg("ssh keyboard-interactive auth request") + + // Initiate the IdP login flow. + err := a.handleLogin(ctx, *info.Hostname, info.PublicKeyFingerprintSha256, querier) + if err != nil { + return KeyboardInteractiveAuthMethodResponse{}, err + } + + if err := a.EvaluateDelayed(ctx, info); err != nil { + // Denied. + return KeyboardInteractiveAuthMethodResponse{}, nil + } + // Allowed. + return KeyboardInteractiveAuthMethodResponse{ + Allow: &extensions_ssh.KeyboardInteractiveAllowResponse{}, + }, nil +} + +func (a *Auth) handleLogin( + ctx context.Context, + hostname string, + publicKeyFingerprint []byte, + querier KeyboardInteractiveQuerier, +) error { + // Initiate the IdP login flow. + authenticator, err := a.getAuthenticator(ctx, hostname) + if err != nil { + return err + } + + resp, err := authenticator.DeviceAuth(ctx) + if err != nil { + return err + } + + // Prompt the user to sign in. + _, _ = querier.Prompt(ctx, &extensions_ssh.KeyboardInteractiveInfoPrompts{ + Name: "Please sign in with " + authenticator.Name() + " to continue", + Instruction: resp.VerificationURIComplete, + Prompts: nil, + }) + + var sessionClaims identity.SessionClaims + token, err := authenticator.DeviceAccessToken(ctx, resp, &sessionClaims) + if err != nil { + return err + } + sessionID, err := sessionIDFromFingerprint(publicKeyFingerprint) + if err != nil { + return err + } + return a.saveSession(ctx, sessionID, &sessionClaims, token) +} + +var errAccessDenied = errors.New("access denied") + +func (a *Auth) EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error { + req, err := sshRequestFromStreamAuthInfo(info) + if err != nil { + return err + } + res, err := a.evaluator.EvaluateSSH(ctx, req) + if err != nil { + return err + } + + if res.Allow.Value && !res.Deny.Value { + return nil + } + return errAccessDenied +} + +func (a *Auth) FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error) { + sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256) + if err != nil { + return nil, err + } + session, err := session.Get(ctx, a.dataBrokerClient, sessionID) + if err != nil { + return nil, err + } + var b bytes.Buffer + err = sessionInfoTmpl.Execute(&b, session) + if err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error { + sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256) + if err != nil { + return err + } + err = session.Delete(ctx, a.dataBrokerClient, sessionID) + a.invalidateCacheForRecord(ctx, &databroker.Record{ + Type: "type.googleapis.com/session.Session", + Id: sessionID, + }) + return err +} + +func (a *Auth) saveSession( + ctx context.Context, + id string, + claims *identity.SessionClaims, + token *oauth2.Token, +) error { + now := time.Now() + nowpb := timestamppb.New(now) + sessionLifetime := a.currentConfig.Load().Options.CookieExpire + + state := sessions.State{ID: id} + if err := claims.Claims.Claims(&state); err != nil { + return err + } + + sess := &session.Session{ + Id: id, + UserId: state.UserID(), + IssuedAt: nowpb, + AccessedAt: nowpb, + ExpiresAt: timestamppb.New(now.Add(sessionLifetime)), + OauthToken: manager.ToOAuthToken(token), + Audience: state.Audience, + } + sess.SetRawIDToken(claims.RawIDToken) + sess.AddClaims(claims.Flatten()) + + u, _ := user.Get(ctx, a.dataBrokerClient, sess.GetUserId()) + if u == nil { + // if no user exists yet, create a new one + u = &user.User{ + Id: sess.GetUserId(), + } + } + u.PopulateFromClaims(claims.Claims) + _, err := databroker.Put(ctx, a.dataBrokerClient, u) + if err != nil { + return err + } + + resp, err := session.Put(ctx, a.dataBrokerClient, sess) + if err != nil { + return err + } + a.invalidateCacheForRecord(ctx, resp.GetRecord()) + return nil +} + +func (a *Auth) invalidateCacheForRecord(ctx context.Context, record *databroker.Record) { + ctx = storage.WithQuerier(ctx, + storage.NewCachingQuerier(storage.NewQuerier(a.dataBrokerClient), storage.GlobalCache)) + storage.InvalidateCacheForDataBrokerRecords(ctx, record) +} + +func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (identity.Authenticator, error) { + opts := a.currentConfig.Load().Options + + redirectURL, err := opts.GetAuthenticateRedirectURL() + if err != nil { + return nil, err + } + + idp, err := opts.GetIdentityProviderForPolicy(opts.GetRouteForSSHHostname(hostname)) + if err != nil { + return nil, err + } + + return identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL) +} + +var _ AuthInterface = (*Auth)(nil) + +var errInvalidFingerprint = errors.New("invalid public key fingerprint") + +func sessionIDFromFingerprint(sha256fingerprint []byte) (string, error) { + if len(sha256fingerprint) != sha256.Size { + return "", errInvalidFingerprint + } + return "sshkey-SHA256:" + base64.StdEncoding.EncodeToString(sha256fingerprint), nil +} + +var errPublicKeyAllowNil = errors.New("expected PublicKeyAllow message not to be nil") + +// Converts from StreamAuthInfo to an SSHRequest, assuming the PublicKeyAllow field is not nil. +func sshRequestFromStreamAuthInfo(info StreamAuthInfo) (*Request, error) { + if info.PublicKeyAllow.Value == nil { + return nil, errPublicKeyAllowNil + } + sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256) + if err != nil { + return nil, err + } + + return &Request{ + Username: *info.Username, + Hostname: *info.Hostname, + PublicKey: info.PublicKeyAllow.Value.PublicKey, + SessionID: sessionID, + }, nil +} + +var sessionInfoTmpl = template.Must(template.New("session-info").Parse(` +User ID: {{.UserId}} +Session ID: {{.Id}} +Expires at: {{.ExpiresAt.AsTime}} +Claims: +{{- range $k, $v := .Claims }} + {{ $k }}: {{ $v.AsSlice }} +{{- end }} +`)) diff --git a/pkg/ssh/auth_test.go b/pkg/ssh/auth_test.go new file mode 100644 index 000000000..6aa564555 --- /dev/null +++ b/pkg/ssh/auth_test.go @@ -0,0 +1,469 @@ +package ssh + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + + extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" + "github.com/pomerium/pomerium/authorize/evaluator" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/testutil/mockidp" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/identity" + "github.com/pomerium/pomerium/pkg/policy/criteria" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func TestHandlePublicKeyMethodRequest(t *testing.T) { + t.Run("no public key fingerprint", func(t *testing.T) { + var a Auth + var req extensions_ssh.PublicKeyMethodRequest + _, err := a.handlePublicKeyMethodRequest(t.Context(), StreamAuthInfo{}, &req) + assert.ErrorContains(t, err, "invalid public key fingerprint") + }) + t.Run("evaluate error", func(t *testing.T) { + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + pe := policyEvaluatorFunc(func(context.Context, *Request) (*evaluator.Result, error) { + return nil, errors.New("error evaluating policy") + }) + a := NewAuth(pe, nil, nil, nil) + _, err := a.handlePublicKeyMethodRequest(t.Context(), info, &req) + assert.ErrorContains(t, err, "error evaluating policy") + }) + t.Run("allow", func(t *testing.T) { + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + fakePublicKey := []byte("fake-public-key") + req.PublicKey = fakePublicKey + pe := policyEvaluatorFunc(func(_ context.Context, r *Request) (*evaluator.Result, error) { + assert.Equal(t, r, &Request{ + Username: "username", + Hostname: "hostname", + PublicKey: fakePublicKey, + SessionID: "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=", + }) + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(true), + Deny: evaluator.NewRuleResult(false), + }, nil + }) + a := NewAuth(pe, nil, nil, nil) + res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req) + assert.NoError(t, err) + assert.Empty(t, res.RequireAdditionalMethods) + require.NotNil(t, res.Allow) + assert.Equal(t, res.Allow.PublicKey, fakePublicKey) + }) + t.Run("deny", func(t *testing.T) { + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(true), + Deny: evaluator.NewRuleResult(true), + }, nil + }) + a := NewAuth(pe, nil, nil, nil) + res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req) + assert.NoError(t, err) + assert.Nil(t, res.Allow) + assert.Empty(t, res.RequireAdditionalMethods) + }) + t.Run("public key unauthorized", func(t *testing.T) { + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(false, criteria.ReasonSSHPublickeyUnauthorized), + Deny: evaluator.NewRuleResult(false), + }, nil + }) + a := NewAuth(pe, nil, nil, nil) + res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req) + assert.NoError(t, err) + assert.Nil(t, res.Allow) + assert.Equal(t, res.RequireAdditionalMethods, []string{MethodPublicKey}) + }) + t.Run("needs login", func(t *testing.T) { + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(false), + Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated), + }, nil + }) + a := NewAuth(pe, nil, nil, nil) + res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req) + assert.NoError(t, err) + assert.NotNil(t, res.Allow) + assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive}) + }) + t.Run("internal command no session", func(t *testing.T) { + client := fakeDataBrokerServiceClient{ + get: func( + _ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption, + ) (*databroker.GetResponse, error) { + return nil, status.Error(codes.NotFound, "not found") + }, + } + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr(""), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(false), + Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated), + }, nil + }) + a := NewAuth(pe, client, nil, nil) + res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req) + assert.NoError(t, err) + assert.NotNil(t, res.Allow) + assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive}) + }) + t.Run("internal command with session", func(t *testing.T) { + client := fakeDataBrokerServiceClient{ + get: func( + _ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption, + ) (*databroker.GetResponse, error) { + return &databroker.GetResponse{ + Record: &databroker.Record{ + Type: "type.googleapis.com/session.Session", + Id: "abc", + Data: protoutil.NewAny(&session.Session{ + Id: "abc", + UserId: "USER-ID", + }), + }, + }, nil + }, + } + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr(""), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(true), + Deny: evaluator.NewRuleResult(false), + }, nil + }) + a := NewAuth(pe, client, nil, nil) + res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req) + assert.NoError(t, err) + assert.NotNil(t, res.Allow) + assert.Empty(t, res.RequireAdditionalMethods) + }) + t.Run("internal command databroker error", func(t *testing.T) { + client := fakeDataBrokerServiceClient{ + get: func( + _ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption, + ) (*databroker.GetResponse, error) { + return nil, status.Error(codes.Unknown, "unknown") + }, + } + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr(""), + } + var req extensions_ssh.PublicKeyMethodRequest + req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456") + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(true), + Deny: evaluator.NewRuleResult(false), + }, nil + }) + a := NewAuth(pe, client, nil, nil) + _, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req) + assert.ErrorContains(t, err, "internal error") + }) +} + +func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) { + t.Run("no public key", func(t *testing.T) { + var a Auth + _, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), StreamAuthInfo{}, nil) + assert.ErrorContains(t, err, "expected PublicKeyAllow message not to be nil") + }) + t.Run("ok", func(t *testing.T) { + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(true), + Deny: evaluator.NewRuleResult(false), + }, nil + }) + var putRecords []*databroker.Record + client := fakeDataBrokerServiceClient{ + get: func( + _ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption, + ) (*databroker.GetResponse, error) { + return nil, status.Error(codes.NotFound, "not found") + }, + put: func( + _ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption, + ) (*databroker.PutResponse, error) { + putRecords = append(putRecords, in.Records...) + return &databroker.PutResponse{ + Records: in.Records, + }, nil + }, + } + cfg := config.Config{ + Options: config.NewDefaultOptions(), + } + mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true}) + idpURL := mockIDP.Start(t) + cfg.Options.Provider = "oidc" + cfg.Options.ProviderURL = idpURL + cfg.Options.ClientID = "client-id" + cfg.Options.ClientSecret = "client-secret" + a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil) + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{ + Value: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: []byte("fake-public-key"), + }, + }, + PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"), + } + res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{}) + require.NoError(t, err) + assert.NotNil(t, res.Allow) + assert.Empty(t, res.RequireAdditionalMethods) + + // A new Session and User record should have been saved to the databroker. + assert.Len(t, putRecords, 2) + + assert.Equal(t, "type.googleapis.com/user.User", putRecords[0].Type) + assert.Equal(t, "fake.user@example.com", putRecords[0].Id) + + assert.Equal(t, "type.googleapis.com/session.Session", putRecords[1].Type) + assert.Equal(t, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=", putRecords[1].Id) + }) + t.Run("denied", func(t *testing.T) { + pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) { + return &evaluator.Result{ + Allow: evaluator.NewRuleResult(false), + Deny: evaluator.NewRuleResult(false), + }, nil + }) + client := fakeDataBrokerServiceClient{ + get: func( + _ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption, + ) (*databroker.GetResponse, error) { + return nil, status.Error(codes.NotFound, "not found") + }, + put: func( + _ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption, + ) (*databroker.PutResponse, error) { + return &databroker.PutResponse{ + Records: in.Records, + }, nil + }, + } + cfg := config.Config{ + Options: config.NewDefaultOptions(), + } + mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true}) + idpURL := mockIDP.Start(t) + cfg.Options.Provider = "oidc" + cfg.Options.ProviderURL = idpURL + cfg.Options.ClientID = "client-id" + cfg.Options.ClientSecret = "client-secret" + a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil) + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{ + Value: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: []byte("fake-public-key"), + }, + }, + PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"), + } + res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{}) + require.NoError(t, err) + assert.Nil(t, res.Allow) + assert.Empty(t, res.RequireAdditionalMethods) + }) + t.Run("invalid fingerprint", func(t *testing.T) { + cfg := config.Config{ + Options: config.NewDefaultOptions(), + } + mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true}) + idpURL := mockIDP.Start(t) + cfg.Options.Provider = "oidc" + cfg.Options.ProviderURL = idpURL + cfg.Options.ClientID = "client-id" + cfg.Options.ClientSecret = "client-secret" + a := NewAuth(nil, nil, atomicutil.NewValue(&cfg), nil) + info := StreamAuthInfo{ + Username: ptr("username"), + Hostname: ptr("hostname"), + PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{ + Value: &extensions_ssh.PublicKeyAllowResponse{ + PublicKey: []byte("fake-public-key"), + }, + }, + } + _, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), info, noopQuerier{}) + assert.ErrorContains(t, err, "invalid public key fingerprint") + }) +} + +func TestFormatSession(t *testing.T) { + t.Run("invalid fingerprint", func(t *testing.T) { + var a Auth + info := StreamAuthInfo{ + PublicKeyFingerprintSha256: []byte("wrong-length"), + } + _, err := a.FormatSession(t.Context(), info) + assert.ErrorContains(t, err, "invalid public key fingerprint") + }) + t.Run("ok", func(t *testing.T) { + client := fakeDataBrokerServiceClient{ + get: func( + _ context.Context, in *databroker.GetRequest, _ ...grpc.CallOption, + ) (*databroker.GetResponse, error) { + const expectedID = "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=" + assert.Equal(t, in.Type, "type.googleapis.com/session.Session") + assert.Equal(t, in.Id, expectedID) + claims := identity.FlattenedClaims{ + "foo": []any{"bar", "baz"}, + "quux": []any{42}, + } + return &databroker.GetResponse{ + Record: &databroker.Record{ + Type: "type.googleapis.com/session.Session", + Id: expectedID, + Data: protoutil.NewAny(&session.Session{ + Id: expectedID, + UserId: "USER-ID", + ExpiresAt: ×tamppb.Timestamp{Seconds: 1750965358}, + Claims: claims.ToPB(), + }), + }, + }, nil + }, + } + a := NewAuth(nil, client, nil, nil) + info := StreamAuthInfo{ + PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"), + } + b, err := a.FormatSession(t.Context(), info) + assert.NoError(t, err) + assert.Equal(t, string(b), ` +User ID: USER-ID +Session ID: sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY= +Expires at: 2025-06-26 19:15:58 +0000 UTC +Claims: + foo: [bar baz] + quux: [42] +`) + }) +} + +func TestDeleteSession(t *testing.T) { + t.Run("invalid fingerprint", func(t *testing.T) { + var a Auth + info := StreamAuthInfo{ + PublicKeyFingerprintSha256: []byte("wrong-length"), + } + err := a.DeleteSession(t.Context(), info) + assert.ErrorContains(t, err, "invalid public key fingerprint") + }) + t.Run("ok", func(t *testing.T) { + putError := errors.New("sentinel") + client := fakeDataBrokerServiceClient{ + put: func( + _ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption, + ) (*databroker.PutResponse, error) { + require.Len(t, in.Records, 1) + assert.Equal(t, in.Records[0].Id, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=") + assert.NotNil(t, in.Records[0].DeletedAt) + return nil, putError + }, + } + a := NewAuth(nil, client, nil, nil) + info := StreamAuthInfo{ + PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"), + } + err := a.DeleteSession(t.Context(), info) + assert.Equal(t, putError, err) + }) +} + +type policyEvaluatorFunc func(context.Context, *Request) (*evaluator.Result, error) + +func (f policyEvaluatorFunc) EvaluateSSH( + ctx context.Context, req *Request, +) (*evaluator.Result, error) { + return f(ctx, req) +} + +type fakeDataBrokerServiceClient struct { + databroker.DataBrokerServiceClient + + get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) + put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) +} + +func (m fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + return m.get(ctx, in, opts...) +} + +func (m fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) { + return m.put(ctx, in, opts...) +} + +type noopQuerier struct{} + +func (noopQuerier) Prompt( + _ context.Context, _ *extensions_ssh.KeyboardInteractiveInfoPrompts, +) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) { + return nil, nil +} + +func ptr[T any](t T) *T { + return &t +} diff --git a/pkg/ssh/manager.go b/pkg/ssh/manager.go index 7656411a8..0b3b0eefc 100644 --- a/pkg/ssh/manager.go +++ b/pkg/ssh/manager.go @@ -8,14 +8,12 @@ import ( ) type StreamManager struct { - auth AuthInterface mu sync.Mutex activeStreams map[uint64]*StreamHandler } -func NewStreamManager(auth AuthInterface) *StreamManager { +func NewStreamManager() *StreamManager { return &StreamManager{ - auth: auth, activeStreams: map[uint64]*StreamHandler{}, } } @@ -30,13 +28,17 @@ func (sm *StreamManager) LookupStream(streamID uint64) *StreamHandler { return stream } -func (sm *StreamManager) NewStreamHandler(cfg *config.Config, downstream *extensions_ssh.DownstreamConnectEvent) *StreamHandler { +func (sm *StreamManager) NewStreamHandler( + cfg *config.Config, + auth AuthInterface, + downstream *extensions_ssh.DownstreamConnectEvent, +) *StreamHandler { sm.mu.Lock() defer sm.mu.Unlock() streamID := downstream.StreamId writeC := make(chan *extensions_ssh.ServerMessage, 32) sh := &StreamHandler{ - auth: sm.auth, + auth: auth, config: cfg, downstream: downstream, readC: make(chan *extensions_ssh.ClientMessage, 32), diff --git a/pkg/ssh/manager_test.go b/pkg/ssh/manager_test.go index 0cd5aedf3..766a36d66 100644 --- a/pkg/ssh/manager_test.go +++ b/pkg/ssh/manager_test.go @@ -22,7 +22,7 @@ func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL { func TestStreamManager(t *testing.T) { ctrl := gomock.NewController(t) auth := mock_ssh.NewMockAuthInterface(ctrl) - m := ssh.NewStreamManager(auth) + m := ssh.NewStreamManager() cfg := &config.Config{Options: config.NewDefaultOptions()} cfg.Options.Policies = []config.Policy{ @@ -32,7 +32,7 @@ func TestStreamManager(t *testing.T) { t.Run("LookupStream", func(t *testing.T) { assert.Nil(t, m.LookupStream(1234)) - sh := m.NewStreamHandler(cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234}) + sh := m.NewStreamHandler(cfg, auth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234}) assert.Equal(t, sh, m.LookupStream(1234)) sh.Close() assert.Nil(t, m.LookupStream(1234)) diff --git a/pkg/ssh/mock/mock_auth_interface.go b/pkg/ssh/mock/mock_auth_interface.go index c033561e6..694a6dcd4 100644 --- a/pkg/ssh/mock/mock_auth_interface.go +++ b/pkg/ssh/mock/mock_auth_interface.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -typed . AuthInterface +// mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface // // Package mock_ssh is a generated GoMock package. diff --git a/pkg/ssh/stream.go b/pkg/ssh/stream.go index a577170fc..07c1b884d 100644 --- a/pkg/ssh/stream.go +++ b/pkg/ssh/stream.go @@ -37,6 +37,8 @@ type ( KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse] ) +//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface + type AuthInterface interface { HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error) HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error) @@ -284,8 +286,10 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq) if err != nil { return err + } else if response.Allow != nil { + partial = true + sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256 } - partial = response.Allow != nil sh.state.PublicKeyAllow.Update(response.Allow) updateMethods(response.RequireAdditionalMethods) case MethodKeyboardInteractive: diff --git a/pkg/ssh/stream_test.go b/pkg/ssh/stream_test.go index 2046bcc21..4c7320287 100644 --- a/pkg/ssh/stream_test.go +++ b/pkg/ssh/stream_test.go @@ -100,7 +100,7 @@ type StreamHandlerSuite struct { func (s *StreamHandlerSuite) SetupTest() { s.ctrl = NewController(s.T()) s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl) - s.mgr = ssh.NewStreamManager(s.mockAuth) + s.mgr = ssh.NewStreamManager() s.cleanup = []func(){} s.errC = make(chan error, 1) @@ -162,7 +162,8 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) { } func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler { - sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID}) + sh := s.mgr.NewStreamHandler( + s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID}) s.errC = make(chan error, 1) ctx, ca := context.WithCancel(s.T().Context()) go func() { @@ -1996,7 +1997,8 @@ func (s *StreamHandlerSuite) TestFormatSession() { s.mockAuth.EXPECT(). FormatSession(Any(), Any()). Return([]byte("example"), nil) - sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + sh := s.mgr.NewStreamHandler( + s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) ctx, ca := context.WithCancel(context.Background()) ca() // this will exit immediately, but it will have a state, which is only @@ -2012,7 +2014,8 @@ func (s *StreamHandlerSuite) TestDeleteSession() { s.mockAuth.EXPECT(). DeleteSession(Any(), Any()). Return(nil) - sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + sh := s.mgr.NewStreamHandler( + s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) ctx, ca := context.WithCancel(context.Background()) ca() // this will exit immediately, but it will have a state, which is only @@ -2024,7 +2027,8 @@ func (s *StreamHandlerSuite) TestDeleteSession() { } func (s *StreamHandlerSuite) TestRunCalledTwice() { - sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + sh := s.mgr.NewStreamHandler( + s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) ctx, ca := context.WithCancel(context.Background()) ca() sh.Run(ctx) @@ -2034,7 +2038,8 @@ func (s *StreamHandlerSuite) TestRunCalledTwice() { } func (s *StreamHandlerSuite) TestAllSSHRoutes() { - sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) + sh := s.mgr.NewStreamHandler( + s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1}) routes := slices.Collect(sh.AllSSHRoutes()) s.Len(routes, 2) s.Equal("ssh://host1", routes[0].From)