Merge remote-tracking branch 'origin/kenjenkins/ssh-proxy-auth-integration' into experimental/ssh

This commit is contained in:
Joe Kralicky 2025-03-19 20:56:59 +00:00
commit 08252f32df
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
15 changed files with 1061 additions and 821 deletions

View file

@ -12,6 +12,8 @@ import (
"github.com/rs/zerolog"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/authorize/evaluator"
@ -19,11 +21,16 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
)
// Authorize struct holds
@ -181,3 +188,66 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
}
}
}
type evaluateResult struct {
// Overall allow/deny result.
Allowed bool
// Reasons for the overall result.
Reasons criteria.Reasons
// Reason detail traces. (Populated only if enabled by the policy.)
Traces []contextutil.PolicyEvaluationTrace
}
func (a *Authorize) evaluate(
ctx context.Context,
req *evaluator.Request,
sessionState *sessions.State,
) (*evaluator.Result, error) {
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)
requestID := requestid.FromContext(ctx)
state := a.state.Load()
var s sessionOrServiceAccount
var u *user.User
var err error
if sessionState != nil {
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if status.Code(err) == codes.Unavailable {
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
return nil, err
} else if err != nil {
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("missing or invalid session or service account")
sessionState = nil
}
}
if sessionState != nil && s != nil {
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
res, err := state.evaluator.Evaluate(ctx, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
return nil, err
}
a.logAuthorizeCheck(ctx, req, res, s, u)
/*result := &evaluateResult{
Allowed: res.Allow.Value && !res.Deny.Value,
}
// if show error details is enabled, attach the policy evaluation traces
if req.Policy != nil && req.Policy.ShowErrorDetails {
result.Traces = res.Traces
}*/
return res, nil
}

View file

