mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
ssh: implement authorization policy evaluation (#5665)
Implement the pkg/ssh.AuthInterface. Add logic for converting from the ssh stream state to an evaluator request, and for interpreting the results of policy evaluation. Refactor some of the existing authorize logic to make it easier to reuse.
This commit is contained in:
parent
9437cec21d
commit
9678e6a231
20 changed files with 1013 additions and 74 deletions
|
@ -6,39 +6,19 @@ import (
|
|||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
)
|
||||
|
||||
func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.TracerProvider, options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
authenticateURL, err := options.GetAuthenticateURL()
|
||||
redirectURL, err := options.GetAuthenticateRedirectURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectURL, err := urlutil.DeepCopy(authenticateURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redirectURL.Path = options.AuthenticateCallbackPath
|
||||
|
||||
idp, err := options.GetIdentityProviderForID(idpID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o := oauth.Options{
|
||||
RedirectURL: redirectURL,
|
||||
ProviderName: idp.GetType(),
|
||||
ProviderURL: idp.GetUrl(),
|
||||
ClientID: idp.GetClientId(),
|
||||
ClientSecret: idp.GetClientSecret(),
|
||||
Scopes: idp.GetScopes(),
|
||||
AuthCodeOptions: idp.GetRequestParams(),
|
||||
}
|
||||
if v := idp.GetAccessTokenAllowedAudiences(); v != nil {
|
||||
o.AccessTokenAllowedAudiences = &v.Values
|
||||
}
|
||||
return identity.NewAuthenticator(ctx, tracerProvider, o)
|
||||
return identity.GetIdentityProvider(ctx, tracerProvider, idp, redirectURL)
|
||||
}
|
||||
|
|
|
@ -44,15 +44,13 @@ func (a *Authorize) handleResult(
|
|||
|
||||
// when the user is unauthenticated it means they haven't
|
||||
// logged in yet, so redirect to authenticate
|
||||
if result.Allow.Reasons.Has(criteria.ReasonUserUnauthenticated) ||
|
||||
result.Deny.Reasons.Has(criteria.ReasonUserUnauthenticated) {
|
||||
if result.HasReason(criteria.ReasonUserUnauthenticated) {
|
||||
return a.requireLoginResponse(ctx, in, request)
|
||||
}
|
||||
|
||||
// when the user's device is unauthenticated it means they haven't
|
||||
// registered a webauthn device yet, so redirect to the webauthn flow
|
||||
if result.Allow.Reasons.Has(criteria.ReasonDeviceUnauthenticated) ||
|
||||
result.Deny.Reasons.Has(criteria.ReasonDeviceUnauthenticated) {
|
||||
if result.HasReason(criteria.ReasonDeviceUnauthenticated) {
|
||||
return a.requireWebAuthnResponse(ctx, in, request, result)
|
||||
}
|
||||
|
||||
|
|
|
@ -149,6 +149,10 @@ type Result struct {
|
|||
AdditionalLogFields map[log.AuthorizeLogField]any
|
||||
}
|
||||
|
||||
func (r *Result) HasReason(reason criteria.Reason) bool {
|
||||
return r.Allow.Reasons.Has(reason) || r.Deny.Reasons.Has(reason)
|
||||
}
|
||||
|
||||
// An Evaluator evaluates policies.
|
||||
type Evaluator struct {
|
||||
evaluationCount, allowCount, denyCount metric.Int64Counter
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
|
@ -9,7 +10,10 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
|
||||
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/ssh"
|
||||
)
|
||||
|
||||
func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageStreamServer) error {
|
||||
|
@ -22,15 +26,16 @@ func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageSt
|
|||
if downstream == nil {
|
||||
return status.Errorf(codes.Internal, "first message was not a downstream connected event")
|
||||
}
|
||||
handler := a.state.Load().ssh.NewStreamHandler(a.currentConfig.Load(), downstream)
|
||||
|
||||
state := a.state.Load()
|
||||
handler := state.ssh.NewStreamHandler(
|
||||
a.currentConfig.Load(),
|
||||
ssh.NewAuth(a, state.dataBrokerClient, a.currentConfig, a.tracerProvider),
|
||||
downstream,
|
||||
)
|
||||
defer handler.Close()
|
||||
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
querier := storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
storage.GlobalCache,
|
||||
)
|
||||
ctx = storage.WithQuerier(ctx, querier)
|
||||
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
|
@ -87,3 +92,42 @@ func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeCha
|
|||
|
||||
return handler.ServeChannel(stream)
|
||||
}
|
||||
|
||||
func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluator.Result, error) {
|
||||
ctx = a.withQuerierForCheckRequest(ctx)
|
||||
|
||||
evalreq := evaluator.Request{
|
||||
HTTP: evaluator.RequestHTTP{
|
||||
Hostname: req.Hostname,
|
||||
},
|
||||
SSH: evaluator.RequestSSH{
|
||||
Username: req.Username,
|
||||
PublicKey: req.PublicKey,
|
||||
},
|
||||
Session: evaluator.RequestSession{
|
||||
ID: req.SessionID,
|
||||
},
|
||||
}
|
||||
|
||||
if req.Hostname == "" {
|
||||
evalreq.IsInternal = true
|
||||
} else {
|
||||
evalreq.Policy = a.currentConfig.Load().Options.GetRouteForSSHHostname(req.Hostname)
|
||||
}
|
||||
|
||||
res, err := a.state.Load().evaluator.Evaluate(ctx, &evalreq)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("error during OPA evaluation")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, _ := a.getDataBrokerSessionOrServiceAccount(ctx, req.SessionID, 0)
|
||||
|
||||
var u *user.User
|
||||
if s != nil {
|
||||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId())
|
||||
}
|
||||
a.logAuthorizeCheck(ctx, &evalreq, res, s, u)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ func newAuthorizeStateFromConfig(
|
|||
evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp))
|
||||
}
|
||||
|
||||
state.ssh = ssh.NewStreamManager(nil) // XXX
|
||||
state.ssh = ssh.NewStreamManager()
|
||||
|
||||
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...)
|
||||
if err != nil {
|
||||
|
|
|
@ -832,6 +832,21 @@ func (o *Options) GetInternalAuthenticateURL() (*url.URL, error) {
|
|||
return urlutil.ParseAndValidateURL(o.AuthenticateInternalURLString)
|
||||
}
|
||||
|
||||
func (o *Options) GetAuthenticateRedirectURL() (*url.URL, error) {
|
||||
authenticateURL, err := o.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectURL, err := urlutil.DeepCopy(authenticateURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redirectURL.Path = o.AuthenticateCallbackPath
|
||||
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
// UseStatelessAuthenticateFlow returns true if the stateless authentication
|
||||
// flow should be used (i.e. for hosted authenticate).
|
||||
func (o *Options) UseStatelessAuthenticateFlow() bool {
|
||||
|
@ -1054,6 +1069,19 @@ func (o *Options) NumPolicies() int {
|
|||
return len(o.Policies) + len(o.Routes) + len(o.AdditionalPolicies)
|
||||
}
|
||||
|
||||
func (o *Options) GetRouteForSSHHostname(hostname string) *Policy {
|
||||
if hostname == "" {
|
||||
return nil
|
||||
}
|
||||
from := "ssh://" + hostname
|
||||
for r := range o.GetAllPolicies() {
|
||||
if r.From == from {
|
||||
return r
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMetricsBasicAuth gets the metrics basic auth username and password.
|
||||
func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) {
|
||||
if o.MetricsBasicAuth == "" {
|
||||
|
|
|
@ -5,18 +5,14 @@ package authenticateflow
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/stats"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||
)
|
||||
|
||||
|
@ -25,21 +21,6 @@ var timeNow = time.Now
|
|||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
||||
func populateUserFromClaims(u *user.User, claims map[string]any) {
|
||||
if v, ok := claims["name"]; ok {
|
||||
u.Name = fmt.Sprint(v)
|
||||
}
|
||||
if v, ok := claims["email"]; ok {
|
||||
u.Email = fmt.Sprint(v)
|
||||
}
|
||||
if u.Claims == nil {
|
||||
u.Claims = make(map[string]*structpb.ListValue)
|
||||
}
|
||||
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
||||
u.Claims[k] = vs
|
||||
}
|
||||
}
|
||||
|
||||
var outboundDatabrokerTraceClientOpts = []trace.ClientStatsHandlerOption{
|
||||
trace.WithStatsInterceptor(ignoreNotFoundErrors),
|
||||
}
|
||||
|
|
|
@ -208,7 +208,7 @@ func (s *Stateful) PersistSession(
|
|||
Id: sess.GetUserId(),
|
||||
}
|
||||
}
|
||||
populateUserFromClaims(u, claims.Claims)
|
||||
u.PopulateFromClaims(claims.Claims)
|
||||
_, err := databroker.Put(ctx, s.dataBrokerClient, u)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving user: %w", err)
|
||||
|
|
|
@ -422,7 +422,7 @@ func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error {
|
|||
if err != nil {
|
||||
u = &user.User{Id: ss.UserID()}
|
||||
}
|
||||
populateUserFromClaims(u, profile.GetClaims().AsMap())
|
||||
u.PopulateFromClaims(profile.Claims.AsMap())
|
||||
|
||||
redirectURI, err := getRedirectURIFromValues(values)
|
||||
if err != nil {
|
||||
|
|
|
@ -72,6 +72,7 @@ func New(cfg Config) *IDP {
|
|||
publicJWK: publicJWK,
|
||||
signingKey: signingKey,
|
||||
userLookup: userLookup,
|
||||
enableDeviceAuth: cfg.EnableDeviceAuth,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
envoyVersion = "1.34.1-rc1"
|
||||
envoyVersion = "1.34.1-rc3"
|
||||
targets = []string{
|
||||
"darwin-amd64",
|
||||
"darwin-arm64",
|
||||
|
|
|
@ -48,6 +48,17 @@ func (x *ServiceAccount) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// PopulateFromClaims sets the Name, Email, and Claims fields from a claims map.
|
||||
func (x *User) PopulateFromClaims(claims map[string]any) {
|
||||
if v, ok := claims["name"]; ok {
|
||||
x.Name = fmt.Sprint(v)
|
||||
}
|
||||
if v, ok := claims["email"]; ok {
|
||||
x.Email = fmt.Sprint(v)
|
||||
}
|
||||
x.AddClaims(identity.Claims(claims).Flatten())
|
||||
}
|
||||
|
||||
// AddClaims adds the flattened claims to the user.
|
||||
func (x *User) AddClaims(claims identity.FlattenedClaims) {
|
||||
if x.Claims == nil {
|
||||
|
|
|
@ -6,11 +6,13 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth/apple"
|
||||
|
@ -92,3 +94,24 @@ func NewAuthenticator(ctx context.Context, tracerProvider oteltrace.TracerProvid
|
|||
|
||||
return ctor(ctx, &o)
|
||||
}
|
||||
|
||||
func GetIdentityProvider(
|
||||
ctx context.Context,
|
||||
tracerProvider oteltrace.TracerProvider,
|
||||
idp *identitypb.Provider,
|
||||
redirectURL *url.URL,
|
||||
) (Authenticator, error) {
|
||||
o := oauth.Options{
|
||||
RedirectURL: redirectURL,
|
||||
ProviderName: idp.GetType(),
|
||||
ProviderURL: idp.GetUrl(),
|
||||
ClientID: idp.GetClientId(),
|
||||
ClientSecret: idp.GetClientSecret(),
|
||||
Scopes: idp.GetScopes(),
|
||||
AuthCodeOptions: idp.GetRequestParams(),
|
||||
}
|
||||
if v := idp.GetAccessTokenAllowedAudiences(); v != nil {
|
||||
o.AccessTokenAllowedAudiences = &v.Values
|
||||
}
|
||||
return NewAuthenticator(ctx, tracerProvider, o)
|
||||
}
|
||||
|
|
389
pkg/ssh/auth.go
Normal file
389
pkg/ssh/auth.go
Normal file
|
@ -0,0 +1,389 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
type PolicyEvaluator interface {
|
||||
EvaluateSSH(context.Context, *Request) (*evaluator.Result, error)
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Username string
|
||||
Hostname string
|
||||
PublicKey []byte
|
||||
SessionID string
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
evaluator PolicyEvaluator
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
}
|
||||
|
||||
func NewAuth(
|
||||
evaluator PolicyEvaluator,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
currentConfig *atomicutil.Value[*config.Config],
|
||||
tracerProvider oteltrace.TracerProvider,
|
||||
) *Auth {
|
||||
return &Auth{evaluator, client, currentConfig, tracerProvider}
|
||||
}
|
||||
|
||||
func (a *Auth) HandlePublicKeyMethodRequest(
|
||||
ctx context.Context,
|
||||
info StreamAuthInfo,
|
||||
req *extensions_ssh.PublicKeyMethodRequest,
|
||||
) (PublicKeyAuthMethodResponse, error) {
|
||||
resp, err := a.handlePublicKeyMethodRequest(ctx, info, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("ssh publickey auth request error")
|
||||
return resp, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (a *Auth) handlePublicKeyMethodRequest(
|
||||
ctx context.Context,
|
||||
info StreamAuthInfo,
|
||||
req *extensions_ssh.PublicKeyMethodRequest,
|
||||
) (PublicKeyAuthMethodResponse, error) {
|
||||
sessionID, err := sessionIDFromFingerprint(req.PublicKeyFingerprintSha256)
|
||||
if err != nil {
|
||||
return PublicKeyAuthMethodResponse{}, err
|
||||
}
|
||||
sshreq := &Request{
|
||||
Username: *info.Username,
|
||||
Hostname: *info.Hostname,
|
||||
PublicKey: req.PublicKey,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
log.Ctx(ctx).Debug().
|
||||
Str("username", *info.Username).
|
||||
Str("hostname", *info.Hostname).
|
||||
Str("session-id", sessionID).
|
||||
Msg("ssh publickey auth request")
|
||||
|
||||
// Special case: internal command (e.g. routes portal).
|
||||
if *info.Hostname == "" {
|
||||
_, err := session.Get(ctx, a.dataBrokerClient, sessionID)
|
||||
if status.Code(err) == codes.NotFound {
|
||||
// Require IdP login.
|
||||
return PublicKeyAuthMethodResponse{
|
||||
Allow: publicKeyAllowResponse(req.PublicKey),
|
||||
RequireAdditionalMethods: []string{MethodKeyboardInteractive},
|
||||
}, nil
|
||||
} else if err != nil {
|
||||
return PublicKeyAuthMethodResponse{}, err
|
||||
}
|
||||
}
|
||||
|
||||
res, err := a.evaluator.EvaluateSSH(ctx, sshreq)
|
||||
if err != nil {
|
||||
return PublicKeyAuthMethodResponse{}, err
|
||||
}
|
||||
|
||||
// Interpret the results of policy evaluation.
|
||||
if res.HasReason(criteria.ReasonSSHPublickeyUnauthorized) {
|
||||
// This public key is not allowed, but the client is free to try a different key.
|
||||
return PublicKeyAuthMethodResponse{
|
||||
RequireAdditionalMethods: []string{MethodPublicKey},
|
||||
}, nil
|
||||
} else if res.HasReason(criteria.ReasonUserUnauthenticated) {
|
||||
// Mark public key as allowed, to initiate IdP login flow.
|
||||
return PublicKeyAuthMethodResponse{
|
||||
Allow: publicKeyAllowResponse(req.PublicKey),
|
||||
RequireAdditionalMethods: []string{MethodKeyboardInteractive},
|
||||
}, nil
|
||||
} else if res.Allow.Value && !res.Deny.Value {
|
||||
// Allowed, no login needed.
|
||||
return PublicKeyAuthMethodResponse{
|
||||
Allow: publicKeyAllowResponse(req.PublicKey),
|
||||
}, nil
|
||||
}
|
||||
// Denied, no login needed.
|
||||
return PublicKeyAuthMethodResponse{}, nil
|
||||
}
|
||||
|
||||
func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse {
|
||||
return &extensions_ssh.PublicKeyAllowResponse{
|
||||
PublicKey: publicKey,
|
||||
Permissions: &extensions_ssh.Permissions{
|
||||
PermitPortForwarding: true,
|
||||
PermitAgentForwarding: true,
|
||||
PermitX11Forwarding: true,
|
||||
PermitPty: true,
|
||||
PermitUserRc: true,
|
||||
ValidStartTime: timestamppb.New(time.Now().Add(-1 * time.Minute)),
|
||||
ValidEndTime: timestamppb.New(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) HandleKeyboardInteractiveMethodRequest(
|
||||
ctx context.Context,
|
||||
info StreamAuthInfo,
|
||||
_ *extensions_ssh.KeyboardInteractiveMethodRequest,
|
||||
querier KeyboardInteractiveQuerier,
|
||||
) (KeyboardInteractiveAuthMethodResponse, error) {
|
||||
resp, err := a.handleKeyboardInteractiveMethodRequest(ctx, info, querier)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("ssh keyboard-interactive auth request error")
|
||||
return resp, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (a *Auth) handleKeyboardInteractiveMethodRequest(
|
||||
ctx context.Context,
|
||||
info StreamAuthInfo,
|
||||
querier KeyboardInteractiveQuerier,
|
||||
) (KeyboardInteractiveAuthMethodResponse, error) {
|
||||
if info.PublicKeyAllow.Value == nil {
|
||||
// Sanity check: this method is only valid if we already accepted a public key.
|
||||
return KeyboardInteractiveAuthMethodResponse{}, errPublicKeyAllowNil
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug().
|
||||
Str("username", *info.Username).
|
||||
Str("hostname", *info.Hostname).
|
||||
Str("publickey-fingerprint", base64.StdEncoding.EncodeToString(info.PublicKeyFingerprintSha256)).
|
||||
Msg("ssh keyboard-interactive auth request")
|
||||
|
||||
// Initiate the IdP login flow.
|
||||
err := a.handleLogin(ctx, *info.Hostname, info.PublicKeyFingerprintSha256, querier)
|
||||
if err != nil {
|
||||
return KeyboardInteractiveAuthMethodResponse{}, err
|
||||
}
|
||||
|
||||
if err := a.EvaluateDelayed(ctx, info); err != nil {
|
||||
// Denied.
|
||||
return KeyboardInteractiveAuthMethodResponse{}, nil
|
||||
}
|
||||
// Allowed.
|
||||
return KeyboardInteractiveAuthMethodResponse{
|
||||
Allow: &extensions_ssh.KeyboardInteractiveAllowResponse{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Auth) handleLogin(
|
||||
ctx context.Context,
|
||||
hostname string,
|
||||
publicKeyFingerprint []byte,
|
||||
querier KeyboardInteractiveQuerier,
|
||||
) error {
|
||||
// Initiate the IdP login flow.
|
||||
authenticator, err := a.getAuthenticator(ctx, hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := authenticator.DeviceAuth(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prompt the user to sign in.
|
||||
_, _ = querier.Prompt(ctx, &extensions_ssh.KeyboardInteractiveInfoPrompts{
|
||||
Name: "Please sign in with " + authenticator.Name() + " to continue",
|
||||
Instruction: resp.VerificationURIComplete,
|
||||
Prompts: nil,
|
||||
})
|
||||
|
||||
var sessionClaims identity.SessionClaims
|
||||
token, err := authenticator.DeviceAccessToken(ctx, resp, &sessionClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionID, err := sessionIDFromFingerprint(publicKeyFingerprint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return a.saveSession(ctx, sessionID, &sessionClaims, token)
|
||||
}
|
||||
|
||||
var errAccessDenied = errors.New("access denied")
|
||||
|
||||
func (a *Auth) EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error {
|
||||
req, err := sshRequestFromStreamAuthInfo(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := a.evaluator.EvaluateSSH(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Allow.Value && !res.Deny.Value {
|
||||
return nil
|
||||
}
|
||||
return errAccessDenied
|
||||
}
|
||||
|
||||
func (a *Auth) FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error) {
|
||||
sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := session.Get(ctx, a.dataBrokerClient, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var b bytes.Buffer
|
||||
err = sessionInfoTmpl.Execute(&b, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error {
|
||||
sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = session.Delete(ctx, a.dataBrokerClient, sessionID)
|
||||
a.invalidateCacheForRecord(ctx, &databroker.Record{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: sessionID,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Auth) saveSession(
|
||||
ctx context.Context,
|
||||
id string,
|
||||
claims *identity.SessionClaims,
|
||||
token *oauth2.Token,
|
||||
) error {
|
||||
now := time.Now()
|
||||
nowpb := timestamppb.New(now)
|
||||
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
|
||||
|
||||
state := sessions.State{ID: id}
|
||||
if err := claims.Claims.Claims(&state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sess := &session.Session{
|
||||
Id: id,
|
||||
UserId: state.UserID(),
|
||||
IssuedAt: nowpb,
|
||||
AccessedAt: nowpb,
|
||||
ExpiresAt: timestamppb.New(now.Add(sessionLifetime)),
|
||||
OauthToken: manager.ToOAuthToken(token),
|
||||
Audience: state.Audience,
|
||||
}
|
||||
sess.SetRawIDToken(claims.RawIDToken)
|
||||
sess.AddClaims(claims.Flatten())
|
||||
|
||||
u, _ := user.Get(ctx, a.dataBrokerClient, sess.GetUserId())
|
||||
if u == nil {
|
||||
// if no user exists yet, create a new one
|
||||
u = &user.User{
|
||||
Id: sess.GetUserId(),
|
||||
}
|
||||
}
|
||||
u.PopulateFromClaims(claims.Claims)
|
||||
_, err := databroker.Put(ctx, a.dataBrokerClient, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := session.Put(ctx, a.dataBrokerClient, sess)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.invalidateCacheForRecord(ctx, resp.GetRecord())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) invalidateCacheForRecord(ctx context.Context, record *databroker.Record) {
|
||||
ctx = storage.WithQuerier(ctx,
|
||||
storage.NewCachingQuerier(storage.NewQuerier(a.dataBrokerClient), storage.GlobalCache))
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, record)
|
||||
}
|
||||
|
||||
func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (identity.Authenticator, error) {
|
||||
opts := a.currentConfig.Load().Options
|
||||
|
||||
redirectURL, err := opts.GetAuthenticateRedirectURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idp, err := opts.GetIdentityProviderForPolicy(opts.GetRouteForSSHHostname(hostname))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL)
|
||||
}
|
||||
|
||||
var _ AuthInterface = (*Auth)(nil)
|
||||
|
||||
var errInvalidFingerprint = errors.New("invalid public key fingerprint")
|
||||
|
||||
func sessionIDFromFingerprint(sha256fingerprint []byte) (string, error) {
|
||||
if len(sha256fingerprint) != sha256.Size {
|
||||
return "", errInvalidFingerprint
|
||||
}
|
||||
return "sshkey-SHA256:" + base64.StdEncoding.EncodeToString(sha256fingerprint), nil
|
||||
}
|
||||
|
||||
var errPublicKeyAllowNil = errors.New("expected PublicKeyAllow message not to be nil")
|
||||
|
||||
// Converts from StreamAuthInfo to an SSHRequest, assuming the PublicKeyAllow field is not nil.
|
||||
func sshRequestFromStreamAuthInfo(info StreamAuthInfo) (*Request, error) {
|
||||
if info.PublicKeyAllow.Value == nil {
|
||||
return nil, errPublicKeyAllowNil
|
||||
}
|
||||
sessionID, err := sessionIDFromFingerprint(info.PublicKeyFingerprintSha256)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Request{
|
||||
Username: *info.Username,
|
||||
Hostname: *info.Hostname,
|
||||
PublicKey: info.PublicKeyAllow.Value.PublicKey,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var sessionInfoTmpl = template.Must(template.New("session-info").Parse(`
|
||||
User ID: {{.UserId}}
|
||||
Session ID: {{.Id}}
|
||||
Expires at: {{.ExpiresAt.AsTime}}
|
||||
Claims:
|
||||
{{- range $k, $v := .Claims }}
|
||||
{{ $k }}: {{ $v.AsSlice }}
|
||||
{{- end }}
|
||||
`))
|
469
pkg/ssh/auth_test.go
Normal file
469
pkg/ssh/auth_test.go
Normal file
|
@ -0,0 +1,469 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/testutil/mockidp"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestHandlePublicKeyMethodRequest(t *testing.T) {
|
||||
t.Run("no public key fingerprint", func(t *testing.T) {
|
||||
var a Auth
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
_, err := a.handlePublicKeyMethodRequest(t.Context(), StreamAuthInfo{}, &req)
|
||||
assert.ErrorContains(t, err, "invalid public key fingerprint")
|
||||
})
|
||||
t.Run("evaluate error", func(t *testing.T) {
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(context.Context, *Request) (*evaluator.Result, error) {
|
||||
return nil, errors.New("error evaluating policy")
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
_, err := a.handlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.ErrorContains(t, err, "error evaluating policy")
|
||||
})
|
||||
t.Run("allow", func(t *testing.T) {
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
fakePublicKey := []byte("fake-public-key")
|
||||
req.PublicKey = fakePublicKey
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, r *Request) (*evaluator.Result, error) {
|
||||
assert.Equal(t, r, &Request{
|
||||
Username: "username",
|
||||
Hostname: "hostname",
|
||||
PublicKey: fakePublicKey,
|
||||
SessionID: "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=",
|
||||
})
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, res.RequireAdditionalMethods)
|
||||
require.NotNil(t, res.Allow)
|
||||
assert.Equal(t, res.Allow.PublicKey, fakePublicKey)
|
||||
})
|
||||
t.Run("deny", func(t *testing.T) {
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(true),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res.Allow)
|
||||
assert.Empty(t, res.RequireAdditionalMethods)
|
||||
})
|
||||
t.Run("public key unauthorized", func(t *testing.T) {
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false, criteria.ReasonSSHPublickeyUnauthorized),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, res.Allow)
|
||||
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodPublicKey})
|
||||
})
|
||||
t.Run("needs login", func(t *testing.T) {
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, nil, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res.Allow)
|
||||
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive})
|
||||
})
|
||||
t.Run("internal command no session", func(t *testing.T) {
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.GetResponse, error) {
|
||||
return nil, status.Error(codes.NotFound, "not found")
|
||||
},
|
||||
}
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr(""),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, client, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res.Allow)
|
||||
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive})
|
||||
})
|
||||
t.Run("internal command with session", func(t *testing.T) {
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.GetResponse, error) {
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: "abc",
|
||||
Data: protoutil.NewAny(&session.Session{
|
||||
Id: "abc",
|
||||
UserId: "USER-ID",
|
||||
}),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr(""),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, client, nil, nil)
|
||||
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res.Allow)
|
||||
assert.Empty(t, res.RequireAdditionalMethods)
|
||||
})
|
||||
t.Run("internal command databroker error", func(t *testing.T) {
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.GetResponse, error) {
|
||||
return nil, status.Error(codes.Unknown, "unknown")
|
||||
},
|
||||
}
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr(""),
|
||||
}
|
||||
var req extensions_ssh.PublicKeyMethodRequest
|
||||
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
a := NewAuth(pe, client, nil, nil)
|
||||
_, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
|
||||
assert.ErrorContains(t, err, "internal error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
|
||||
t.Run("no public key", func(t *testing.T) {
|
||||
var a Auth
|
||||
_, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), StreamAuthInfo{}, nil)
|
||||
assert.ErrorContains(t, err, "expected PublicKeyAllow message not to be nil")
|
||||
})
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
var putRecords []*databroker.Record
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.GetResponse, error) {
|
||||
return nil, status.Error(codes.NotFound, "not found")
|
||||
},
|
||||
put: func(
|
||||
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.PutResponse, error) {
|
||||
putRecords = append(putRecords, in.Records...)
|
||||
return &databroker.PutResponse{
|
||||
Records: in.Records,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := config.Config{
|
||||
Options: config.NewDefaultOptions(),
|
||||
}
|
||||
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
|
||||
idpURL := mockIDP.Start(t)
|
||||
cfg.Options.Provider = "oidc"
|
||||
cfg.Options.ProviderURL = idpURL
|
||||
cfg.Options.ClientID = "client-id"
|
||||
cfg.Options.ClientSecret = "client-secret"
|
||||
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
|
||||
Value: &extensions_ssh.PublicKeyAllowResponse{
|
||||
PublicKey: []byte("fake-public-key"),
|
||||
},
|
||||
},
|
||||
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
|
||||
}
|
||||
res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, res.Allow)
|
||||
assert.Empty(t, res.RequireAdditionalMethods)
|
||||
|
||||
// A new Session and User record should have been saved to the databroker.
|
||||
assert.Len(t, putRecords, 2)
|
||||
|
||||
assert.Equal(t, "type.googleapis.com/user.User", putRecords[0].Type)
|
||||
assert.Equal(t, "fake.user@example.com", putRecords[0].Id)
|
||||
|
||||
assert.Equal(t, "type.googleapis.com/session.Session", putRecords[1].Type)
|
||||
assert.Equal(t, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=", putRecords[1].Id)
|
||||
})
|
||||
t.Run("denied", func(t *testing.T) {
|
||||
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
|
||||
return &evaluator.Result{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
}, nil
|
||||
})
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.GetResponse, error) {
|
||||
return nil, status.Error(codes.NotFound, "not found")
|
||||
},
|
||||
put: func(
|
||||
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.PutResponse, error) {
|
||||
return &databroker.PutResponse{
|
||||
Records: in.Records,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := config.Config{
|
||||
Options: config.NewDefaultOptions(),
|
||||
}
|
||||
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
|
||||
idpURL := mockIDP.Start(t)
|
||||
cfg.Options.Provider = "oidc"
|
||||
cfg.Options.ProviderURL = idpURL
|
||||
cfg.Options.ClientID = "client-id"
|
||||
cfg.Options.ClientSecret = "client-secret"
|
||||
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
|
||||
Value: &extensions_ssh.PublicKeyAllowResponse{
|
||||
PublicKey: []byte("fake-public-key"),
|
||||
},
|
||||
},
|
||||
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
|
||||
}
|
||||
res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{})
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, res.Allow)
|
||||
assert.Empty(t, res.RequireAdditionalMethods)
|
||||
})
|
||||
t.Run("invalid fingerprint", func(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Options: config.NewDefaultOptions(),
|
||||
}
|
||||
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
|
||||
idpURL := mockIDP.Start(t)
|
||||
cfg.Options.Provider = "oidc"
|
||||
cfg.Options.ProviderURL = idpURL
|
||||
cfg.Options.ClientID = "client-id"
|
||||
cfg.Options.ClientSecret = "client-secret"
|
||||
a := NewAuth(nil, nil, atomicutil.NewValue(&cfg), nil)
|
||||
info := StreamAuthInfo{
|
||||
Username: ptr("username"),
|
||||
Hostname: ptr("hostname"),
|
||||
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
|
||||
Value: &extensions_ssh.PublicKeyAllowResponse{
|
||||
PublicKey: []byte("fake-public-key"),
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), info, noopQuerier{})
|
||||
assert.ErrorContains(t, err, "invalid public key fingerprint")
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatSession(t *testing.T) {
|
||||
t.Run("invalid fingerprint", func(t *testing.T) {
|
||||
var a Auth
|
||||
info := StreamAuthInfo{
|
||||
PublicKeyFingerprintSha256: []byte("wrong-length"),
|
||||
}
|
||||
_, err := a.FormatSession(t.Context(), info)
|
||||
assert.ErrorContains(t, err, "invalid public key fingerprint")
|
||||
})
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
client := fakeDataBrokerServiceClient{
|
||||
get: func(
|
||||
_ context.Context, in *databroker.GetRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.GetResponse, error) {
|
||||
const expectedID = "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY="
|
||||
assert.Equal(t, in.Type, "type.googleapis.com/session.Session")
|
||||
assert.Equal(t, in.Id, expectedID)
|
||||
claims := identity.FlattenedClaims{
|
||||
"foo": []any{"bar", "baz"},
|
||||
"quux": []any{42},
|
||||
}
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: expectedID,
|
||||
Data: protoutil.NewAny(&session.Session{
|
||||
Id: expectedID,
|
||||
UserId: "USER-ID",
|
||||
ExpiresAt: ×tamppb.Timestamp{Seconds: 1750965358},
|
||||
Claims: claims.ToPB(),
|
||||
}),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
a := NewAuth(nil, client, nil, nil)
|
||||
info := StreamAuthInfo{
|
||||
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
|
||||
}
|
||||
b, err := a.FormatSession(t.Context(), info)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, string(b), `
|
||||
User ID: USER-ID
|
||||
Session ID: sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=
|
||||
Expires at: 2025-06-26 19:15:58 +0000 UTC
|
||||
Claims:
|
||||
foo: [bar baz]
|
||||
quux: [42]
|
||||
`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteSession(t *testing.T) {
|
||||
t.Run("invalid fingerprint", func(t *testing.T) {
|
||||
var a Auth
|
||||
info := StreamAuthInfo{
|
||||
PublicKeyFingerprintSha256: []byte("wrong-length"),
|
||||
}
|
||||
err := a.DeleteSession(t.Context(), info)
|
||||
assert.ErrorContains(t, err, "invalid public key fingerprint")
|
||||
})
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
putError := errors.New("sentinel")
|
||||
client := fakeDataBrokerServiceClient{
|
||||
put: func(
|
||||
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
|
||||
) (*databroker.PutResponse, error) {
|
||||
require.Len(t, in.Records, 1)
|
||||
assert.Equal(t, in.Records[0].Id, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=")
|
||||
assert.NotNil(t, in.Records[0].DeletedAt)
|
||||
return nil, putError
|
||||
},
|
||||
}
|
||||
a := NewAuth(nil, client, nil, nil)
|
||||
info := StreamAuthInfo{
|
||||
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
|
||||
}
|
||||
err := a.DeleteSession(t.Context(), info)
|
||||
assert.Equal(t, putError, err)
|
||||
})
|
||||
}
|
||||
|
||||
type policyEvaluatorFunc func(context.Context, *Request) (*evaluator.Result, error)
|
||||
|
||||
func (f policyEvaluatorFunc) EvaluateSSH(
|
||||
ctx context.Context, req *Request,
|
||||
) (*evaluator.Result, error) {
|
||||
return f(ctx, req)
|
||||
}
|
||||
|
||||
type fakeDataBrokerServiceClient struct {
|
||||
databroker.DataBrokerServiceClient
|
||||
|
||||
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
|
||||
put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
|
||||
}
|
||||
|
||||
func (m fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
return m.get(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (m fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
return m.put(ctx, in, opts...)
|
||||
}
|
||||
|
||||
type noopQuerier struct{}
|
||||
|
||||
func (noopQuerier) Prompt(
|
||||
_ context.Context, _ *extensions_ssh.KeyboardInteractiveInfoPrompts,
|
||||
) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ptr[T any](t T) *T {
|
||||
return &t
|
||||
}
|
|
@ -8,14 +8,12 @@ import (
|
|||
)
|
||||
|
||||
type StreamManager struct {
|
||||
auth AuthInterface
|
||||
mu sync.Mutex
|
||||
activeStreams map[uint64]*StreamHandler
|
||||
}
|
||||
|
||||
func NewStreamManager(auth AuthInterface) *StreamManager {
|
||||
func NewStreamManager() *StreamManager {
|
||||
return &StreamManager{
|
||||
auth: auth,
|
||||
activeStreams: map[uint64]*StreamHandler{},
|
||||
}
|
||||
}
|
||||
|
@ -30,13 +28,17 @@ func (sm *StreamManager) LookupStream(streamID uint64) *StreamHandler {
|
|||
return stream
|
||||
}
|
||||
|
||||
func (sm *StreamManager) NewStreamHandler(cfg *config.Config, downstream *extensions_ssh.DownstreamConnectEvent) *StreamHandler {
|
||||
func (sm *StreamManager) NewStreamHandler(
|
||||
cfg *config.Config,
|
||||
auth AuthInterface,
|
||||
downstream *extensions_ssh.DownstreamConnectEvent,
|
||||
) *StreamHandler {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
streamID := downstream.StreamId
|
||||
writeC := make(chan *extensions_ssh.ServerMessage, 32)
|
||||
sh := &StreamHandler{
|
||||
auth: sm.auth,
|
||||
auth: auth,
|
||||
config: cfg,
|
||||
downstream: downstream,
|
||||
readC: make(chan *extensions_ssh.ClientMessage, 32),
|
||||
|
|
|
@ -22,7 +22,7 @@ func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL {
|
|||
func TestStreamManager(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
auth := mock_ssh.NewMockAuthInterface(ctrl)
|
||||
m := ssh.NewStreamManager(auth)
|
||||
m := ssh.NewStreamManager()
|
||||
|
||||
cfg := &config.Config{Options: config.NewDefaultOptions()}
|
||||
cfg.Options.Policies = []config.Policy{
|
||||
|
@ -32,7 +32,7 @@ func TestStreamManager(t *testing.T) {
|
|||
|
||||
t.Run("LookupStream", func(t *testing.T) {
|
||||
assert.Nil(t, m.LookupStream(1234))
|
||||
sh := m.NewStreamHandler(cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
|
||||
sh := m.NewStreamHandler(cfg, auth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1234})
|
||||
assert.Equal(t, sh, m.LookupStream(1234))
|
||||
sh.Close()
|
||||
assert.Nil(t, m.LookupStream(1234))
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -typed . AuthInterface
|
||||
// mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
|
||||
//
|
||||
|
||||
// Package mock_ssh is a generated GoMock package.
|
||||
|
|
|
@ -37,6 +37,8 @@ type (
|
|||
KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
|
||||
)
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
|
||||
|
||||
type AuthInterface interface {
|
||||
HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
|
||||
HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
|
||||
|
@ -284,8 +286,10 @@ func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_
|
|||
response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if response.Allow != nil {
|
||||
partial = true
|
||||
sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
|
||||
}
|
||||
partial = response.Allow != nil
|
||||
sh.state.PublicKeyAllow.Update(response.Allow)
|
||||
updateMethods(response.RequireAdditionalMethods)
|
||||
case MethodKeyboardInteractive:
|
||||
|
|
|
@ -100,7 +100,7 @@ type StreamHandlerSuite struct {
|
|||
func (s *StreamHandlerSuite) SetupTest() {
|
||||
s.ctrl = NewController(s.T())
|
||||
s.mockAuth = mock_ssh.NewMockAuthInterface(s.ctrl)
|
||||
s.mgr = ssh.NewStreamManager(s.mockAuth)
|
||||
s.mgr = ssh.NewStreamManager()
|
||||
s.cleanup = []func(){}
|
||||
s.errC = make(chan error, 1)
|
||||
|
||||
|
@ -162,7 +162,8 @@ func (s *StreamHandlerSuite) expectError(fn func(), msg string) {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) startStreamHandler(streamID uint64) *ssh.StreamHandler {
|
||||
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: streamID})
|
||||
s.errC = make(chan error, 1)
|
||||
ctx, ca := context.WithCancel(s.T().Context())
|
||||
go func() {
|
||||
|
@ -1996,7 +1997,8 @@ func (s *StreamHandlerSuite) TestFormatSession() {
|
|||
s.mockAuth.EXPECT().
|
||||
FormatSession(Any(), Any()).
|
||||
Return([]byte("example"), nil)
|
||||
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
ctx, ca := context.WithCancel(context.Background())
|
||||
ca()
|
||||
// this will exit immediately, but it will have a state, which is only
|
||||
|
@ -2012,7 +2014,8 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
|
|||
s.mockAuth.EXPECT().
|
||||
DeleteSession(Any(), Any()).
|
||||
Return(nil)
|
||||
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
ctx, ca := context.WithCancel(context.Background())
|
||||
ca()
|
||||
// this will exit immediately, but it will have a state, which is only
|
||||
|
@ -2024,7 +2027,8 @@ func (s *StreamHandlerSuite) TestDeleteSession() {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TestRunCalledTwice() {
|
||||
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
ctx, ca := context.WithCancel(context.Background())
|
||||
ca()
|
||||
sh.Run(ctx)
|
||||
|
@ -2034,7 +2038,8 @@ func (s *StreamHandlerSuite) TestRunCalledTwice() {
|
|||
}
|
||||
|
||||
func (s *StreamHandlerSuite) TestAllSSHRoutes() {
|
||||
sh := s.mgr.NewStreamHandler(s.cfg, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
sh := s.mgr.NewStreamHandler(
|
||||
s.cfg, s.mockAuth, &extensions_ssh.DownstreamConnectEvent{StreamId: 1})
|
||||
routes := slices.Collect(sh.AllSSHRoutes())
|
||||
s.Len(routes, 2)
|
||||
s.Equal("ssh://host1", routes[0].From)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue