Merge remote-tracking branch 'origin/kenjenkins/ssh-proxy-auth-integration' into experimental/ssh

This commit is contained in:
Joe Kralicky 2025-03-19 20:56:59 +00:00
commit 08252f32df
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
15 changed files with 1061 additions and 821 deletions

View file

@ -12,6 +12,8 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
oteltrace "go.opentelemetry.io/otel/trace" oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
@ -19,11 +21,16 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
) )
// Authorize struct holds // Authorize struct holds
@ -181,3 +188,66 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
} }
} }
} }
type evaluateResult struct {
// Overall allow/deny result.
Allowed bool
// Reasons for the overall result.
Reasons criteria.Reasons
// Reason detail traces. (Populated only if enabled by the policy.)
Traces []contextutil.PolicyEvaluationTrace
}
func (a *Authorize) evaluate(
ctx context.Context,
req *evaluator.Request,
sessionState *sessions.State,
) (*evaluator.Result, error) {
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)
requestID := requestid.FromContext(ctx)
state := a.state.Load()
var s sessionOrServiceAccount
var u *user.User
var err error
if sessionState != nil {
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if status.Code(err) == codes.Unavailable {
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
return nil, err
} else if err != nil {
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("missing or invalid session or service account")
sessionState = nil
}
}
if sessionState != nil && s != nil {
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
res, err := state.evaluator.Evaluate(ctx, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
return nil, err
}
a.logAuthorizeCheck(ctx, req, res, s, u)
/*result := &evaluateResult{
Allowed: res.Allow.Value && !res.Deny.Value,
}
// if show error details is enabled, attach the policy evaluation traces
if req.Policy != nil && req.Policy.ShowErrorDetails {
result.Traces = res.Traces
}*/
return res, nil
}

View file

@ -38,6 +38,7 @@ type RequestHTTP struct {
Method string `json:"method"` Method string `json:"method"`
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
Path string `json:"path"` Path string `json:"path"`
Query string `json:"-"`
URL string `json:"url"` URL string `json:"url"`
Headers map[string]string `json:"headers"` Headers map[string]string `json:"headers"`
ClientCertificate ClientCertificateInfo `json:"client_certificate"` ClientCertificate ClientCertificateInfo `json:"client_certificate"`

View file

@ -31,6 +31,10 @@ type PolicyResponse struct {
Traces []contextutil.PolicyEvaluationTrace Traces []contextutil.PolicyEvaluationTrace
} }
func (r *PolicyResponse) Allowed() bool {
return r.Allow.Value && !r.Deny.Value
}
// NewPolicyResponse creates a new PolicyResponse. // NewPolicyResponse creates a new PolicyResponse.
func NewPolicyResponse() *PolicyResponse { func NewPolicyResponse() *PolicyResponse {
return &PolicyResponse{ return &PolicyResponse{

View file

@ -18,9 +18,7 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid" "github.com/pomerium/pomerium/pkg/telemetry/requestid"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -33,12 +31,6 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check") ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
defer span.End() defer span.End()
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)
state := a.state.Load() state := a.state.Load()
// convert the incoming envoy-style http request into a go-style http request // convert the incoming envoy-style http request into a go-style http request
@ -46,41 +38,24 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
requestID := requestid.FromHTTPHeader(hreq.Header) requestID := requestid.FromHTTPHeader(hreq.Header)
ctx = requestid.WithValue(ctx, requestID) ctx = requestid.WithValue(ctx, requestID)
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in) sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request") log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
return nil, err return nil, err
} }
// load the session res, err := a.evaluate(ctx, req, sessionState)
s, err := a.loadSession(ctx, hreq, req)
if err != nil {
return nil, err
}
// if there's a session or service account, load the user
var u *user.User
if s != nil {
req.Session.ID = s.GetId()
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
res, err := state.evaluator.Evaluate(ctx, req)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation") log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
return nil, err return nil, err
} }
// if show error details is enabled, attach the policy evaluation traces
if req.Policy != nil && req.Policy.ShowErrorDetails {
ctx = contextutil.WithPolicyEvaluationTraces(ctx, res.Traces)
}
resp, err := a.handleResult(ctx, in, req, res) resp, err := a.handleResult(ctx, in, req, res)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("grpc check ext_authz_error") log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("grpc check ext_authz_error")
} }
a.logAuthorizeCheck(ctx, in, res, s, u)
return resp, err return resp, err
} }
@ -144,6 +119,7 @@ func (a *Authorize) loadSession(
func (a *Authorize) getEvaluatorRequestFromCheckRequest( func (a *Authorize) getEvaluatorRequestFromCheckRequest(
ctx context.Context, ctx context.Context,
in *envoy_service_auth_v3.CheckRequest, in *envoy_service_auth_v3.CheckRequest,
sessionState *sessions.State,
) (*evaluator.Request, error) { ) (*evaluator.Request, error) {
requestURL := getCheckRequestURL(in) requestURL := getCheckRequestURL(in)
attrs := in.GetAttributes() attrs := in.GetAttributes()
@ -158,6 +134,11 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(), attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
), ),
} }
if sessionState != nil {
req.Session = evaluator.RequestSession{
ID: sessionState.ID,
}
}
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions())) req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
return req, nil return req, nil
} }