@ -38,6 +38,7 @@ type RequestHTTP struct {
Method string `json:"method"`
Hostname string `json:"hostname"`
Path string `json:"path"`
Query string `json:"-"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
ClientCertificate ClientCertificateInfo `json:"client_certificate"`

View file

@ -31,6 +31,10 @@ type PolicyResponse struct {
Traces []contextutil.PolicyEvaluationTrace
}
func (r *PolicyResponse) Allowed() bool {
return r.Allow.Value && !r.Deny.Value
}
// NewPolicyResponse creates a new PolicyResponse.
func NewPolicyResponse() *PolicyResponse {
return &PolicyResponse{

View file

@ -18,9 +18,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
"google.golang.org/grpc/codes"
@ -33,12 +31,6 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
defer span.End()
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)
state := a.state.Load()
// convert the incoming envoy-style http request into a go-style http request
@ -46,41 +38,24 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
requestID := requestid.FromHTTPHeader(hreq.Header)
ctx = requestid.WithValue(ctx, requestID)
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in)
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
return nil, err
}
// load the session
s, err := a.loadSession(ctx, hreq, req)
if err != nil {
return nil, err
}
// if there's a session or service account, load the user
var u *user.User
if s != nil {
req.Session.ID = s.GetId()
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
res, err := state.evaluator.Evaluate(ctx, req)
res, err := a.evaluate(ctx, req, sessionState)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
return nil, err
}
// if show error details is enabled, attach the policy evaluation traces
if req.Policy != nil && req.Policy.ShowErrorDetails {
ctx = contextutil.WithPolicyEvaluationTraces(ctx, res.Traces)
}
resp, err := a.handleResult(ctx, in, req, res)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("grpc check ext_authz_error")
}
a.logAuthorizeCheck(ctx, in, res, s, u)
return resp, err
}
@ -144,6 +119,7 @@ func (a *Authorize) loadSession(
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
sessionState *sessions.State,
) (*evaluator.Request, error) {
requestURL := getCheckRequestURL(in)
attrs := in.GetAttributes()
@ -158,6 +134,11 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
),
}
if sessionState != nil {
req.Session = evaluator.RequestSession{
ID: sessionState.ID,
}
}
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
return req, nil
}

View file

@ -88,7 +88,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
},
},
},
)
nil)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentConfig.Load().Options.Policies[0],
@ -140,7 +140,7 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
},
},
},
})
}, nil)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentConfig.Load().Options.Policies[0],

View file

@ -200,7 +200,11 @@ func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrI
res, err := storage.GetQuerier(ctx).Query(ctx, req, grpc.WaitForReady(true))
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
log.Ctx(ctx).Error().
Str("record-type", recordType).
Str("record-id-or-index", recordIDOrIndex).
Err(err).
Msg("authorize/store: error retrieving record")
return nil
}

View file

@ -2,9 +2,7 @@ package authorize
import (
"context"
"strings"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
@ -21,19 +19,19 @@ import (
func (a *Authorize) logAuthorizeCheck(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
in *evaluator.Request,
res *evaluator.Result, s sessionOrServiceAccount, u *user.User,
) {
ctx, span := a.tracer.Start(ctx, "authorize.grpc.LogAuthorizeCheck")
defer span.End()
hdrs := getCheckRequestHeaders(in)
hdrs := in.HTTP.Headers
impersonateDetails := a.getImpersonateDetails(ctx, s)
evt := log.Ctx(ctx).Info().Str("service", "authorize")
fields := a.currentConfig.Load().Options.GetAuthorizeLogFields()
for _, field := range fields {
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res)
evt = populateLogEvent(ctx, field, evt, in, s, u, impersonateDetails, res)
}
evt = log.HTTPHeaders(evt, fields, hdrs)
@ -134,22 +132,19 @@ func populateLogEvent(
ctx context.Context,
field log.AuthorizeLogField,
evt *zerolog.Event,
in *envoy_service_auth_v3.CheckRequest,
in *evaluator.Request,
s sessionOrServiceAccount,
u *user.User,
hdrs map[string]string,
impersonateDetails *impersonateDetails,
res *evaluator.Result,
) *zerolog.Event {
path, query, _ := strings.Cut(in.GetAttributes().GetRequest().GetHttp().GetPath(), "?")
switch field {
case log.AuthorizeLogFieldCheckRequestID:
return evt.Str(string(field), hdrs["X-Request-Id"])
return evt.Str(string(field), in.HTTP.Headers["X-Request-Id"])
case log.AuthorizeLogFieldEmail:
return evt.Str(string(field), u.GetEmail())
case log.AuthorizeLogFieldHost:
return evt.Str(string(field), in.GetAttributes().GetRequest().GetHttp().GetHost())
return evt.Str(string(field), in.HTTP.Hostname)
case log.AuthorizeLogFieldIDToken:
if s, ok := s.(*session.Session); ok {
evt = evt.Str(string(field), s.GetIdToken().GetRaw())
@ -180,13 +175,13 @@ func populateLogEvent(
}
return evt
case log.AuthorizeLogFieldIP:
return evt.Str(string(field), in.GetAttributes().GetSource().GetAddress().GetSocketAddress().GetAddress())
return evt.Str(string(field), in.HTTP.IP)
case log.AuthorizeLogFieldMethod:
return evt.Str(string(field), in.GetAttributes().GetRequest().GetHttp().GetMethod())
return evt.Str(string(field), in.HTTP.Method)
case log.AuthorizeLogFieldPath:
return evt.Str(string(field), path)
return evt.Str(string(field), in.HTTP.Path)
case log.AuthorizeLogFieldQuery:
return evt.Str(string(field), query)
return evt.Str(string(field), in.HTTP.Query)
case log.AuthorizeLogFieldRequestID:
return evt.Str(string(field), requestid.FromContext(ctx))
case log.AuthorizeLogFieldServiceAccountID:

View file

@ -6,8 +6,6 @@ import (
"strings"
"testing"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
@ -24,27 +22,18 @@ func Test_populateLogEvent(t *testing.T) {
ctx := context.Background()
ctx = requestid.WithValue(ctx, "REQUEST-ID")
checkRequest := &envoy_service_auth_v3.CheckRequest{
Attributes: &envoy_service_auth_v3.AttributeContext{
Request: &envoy_service_auth_v3.AttributeContext_Request{
Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{
Host: "HOST",
Path: "https://www.example.com/some/path?a=b",
Method: "GET",
},
},
Source: &envoy_service_auth_v3.AttributeContext_Peer{
Address: &envoy_config_core_v3.Address{
Address: &envoy_config_core_v3.Address_SocketAddress{
SocketAddress: &envoy_config_core_v3.SocketAddress{
Address: "127.0.0.1",
},
},
},
request := &evaluator.Request{
HTTP: evaluator.RequestHTTP{
Method: "GET",
Hostname: "HOST",
Path: "/some/path",
Query: "a=b",
Headers: map[string]string{
"X-Request-Id": "CHECK-REQUEST-ID",
},
IP: "127.0.0.1",
},
}
headers := map[string]string{"X-Request-Id": "CHECK-REQUEST-ID"}
s := &session.Session{
Id: "SESSION-ID",
IdToken: &session.IDToken{
@ -86,7 +75,7 @@ func Test_populateLogEvent(t *testing.T) {
{log.AuthorizeLogFieldImpersonateUserID, s, `{"impersonate-user-id":"IMPERSONATE-USER-ID"}`},
{log.AuthorizeLogFieldIP, s, `{"ip":"127.0.0.1"}`},
{log.AuthorizeLogFieldMethod, s, `{"method":"GET"}`},
{log.AuthorizeLogFieldPath, s, `{"path":"https://www.example.com/some/path"}`},
{log.AuthorizeLogFieldPath, s, `{"path":"/some/path"}`},
{log.AuthorizeLogFieldQuery, s, `{"query":"a=b"}`},
{log.AuthorizeLogFieldRemovedGroupsCount, s, `{"removed-groups-count":42}`},
{log.AuthorizeLogFieldRequestID, s, `{"request-id":"REQUEST-ID"}`},
@ -104,7 +93,7 @@ func Test_populateLogEvent(t *testing.T) {
var buf bytes.Buffer
log := zerolog.New(&buf)
evt := log.Log()
evt = populateLogEvent(ctx, tc.field, evt, checkRequest, tc.s, u, headers, impersonateDetails, res)
evt = populateLogEvent(ctx, tc.field, evt, request, tc.s, u, impersonateDetails, res)
evt.Send()
assert.Equal(t, tc.expect, strings.TrimSpace(buf.String()))

View file

@ -3,6 +3,7 @@ package authorize
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
@ -23,14 +24,21 @@ import (
"github.com/klauspost/compress/zstd"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
extensions_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/identity/oauth"
"github.com/pomerium/pomerium/pkg/storage"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/anypb"
@ -136,6 +144,13 @@ func (a *Authorize) ManageStream(
}
})
// XXX
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)
eg.Go(func() error {
for {
select {
@ -154,8 +169,8 @@ func (a *Authorize) ManageStream(
var state StreamState
deviceAuthSuccess := &atomic.Bool{}
deviceAuthDone := make(chan struct{})
sessionState := &atomic.Pointer[sessions.State]{}
errC := make(chan error, 1)
a.activeStreamsMu.Lock()
@ -201,62 +216,18 @@ func (a *Authorize) ManageStream(
//
// validate public key here
//
session, err := a.GetPomeriumSession(ctx, pubkeyReq.PublicKey)
if err != nil {
return err // XXX: wrap this error?
}
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "publickey")
state.PublicKey = pubkeyReq.PublicKey
if authReq.Username == "" && authReq.Hostname == "" {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{},
},
},
},
},
},
}
sendC <- &resp
continue
} else if authReq.Username == "_mirror" && authReq.Hostname == "" {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{},
},
},
},
},
},
}
// id, _ := strconv.ParseUint(authReq.Hostname, 10, 64)
// resp := extensions_ssh.ServerMessage{
// Message: &extensions_ssh.ServerMessage_AuthResponse{
// AuthResponse: &extensions_ssh.AuthenticationResponse{
// Response: &extensions_ssh.AuthenticationResponse_Allow{
// Allow: &extensions_ssh.AllowResponse{
// Target: &extensions_ssh.AllowResponse_MirrorSession{
// MirrorSession: &extensions_ssh.MirrorSessionTarget{
// SourceId: id,
// Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
// },
// },
// },
// },
// },
// },
// }
sendC <- &resp
continue
} else if authReq.Username != "" && authReq.Hostname == "" {
if authReq.Username == "" {
return fmt.Errorf("no username given")
}
if authReq.Hostname == "" {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
@ -275,7 +246,24 @@ func (a *Authorize) ManageStream(
continue
}
if !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
if session != nil {
// Perform authorize check for this route
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
if err != nil {
return err
}
res, err := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
if err != nil {
return err
}
sendC <- handleEvaluatorResponseForSSH(res, &state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id)
}
}
if session == nil && !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
@ -291,19 +279,12 @@ func (a *Authorize) ManageStream(
sendC <- &resp
}
case "keyboard-interactive":
opts := a.currentOptions.Load()
var route *config.Policy
for r := range opts.GetAllPolicies() {
if r.From == "ssh://"+strings.TrimSuffix(strings.Join([]string{state.Hostname, opts.SSHHostname}, "."), ".") {
route = r
break
}
}
route := a.getSSHRouteForHostname(state.Hostname)
if route == nil {
return fmt.Errorf("invalid route")
}
// sessionState := a.state.Load()
opts := a.currentConfig.Load().Options
idp, err := opts.GetIdentityProviderForPolicy(route)
if err != nil {
return err
@ -357,13 +338,20 @@ func (a *Authorize) ManageStream(
return
}
s := sessions.NewState(idp.Id)
err = claims.Claims.Claims(&s)
claims.Claims.Claims(&s) // XXX
s.ID, err = getSessionIDForSSH(state.PublicKey)
if err != nil {
errC <- fmt.Errorf("error unmarshaling session state: %w", err)
errC <- err
return
}
fmt.Println(token)
deviceAuthSuccess.Store(true)
err = a.PersistSession(ctx, s, claims, token)
if err != nil {
fmt.Println("error from PersistSession:", err)
errC <- fmt.Errorf("error persisting session: %w", err)
return
}
sessionState.Store(s)
close(deviceAuthDone)
}()
}
@ -380,7 +368,7 @@ func (a *Authorize) ManageStream(
case <-deviceAuthDone:
case <-ctx.Done():
}
if deviceAuthSuccess.Load() {
if sessionState.Load() != nil {
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
} else {
resp := extensions_ssh.ServerMessage{
@ -393,78 +381,24 @@ func (a *Authorize) ManageStream(
},
}
sendC <- &resp
// retryReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
// Name: "",
// Instruction: "Login not successful yet, try again",
// Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
// // {},
// },
// }
// infoReqAny, _ := anypb.New(&retryReq)
// resp := extensions_ssh.ServerMessage{
// Message: &extensions_ssh.ServerMessage_AuthResponse{
// AuthResponse: &extensions_ssh.AuthenticationResponse{
// Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
// InfoRequest: &extensions_ssh.InfoRequest{
// Method: "keyboard-interactive",
// Request: infoReqAny,
// },
// },
// },
// },
// }
// sendC <- &resp
continue
}
if slices.Contains(state.MethodsAuthenticated, "publickey") {
pkData, _ := anypb.New(&extensions_ssh.PublicKeyAllowResponse{
PublicKey: state.PublicKey,
Permissions: &extensions_ssh.Permissions{
PermitPortForwarding: true,
PermitAgentForwarding: true,
PermitX11Forwarding: true,
PermitPty: true,
PermitUserRc: true,
ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)),
ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)),
},
})
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,
})
authResponse := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: state.Hostname,
AllowedMethods: []*extensions_ssh.AllowedMethod{
{
Method: "publickey",
MethodData: pkData,
},
{
Method: "keyboard-interactive",
},
},
Extensions: []*corev3.TypedExtensionConfig{
{
TypedConfig: sessionRecordingExt,
},
},
},
},
},
},
},
},
// Perform authorize check for this route
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
if err != nil {
return err
}
res, err := a.evaluate(ctx, req, sessionState.Load())
if err != nil {
return err
}
sendC <- handleEvaluatorResponseForSSH(res, &state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
}
sendC <- &authResponse
} else {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
@ -485,8 +419,218 @@ func (a *Authorize) ManageStream(
}
}
}
}
return eg.Wait()
func (a *Authorize) getSSHRouteForHostname(hostname string) *config.Policy {
opts := a.currentConfig.Load().Options
from := "ssh://" + strings.TrimSuffix(strings.Join([]string{hostname, opts.SSHHostname}, "."), ".")
for r := range opts.GetAllPolicies() {
if r.From == from {
return r
}
}
return nil
}
func (a *Authorize) GetPomeriumSession(
ctx context.Context, publicKey []byte,
) (*session.Session, error) {
sessionID, err := getSessionIDForSSH(publicKey)
if err != nil {
return nil, err
}
fmt.Println("session ID:", sessionID) // XXX
session, err := session.Get(ctx, a.GetDataBrokerServiceClient(), sessionID)
if err != nil {
if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound {
return nil, nil
}
return nil, err
}
return session, nil
}
func getSessionIDForSSH(publicKey []byte) (string, error) {
// XXX: get the fingerprint from Envoy rather than computing it here
k, err := gossh.ParsePublicKey(publicKey)
if err != nil {
return "", fmt.Errorf("couldn't parse ssh key: %w", err)
}
return "sshkey-" + gossh.FingerprintSHA256(k), nil
}
func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
state *StreamState,
) (*evaluator.Request, error) {
sessionID, err := getSessionIDForSSH(state.PublicKey)
if err != nil {
return nil, err
}
route := a.getSSHRouteForHostname(state.Hostname)
if route == nil {
return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
}
req := &evaluator.Request{
IsInternal: false,
HTTP: evaluator.RequestHTTP{
Hostname: route.From, // XXX: this is not quite right
// IP: ? // TODO
},
Session: evaluator.RequestSession{
ID: sessionID,
},
Policy: route,
}
return req, nil
}
func handleEvaluatorResponseForSSH(
result *evaluator.Result, state *StreamState,
) *extensions_ssh.ServerMessage {
// fmt.Printf(" *** evaluator result: %+v\n", result)
// TODO: ideally there would be a way to keep this in sync with the logic in check_response.go
allow := result.Allow.Value && !result.Deny.Value
if allow {
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,
})
return &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: state.Hostname,
AllowedMethods: []*extensions_ssh.AllowedMethod{
{
Method: "publickey",
MethodData: pkData,
},
{
Method: "keyboard-interactive",
},
},
Extensions: []*corev3.TypedExtensionConfig{
{
TypedConfig: sessionRecordingExt,
},
},
},
},
},
},
},
},
}
}
// XXX: do we want to send an equivalent to the "show error details" output
// in the case of a deny result?
// XXX: this is not quite right -- needs to exactly match the last list of methods
methods := []string{"publickey"}
if slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
methods = append(methods, "keyboard-interactive")
}
return &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Deny{
Deny: &extensions_ssh.DenyResponse{
Methods: methods,
},
},
},
},
}
}
func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse {
return &extensions_ssh.PublicKeyAllowResponse{
PublicKey: publicKey,
Permissions: &extensions_ssh.Permissions{
PermitPortForwarding: true,
PermitAgentForwarding: true,
PermitX11Forwarding: true,
PermitPty: true,
PermitUserRc: true,
ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)),
// XXX: tie this to Pomerium session lifetime?
ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)),
},
}
}
// PersistSession stores session and user data in the databroker.
func (a *Authorize) PersistSession(
ctx context.Context,
sessionState *sessions.State, // XXX: consider not using this struct
claims identity.SessionClaims,
accessToken *oauth2.Token,
) error {
now := time.Now()
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
sess := &session.Session{
Id: sessionState.ID,
UserId: sessionState.UserID(),
IssuedAt: timestamppb.New(now),
AccessedAt: timestamppb.New(now),
ExpiresAt: sessionExpiry,
OauthToken: manager.ToOAuthToken(accessToken),
Audience: sessionState.Audience,
}
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
// XXX: do we need to create a user record too?
// compare with Stateful.PersistSession()
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
if err != nil {
return err
}
sessionState.DatabrokerServerVersion = res.GetServerVersion()
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
return nil
}
func (a *Authorize) startContinuousAuthorization(
ctx context.Context,
errC chan<- error,
req *evaluator.Request,
sessionID string,
) {
recheck := func() {
// XXX: probably want to log the results of this evaluation only if it changes
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
if !res.Allow.Value || res.Deny.Value {
errC <- fmt.Errorf("no longer authorized")
}
}
ticker := time.NewTicker(time.Second)
go func() {
for {
select {
case <-ticker.C:
recheck()
case <-ctx.Done():
ticker.Stop()
return
}
}
}()
}
// See RFC 4254, section 5.1.
@ -653,7 +797,7 @@ func (a *Authorize) ServeChannel(
switch msg.Request {
case "pty-req":
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
var routes []string
for r := range opts.GetAllPolicies() {
if strings.HasPrefix(r.From, "ssh://") {

4
go.mod
View file

@ -18,8 +18,6 @@ require (
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.3.3
github.com/charmbracelet/lipgloss v1.0.0
github.com/charmbracelet/x/ansi v0.8.0
github.com/charmbracelet/x/term v0.2.1
github.com/cloudflare/circl v1.6.0
github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3
github.com/coreos/go-oidc/v3 v3.12.0
@ -145,6 +143,8 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect

View file

@ -5,11 +5,11 @@ package files
import _ "embed" // embed
//go:embed envoy-darwin-arm64
//go:embed envoy
var rawBinary []byte
//go:embed envoy-darwin-arm64.sha256
//go:embed envoy.sha256
var rawChecksum string
//go:embed envoy-darwin-arm64.version
//go:embed envoy.version
var rawVersion string

File diff suppressed because it is too large Load diff

View file

@ -167,7 +167,7 @@ message Policy {
string remediation = 9;
}
// Next ID: 140.
// Next ID: 141.
message Settings {
message Certificate {
bytes cert_bytes = 3;
@ -201,14 +201,15 @@ message Settings {
optional string cookie_secret = 17;
optional string cookie_domain = 18;
// optional bool cookie_secure = 19;
optional bool cookie_http_only = 20;
optional google.protobuf.Duration cookie_expire = 21;
optional string cookie_same_site = 113;
optional string idp_client_id = 22;
optional string idp_client_secret = 23;
optional string idp_provider = 24;
optional string idp_provider_url = 25;
repeated string scopes = 26;
optional bool cookie_http_only = 20;
optional google.protobuf.Duration cookie_expire = 21;
optional string cookie_same_site = 113;
optional string idp_client_id = 22;
optional string idp_client_secret = 23;
optional string idp_provider = 24;
optional string idp_provider_url = 25;
optional StringList idp_access_token_allowed_audiences = 137;
repeated string scopes = 26;
// optional string idp_service_account = 27;
// optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
// optional google.protobuf.Duration idp_refresh_directory_interval = 29;
@ -222,7 +223,9 @@ message Settings {
map<string, string> set_response_headers = 69;
// repeated string jwt_claims_headers = 37;
map<string, string> jwt_claims_headers = 63;
optional IssuerFormat jwt_issuer_format = 139;
repeated string jwt_groups_filter = 119;
optional BearerTokenFormat bearer_token_format = 138;
optional google.protobuf.Duration default_upstream_timeout = 39;
optional string metrics_address = 40;
optional string metrics_basic_auth = 64;
@ -288,7 +291,7 @@ message Settings {
optional bool pass_identity_headers = 117;
map<string, bool> runtime_flags = 118;
optional uint32 http3_advertise_port = 136;
optional string device_auth_client_type = 137;
optional string device_auth_client_type = 140;
}
message DownstreamMtlsSettings {

View file

@ -63,7 +63,6 @@ func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string)
return mp.SignInError
}
<<<<<<< HEAD
// VerifyAccessToken verifies an access token.
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyAccessToken not implemented")
@ -72,8 +71,8 @@ func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims ma
// VerifyIdentityToken verifies an identity token.
func (mp MockProvider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyIdentityToken not implemented")
||||||| 229ef72e5
=======
}
// DeviceAccessToken implements Authenticator.
func (mp MockProvider) DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state identity.State) (*oauth2.Token, error) {
return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError
@ -82,5 +81,4 @@ func (mp MockProvider) DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAu
// DeviceAuth implements Authenticator.
func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return &mp.DeviceAuthResponse, mp.DeviceAuthError
>>>>>>> kralicky/ssh-demo
}

View file

@ -168,7 +168,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
func (p *Proxy) DeviceAuthLogin(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load()
options := p.currentOptions.Load()
options := p.currentConfig.Load().Options
params := url.Values{}
routeUri := urlutil.GetAbsoluteURL(r)