View file

@ -88,7 +88,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
}, },
}, },
}, },
) nil)
require.NoError(t, err) require.NoError(t, err)
expect := &evaluator.Request{ expect := &evaluator.Request{
Policy: &a.currentConfig.Load().Options.Policies[0], Policy: &a.currentConfig.Load().Options.Policies[0],
@ -140,7 +140,7 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
}, },
}, },
}, },
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
expect := &evaluator.Request{ expect := &evaluator.Request{
Policy: &a.currentConfig.Load().Options.Policies[0], Policy: &a.currentConfig.Load().Options.Policies[0],

View file

@ -200,7 +200,11 @@ func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrI
res, err := storage.GetQuerier(ctx).Query(ctx, req, grpc.WaitForReady(true)) res, err := storage.GetQuerier(ctx).Query(ctx, req, grpc.WaitForReady(true))
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record") log.Ctx(ctx).Error().
Str("record-type", recordType).
Str("record-id-or-index", recordIDOrIndex).
Err(err).
Msg("authorize/store: error retrieving record")
return nil return nil
} }

View file

@ -2,9 +2,7 @@ package authorize
import ( import (
"context" "context"
"strings"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v3/jwt"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
@ -21,19 +19,19 @@ import (
func (a *Authorize) logAuthorizeCheck( func (a *Authorize) logAuthorizeCheck(
ctx context.Context, ctx context.Context,
in *envoy_service_auth_v3.CheckRequest, in *evaluator.Request,
res *evaluator.Result, s sessionOrServiceAccount, u *user.User, res *evaluator.Result, s sessionOrServiceAccount, u *user.User,
) { ) {
ctx, span := a.tracer.Start(ctx, "authorize.grpc.LogAuthorizeCheck") ctx, span := a.tracer.Start(ctx, "authorize.grpc.LogAuthorizeCheck")
defer span.End() defer span.End()
hdrs := getCheckRequestHeaders(in) hdrs := in.HTTP.Headers
impersonateDetails := a.getImpersonateDetails(ctx, s) impersonateDetails := a.getImpersonateDetails(ctx, s)
evt := log.Ctx(ctx).Info().Str("service", "authorize") evt := log.Ctx(ctx).Info().Str("service", "authorize")
fields := a.currentConfig.Load().Options.GetAuthorizeLogFields() fields := a.currentConfig.Load().Options.GetAuthorizeLogFields()
for _, field := range fields { for _, field := range fields {
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res) evt = populateLogEvent(ctx, field, evt, in, s, u, impersonateDetails, res)
} }
evt = log.HTTPHeaders(evt, fields, hdrs) evt = log.HTTPHeaders(evt, fields, hdrs)
@ -134,22 +132,19 @@ func populateLogEvent(
ctx context.Context, ctx context.Context,
field log.AuthorizeLogField, field log.AuthorizeLogField,
evt *zerolog.Event, evt *zerolog.Event,
in *envoy_service_auth_v3.CheckRequest, in *evaluator.Request,
s sessionOrServiceAccount, s sessionOrServiceAccount,
u *user.User, u *user.User,
hdrs map[string]string,
impersonateDetails *impersonateDetails, impersonateDetails *impersonateDetails,
res *evaluator.Result, res *evaluator.Result,
) *zerolog.Event { ) *zerolog.Event {
path, query, _ := strings.Cut(in.GetAttributes().GetRequest().GetHttp().GetPath(), "?")
switch field { switch field {
case log.AuthorizeLogFieldCheckRequestID: case log.AuthorizeLogFieldCheckRequestID:
return evt.Str(string(field), hdrs["X-Request-Id"]) return evt.Str(string(field), in.HTTP.Headers["X-Request-Id"])
case log.AuthorizeLogFieldEmail: case log.AuthorizeLogFieldEmail:
return evt.Str(string(field), u.GetEmail()) return evt.Str(string(field), u.GetEmail())
case log.AuthorizeLogFieldHost: case log.AuthorizeLogFieldHost:
return evt.Str(string(field), in.GetAttributes().GetRequest().GetHttp().GetHost()) return evt.Str(string(field), in.HTTP.Hostname)
case log.AuthorizeLogFieldIDToken: case log.AuthorizeLogFieldIDToken:
if s, ok := s.(*session.Session); ok { if s, ok := s.(*session.Session); ok {
evt = evt.Str(string(field), s.GetIdToken().GetRaw()) evt = evt.Str(string(field), s.GetIdToken().GetRaw())
@ -180,13 +175,13 @@ func populateLogEvent(
} }
return evt return evt
case log.AuthorizeLogFieldIP: case log.AuthorizeLogFieldIP:
return evt.Str(string(field), in.GetAttributes().GetSource().GetAddress().GetSocketAddress().GetAddress()) return evt.Str(string(field), in.HTTP.IP)
case log.AuthorizeLogFieldMethod: case log.AuthorizeLogFieldMethod:
return evt.Str(string(field), in.GetAttributes().GetRequest().GetHttp().GetMethod()) return evt.Str(string(field), in.HTTP.Method)
case log.AuthorizeLogFieldPath: case log.AuthorizeLogFieldPath:
return evt.Str(string(field), path) return evt.Str(string(field), in.HTTP.Path)
case log.AuthorizeLogFieldQuery: case log.AuthorizeLogFieldQuery:
return evt.Str(string(field), query) return evt.Str(string(field), in.HTTP.Query)
case log.AuthorizeLogFieldRequestID: case log.AuthorizeLogFieldRequestID:
return evt.Str(string(field), requestid.FromContext(ctx)) return evt.Str(string(field), requestid.FromContext(ctx))
case log.AuthorizeLogFieldServiceAccountID: case log.AuthorizeLogFieldServiceAccountID:

View file

@ -6,8 +6,6 @@ import (
"strings" "strings"
"testing" "testing"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -24,27 +22,18 @@ func Test_populateLogEvent(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = requestid.WithValue(ctx, "REQUEST-ID") ctx = requestid.WithValue(ctx, "REQUEST-ID")
checkRequest := &envoy_service_auth_v3.CheckRequest{ request := &evaluator.Request{
Attributes: &envoy_service_auth_v3.AttributeContext{ HTTP: evaluator.RequestHTTP{
Request: &envoy_service_auth_v3.AttributeContext_Request{ Method: "GET",
Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ Hostname: "HOST",
Host: "HOST", Path: "/some/path",
Path: "https://www.example.com/some/path?a=b", Query: "a=b",
Method: "GET", Headers: map[string]string{
}, "X-Request-Id": "CHECK-REQUEST-ID",
},
Source: &envoy_service_auth_v3.AttributeContext_Peer{
Address: &envoy_config_core_v3.Address{
Address: &envoy_config_core_v3.Address_SocketAddress{
SocketAddress: &envoy_config_core_v3.SocketAddress{
Address: "127.0.0.1",
},
},
},
}, },
IP: "127.0.0.1",
}, },
} }
headers := map[string]string{"X-Request-Id": "CHECK-REQUEST-ID"}
s := &session.Session{ s := &session.Session{
Id: "SESSION-ID", Id: "SESSION-ID",
IdToken: &session.IDToken{ IdToken: &session.IDToken{
@ -86,7 +75,7 @@ func Test_populateLogEvent(t *testing.T) {
{log.AuthorizeLogFieldImpersonateUserID, s, `{"impersonate-user-id":"IMPERSONATE-USER-ID"}`}, {log.AuthorizeLogFieldImpersonateUserID, s, `{"impersonate-user-id":"IMPERSONATE-USER-ID"}`},
{log.AuthorizeLogFieldIP, s, `{"ip":"127.0.0.1"}`}, {log.AuthorizeLogFieldIP, s, `{"ip":"127.0.0.1"}`},
{log.AuthorizeLogFieldMethod, s, `{"method":"GET"}`}, {log.AuthorizeLogFieldMethod, s, `{"method":"GET"}`},
{log.AuthorizeLogFieldPath, s, `{"path":"https://www.example.com/some/path"}`}, {log.AuthorizeLogFieldPath, s, `{"path":"/some/path"}`},
{log.AuthorizeLogFieldQuery, s, `{"query":"a=b"}`}, {log.AuthorizeLogFieldQuery, s, `{"query":"a=b"}`},
{log.AuthorizeLogFieldRemovedGroupsCount, s, `{"removed-groups-count":42}`}, {log.AuthorizeLogFieldRemovedGroupsCount, s, `{"removed-groups-count":42}`},
{log.AuthorizeLogFieldRequestID, s, `{"request-id":"REQUEST-ID"}`}, {log.AuthorizeLogFieldRequestID, s, `{"request-id":"REQUEST-ID"}`},
@ -104,7 +93,7 @@ func Test_populateLogEvent(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
log := zerolog.New(&buf) log := zerolog.New(&buf)
evt := log.Log() evt := log.Log()
evt = populateLogEvent(ctx, tc.field, evt, checkRequest, tc.s, u, headers, impersonateDetails, res) evt = populateLogEvent(ctx, tc.field, evt, request, tc.s, u, impersonateDetails, res)
evt.Send() evt.Send()
assert.Equal(t, tc.expect, strings.TrimSpace(buf.String())) assert.Equal(t, tc.expect, strings.TrimSpace(buf.String()))

View file

@ -3,6 +3,7 @@ package authorize
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -23,14 +24,21 @@ import (
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh" extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
extensions_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording" extensions_session_recording "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh/filters/session_recording"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/identity/oauth" "github.com/pomerium/pomerium/pkg/identity/oauth"
"github.com/pomerium/pomerium/pkg/storage"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protodelim" "google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
@ -136,6 +144,13 @@ func (a *Authorize) ManageStream(
} }
}) })
// XXX
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)
eg.Go(func() error { eg.Go(func() error {
for { for {
select { select {
@ -154,8 +169,8 @@ func (a *Authorize) ManageStream(
var state StreamState var state StreamState
deviceAuthSuccess := &atomic.Bool{}
deviceAuthDone := make(chan struct{}) deviceAuthDone := make(chan struct{})
sessionState := &atomic.Pointer[sessions.State]{}
errC := make(chan error, 1) errC := make(chan error, 1)
a.activeStreamsMu.Lock() a.activeStreamsMu.Lock()
@ -201,62 +216,18 @@ func (a *Authorize) ManageStream(
// //
// validate public key here // validate public key here
// //
session, err := a.GetPomeriumSession(ctx, pubkeyReq.PublicKey)
if err != nil {
return err // XXX: wrap this error?
}
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "publickey") state.MethodsAuthenticated = append(state.MethodsAuthenticated, "publickey")
state.PublicKey = pubkeyReq.PublicKey state.PublicKey = pubkeyReq.PublicKey
if authReq.Username == "" && authReq.Hostname == "" { if authReq.Username == "" {
resp := extensions_ssh.ServerMessage{ return fmt.Errorf("no username given")
Message: &extensions_ssh.ServerMessage_AuthResponse{ }
AuthResponse: &extensions_ssh.AuthenticationResponse{ if authReq.Hostname == "" {
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{},
},
},
},
},
},
}
sendC <- &resp
continue
} else if authReq.Username == "_mirror" && authReq.Hostname == "" {
resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{},
},
},
},
},
},
}
// id, _ := strconv.ParseUint(authReq.Hostname, 10, 64)
// resp := extensions_ssh.ServerMessage{
// Message: &extensions_ssh.ServerMessage_AuthResponse{
// AuthResponse: &extensions_ssh.AuthenticationResponse{
// Response: &extensions_ssh.AuthenticationResponse_Allow{
// Allow: &extensions_ssh.AllowResponse{
// Target: &extensions_ssh.AllowResponse_MirrorSession{
// MirrorSession: &extensions_ssh.MirrorSessionTarget{
// SourceId: id,
// Mode: extensions_ssh.MirrorSessionTarget_ReadWrite,
// },
// },
// },
// },
// },
// },
// }
sendC <- &resp
continue
} else if authReq.Username != "" && authReq.Hostname == "" {
resp := extensions_ssh.ServerMessage{ resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{ Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{
@ -275,7 +246,24 @@ func (a *Authorize) ManageStream(
continue continue
} }
if !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") { if session != nil {
// Perform authorize check for this route
req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
if err != nil {
return err
}
res, err := a.evaluate(ctx, req, &sessions.State{ID: session.Id})
if err != nil {
return err
}
sendC <- handleEvaluatorResponseForSSH(res, &state)
if res.Allow.Value && !res.Deny.Value {
a.startContinuousAuthorization(ctx, errC, req, session.Id)
}
}
if session == nil && !slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
resp := extensions_ssh.ServerMessage{ resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{ Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{ AuthResponse: &extensions_ssh.AuthenticationResponse{
@ -291,19 +279,12 @@ func (a *Authorize) ManageStream(
sendC <- &resp sendC <- &resp
} }
case "keyboard-interactive": case "keyboard-interactive":
opts := a.currentOptions.Load() route := a.getSSHRouteForHostname(state.Hostname)
var route *config.Policy
for r := range opts.GetAllPolicies() {
if r.From == "ssh://"+strings.TrimSuffix(strings.Join([]string{state.Hostname, opts.SSHHostname}, "."), ".") {
route = r
break
}
}
if route == nil { if route == nil {
return fmt.Errorf("invalid route") return fmt.Errorf("invalid route")
} }
// sessionState := a.state.Load()
opts := a.currentConfig.Load().Options
idp, err := opts.GetIdentityProviderForPolicy(route) idp, err := opts.GetIdentityProviderForPolicy(route)
if err != nil { if err != nil {
return err return err
@ -357,13 +338,20 @@ func (a *Authorize) ManageStream(
return return
} }
s := sessions.NewState(idp.Id) s := sessions.NewState(idp.Id)
err = claims.Claims.Claims(&s) claims.Claims.Claims(&s) // XXX
s.ID, err = getSessionIDForSSH(state.PublicKey)
if err != nil { if err != nil {
errC <- fmt.Errorf("error unmarshaling session state: %w", err) errC <- err
return return
} }
fmt.Println(token) fmt.Println(token)
deviceAuthSuccess.Store(true) err = a.PersistSession(ctx, s, claims, token)
if err != nil {
fmt.Println("error from PersistSession:", err)
errC <- fmt.Errorf("error persisting session: %w", err)
return
}
sessionState.Store(s)
close(deviceAuthDone) close(deviceAuthDone)
}() }()
} }
@ -380,7 +368,7 @@ func (a *Authorize) ManageStream(
case <-deviceAuthDone: case <-deviceAuthDone:
case <-ctx.Done(): case <-ctx.Done():
} }
if deviceAuthSuccess.Load() { if sessionState.Load() != nil {
state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive") state.MethodsAuthenticated = append(state.MethodsAuthenticated, "keyboard-interactive")
} else { } else {
resp := extensions_ssh.ServerMessage{ resp := extensions_ssh.ServerMessage{
@ -393,78 +381,24 @@ func (a *Authorize) ManageStream(
}, },
} }
sendC <- &resp sendC <- &resp
// retryReq := extensions_ssh.KeyboardInteractiveInfoPrompts{
// Name: "",
// Instruction: "Login not successful yet, try again",
// Prompts: []*extensions_ssh.KeyboardInteractiveInfoPrompts_Prompt{
// // {},
// },
// }
// infoReqAny, _ := anypb.New(&retryReq)
// resp := extensions_ssh.ServerMessage{
// Message: &extensions_ssh.ServerMessage_AuthResponse{
// AuthResponse: &extensions_ssh.AuthenticationResponse{
// Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
// InfoRequest: &extensions_ssh.InfoRequest{
// Method: "keyboard-interactive",
// Request: infoReqAny,
// },
// },
// },
// },
// }
// sendC <- &resp
continue continue
} }
if slices.Contains(state.MethodsAuthenticated, "publickey") { if slices.Contains(state.MethodsAuthenticated, "publickey") {
pkData, _ := anypb.New(&extensions_ssh.PublicKeyAllowResponse{ // Perform authorize check for this route
PublicKey: state.PublicKey, req, err := a.getEvaluatorRequestFromSSHAuthRequest(&state)
Permissions: &extensions_ssh.Permissions{ if err != nil {
PermitPortForwarding: true, return err
PermitAgentForwarding: true, }
PermitX11Forwarding: true, res, err := a.evaluate(ctx, req, sessionState.Load())
PermitPty: true, if err != nil {
PermitUserRc: true, return err
ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)), }
ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)), sendC <- handleEvaluatorResponseForSSH(res, &state)
},
}) if res.Allow.Value && !res.Deny.Value {
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{ a.startContinuousAuthorization(ctx, errC, req, sessionState.Load().ID)
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,
})
authResponse := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: state.Hostname,
AllowedMethods: []*extensions_ssh.AllowedMethod{
{
Method: "publickey",
MethodData: pkData,
},
{
Method: "keyboard-interactive",
},
},
Extensions: []*corev3.TypedExtensionConfig{
{
TypedConfig: sessionRecordingExt,
},
},
},
},
},
},
},
},
} }
sendC <- &authResponse
} else { } else {
resp := extensions_ssh.ServerMessage{ resp := extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{ Message: &extensions_ssh.ServerMessage_AuthResponse{
@ -485,8 +419,218 @@ func (a *Authorize) ManageStream(
} }
} }
} }
}
return eg.Wait() func (a *Authorize) getSSHRouteForHostname(hostname string) *config.Policy {
opts := a.currentConfig.Load().Options
from := "ssh://" + strings.TrimSuffix(strings.Join([]string{hostname, opts.SSHHostname}, "."), ".")
for r := range opts.GetAllPolicies() {
if r.From == from {
return r
}
}
return nil
}
func (a *Authorize) GetPomeriumSession(
ctx context.Context, publicKey []byte,
) (*session.Session, error) {
sessionID, err := getSessionIDForSSH(publicKey)
if err != nil {
return nil, err
}
fmt.Println("session ID:", sessionID) // XXX
session, err := session.Get(ctx, a.GetDataBrokerServiceClient(), sessionID)
if err != nil {
if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound {
return nil, nil
}
return nil, err
}
return session, nil
}
func getSessionIDForSSH(publicKey []byte) (string, error) {
// XXX: get the fingerprint from Envoy rather than computing it here
k, err := gossh.ParsePublicKey(publicKey)
if err != nil {
return "", fmt.Errorf("couldn't parse ssh key: %w", err)
}
return "sshkey-" + gossh.FingerprintSHA256(k), nil
}
func (a *Authorize) getEvaluatorRequestFromSSHAuthRequest(
state *StreamState,
) (*evaluator.Request, error) {
sessionID, err := getSessionIDForSSH(state.PublicKey)
if err != nil {
return nil, err
}
route := a.getSSHRouteForHostname(state.Hostname)
if route == nil {
return nil, fmt.Errorf("no route found for hostname %q", state.Hostname)
}
req := &evaluator.Request{
IsInternal: false,
HTTP: evaluator.RequestHTTP{
Hostname: route.From, // XXX: this is not quite right
// IP: ? // TODO
},
Session: evaluator.RequestSession{
ID: sessionID,
},
Policy: route,
}
return req, nil
}
func handleEvaluatorResponseForSSH(
result *evaluator.Result, state *StreamState,
) *extensions_ssh.ServerMessage {
// fmt.Printf(" *** evaluator result: %+v\n", result)
// TODO: ideally there would be a way to keep this in sync with the logic in check_response.go
allow := result.Allow.Value && !result.Deny.Value
if allow {
pkData, _ := anypb.New(publicKeyAllowResponse(state.PublicKey))
sessionRecordingExt, _ := anypb.New(&extensions_session_recording.UpstreamTargetExtensionConfig{
RecordingName: fmt.Sprintf("session-%s-at-%s-%d.cast", state.Username, state.Hostname, time.Now().UnixNano()),
Format: extensions_session_recording.Format_AsciicastFormat,
})
return &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: &extensions_ssh.AllowResponse{
Username: state.Username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: state.Hostname,
AllowedMethods: []*extensions_ssh.AllowedMethod{
{
Method: "publickey",
MethodData: pkData,
},
{
Method: "keyboard-interactive",
},
},
Extensions: []*corev3.TypedExtensionConfig{
{
TypedConfig: sessionRecordingExt,
},
},
},
},
},
},
},
},
}
}
// XXX: do we want to send an equivalent to the "show error details" output
// in the case of a deny result?
// XXX: this is not quite right -- needs to exactly match the last list of methods
methods := []string{"publickey"}
if slices.Contains(state.MethodsAuthenticated, "keyboard-interactive") {
methods = append(methods, "keyboard-interactive")
}
return &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Deny{
Deny: &extensions_ssh.DenyResponse{
Methods: methods,
},
},
},
},
}
}
func publicKeyAllowResponse(publicKey []byte) *extensions_ssh.PublicKeyAllowResponse {
return &extensions_ssh.PublicKeyAllowResponse{
PublicKey: publicKey,
Permissions: &extensions_ssh.Permissions{
PermitPortForwarding: true,
PermitAgentForwarding: true,
PermitX11Forwarding: true,
PermitPty: true,
PermitUserRc: true,
ValidBefore: timestamppb.New(time.Now().Add(-1 * time.Minute)),
// XXX: tie this to Pomerium session lifetime?
ValidAfter: timestamppb.New(time.Now().Add(12 * time.Hour)),
},
}
}
// PersistSession stores session and user data in the databroker.
func (a *Authorize) PersistSession(
ctx context.Context,
sessionState *sessions.State, // XXX: consider not using this struct
claims identity.SessionClaims,
accessToken *oauth2.Token,
) error {
now := time.Now()
sessionLifetime := a.currentConfig.Load().Options.CookieExpire
sessionExpiry := timestamppb.New(now.Add(sessionLifetime))
sess := &session.Session{
Id: sessionState.ID,
UserId: sessionState.UserID(),
IssuedAt: timestamppb.New(now),
AccessedAt: timestamppb.New(now),
ExpiresAt: sessionExpiry,
OauthToken: manager.ToOAuthToken(accessToken),
Audience: sessionState.Audience,
}
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
// XXX: do we need to create a user record too?
// compare with Stateful.PersistSession()
res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess)
if err != nil {
return err
}
sessionState.DatabrokerServerVersion = res.GetServerVersion()
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
return nil
}
func (a *Authorize) startContinuousAuthorization(
ctx context.Context,
errC chan<- error,
req *evaluator.Request,
sessionID string,
) {
recheck := func() {
// XXX: probably want to log the results of this evaluation only if it changes
res, _ := a.evaluate(ctx, req, &sessions.State{ID: sessionID})
if !res.Allow.Value || res.Deny.Value {
errC <- fmt.Errorf("no longer authorized")
}
}
ticker := time.NewTicker(time.Second)
go func() {
for {
select {
case <-ticker.C:
recheck()
case <-ctx.Done():
ticker.Stop()
return
}
}
}()
} }
// See RFC 4254, section 5.1. // See RFC 4254, section 5.1.
@ -653,7 +797,7 @@ func (a *Authorize) ServeChannel(
switch msg.Request { switch msg.Request {
case "pty-req": case "pty-req":
opts := a.currentOptions.Load() opts := a.currentConfig.Load().Options
var routes []string var routes []string
for r := range opts.GetAllPolicies() { for r := range opts.GetAllPolicies() {
if strings.HasPrefix(r.From, "ssh://") { if strings.HasPrefix(r.From, "ssh://") {

4
go.mod
View file

@ -18,8 +18,6 @@ require (
github.com/charmbracelet/bubbles v0.20.0 github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.3.3 github.com/charmbracelet/bubbletea v1.3.3
github.com/charmbracelet/lipgloss v1.0.0 github.com/charmbracelet/lipgloss v1.0.0
github.com/charmbracelet/x/ansi v0.8.0
github.com/charmbracelet/x/term v0.2.1
github.com/cloudflare/circl v1.6.0 github.com/cloudflare/circl v1.6.0
github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3
github.com/coreos/go-oidc/v3 v3.12.0 github.com/coreos/go-oidc/v3 v3.12.0
@ -145,6 +143,8 @@ require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect

View file

@ -5,11 +5,11 @@ package files
import _ "embed" // embed import _ "embed" // embed
//go:embed envoy-darwin-arm64 //go:embed envoy
var rawBinary []byte var rawBinary []byte
//go:embed envoy-darwin-arm64.sha256 //go:embed envoy.sha256
var rawChecksum string var rawChecksum string
//go:embed envoy-darwin-arm64.version //go:embed envoy.version
var rawVersion string var rawVersion string

File diff suppressed because it is too large Load diff

View file

@ -167,7 +167,7 @@ message Policy {
string remediation = 9; string remediation = 9;
} }
// Next ID: 140. // Next ID: 141.
message Settings { message Settings {
message Certificate { message Certificate {
bytes cert_bytes = 3; bytes cert_bytes = 3;
@ -201,14 +201,15 @@ message Settings {
optional string cookie_secret = 17; optional string cookie_secret = 17;
optional string cookie_domain = 18; optional string cookie_domain = 18;
// optional bool cookie_secure = 19; // optional bool cookie_secure = 19;
optional bool cookie_http_only = 20; optional bool cookie_http_only = 20;
optional google.protobuf.Duration cookie_expire = 21; optional google.protobuf.Duration cookie_expire = 21;
optional string cookie_same_site = 113; optional string cookie_same_site = 113;
optional string idp_client_id = 22; optional string idp_client_id = 22;
optional string idp_client_secret = 23; optional string idp_client_secret = 23;
optional string idp_provider = 24; optional string idp_provider = 24;
optional string idp_provider_url = 25; optional string idp_provider_url = 25;
repeated string scopes = 26; optional StringList idp_access_token_allowed_audiences = 137;
repeated string scopes = 26;
// optional string idp_service_account = 27; // optional string idp_service_account = 27;
// optional google.protobuf.Duration idp_refresh_directory_timeout = 28; // optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
// optional google.protobuf.Duration idp_refresh_directory_interval = 29; // optional google.protobuf.Duration idp_refresh_directory_interval = 29;
@ -222,7 +223,9 @@ message Settings {
map<string, string> set_response_headers = 69; map<string, string> set_response_headers = 69;
// repeated string jwt_claims_headers = 37; // repeated string jwt_claims_headers = 37;
map<string, string> jwt_claims_headers = 63; map<string, string> jwt_claims_headers = 63;
optional IssuerFormat jwt_issuer_format = 139;
repeated string jwt_groups_filter = 119; repeated string jwt_groups_filter = 119;
optional BearerTokenFormat bearer_token_format = 138;
optional google.protobuf.Duration default_upstream_timeout = 39; optional google.protobuf.Duration default_upstream_timeout = 39;
optional string metrics_address = 40; optional string metrics_address = 40;
optional string metrics_basic_auth = 64; optional string metrics_basic_auth = 64;
@ -288,7 +291,7 @@ message Settings {
optional bool pass_identity_headers = 117; optional bool pass_identity_headers = 117;
map<string, bool> runtime_flags = 118; map<string, bool> runtime_flags = 118;
optional uint32 http3_advertise_port = 136; optional uint32 http3_advertise_port = 136;
optional string device_auth_client_type = 137; optional string device_auth_client_type = 140;
} }
message DownstreamMtlsSettings { message DownstreamMtlsSettings {

View file

@ -63,7 +63,6 @@ func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string)
return mp.SignInError return mp.SignInError
} }
<<<<<<< HEAD
// VerifyAccessToken verifies an access token. // VerifyAccessToken verifies an access token.
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) { func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyAccessToken not implemented") return nil, fmt.Errorf("VerifyAccessToken not implemented")
@ -72,8 +71,8 @@ func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims ma
// VerifyIdentityToken verifies an identity token. // VerifyIdentityToken verifies an identity token.
func (mp MockProvider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) { func (mp MockProvider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyIdentityToken not implemented") return nil, fmt.Errorf("VerifyIdentityToken not implemented")
||||||| 229ef72e5 }
=======
// DeviceAccessToken implements Authenticator. // DeviceAccessToken implements Authenticator.
func (mp MockProvider) DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state identity.State) (*oauth2.Token, error) { func (mp MockProvider) DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state identity.State) (*oauth2.Token, error) {
return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError
@ -82,5 +81,4 @@ func (mp MockProvider) DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAu
// DeviceAuth implements Authenticator. // DeviceAuth implements Authenticator.
func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) { func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return &mp.DeviceAuthResponse, mp.DeviceAuthError return &mp.DeviceAuthResponse, mp.DeviceAuthError
>>>>>>> kralicky/ssh-demo
} }

View file

@ -168,7 +168,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
func (p *Proxy) DeviceAuthLogin(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) DeviceAuthLogin(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load() state := p.state.Load()
options := p.currentOptions.Load() options := p.currentConfig.Load().Options
params := url.Values{} params := url.Values{}
routeUri := urlutil.GetAbsoluteURL(r) routeUri := urlutil.GetAbsoluteURL(r)