mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 06:27:17 +02:00
implement session creation
This commit is contained in:
parent
24b35e26a5
commit
b95ad4dbc3
15 changed files with 646 additions and 148 deletions
|
@ -6,25 +6,11 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
type VerifyAccessTokenRequest struct {
|
|
||||||
AccessToken string `json:"accessToken"`
|
|
||||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type VerifyIdentityTokenRequest struct {
|
|
||||||
IdentityToken string `json:"identityToken"`
|
|
||||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type VerifyTokenResponse struct {
|
|
||||||
Valid bool `json:"valid"`
|
|
||||||
Claims map[string]any `json:"claims,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) error {
|
||||||
var req VerifyAccessTokenRequest
|
var req authenticateapi.VerifyAccessTokenRequest
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
@ -35,7 +21,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var res VerifyTokenResponse
|
var res authenticateapi.VerifyTokenResponse
|
||||||
claims, err := authenticator.VerifyAccessToken(r.Context(), req.AccessToken)
|
claims, err := authenticator.VerifyAccessToken(r.Context(), req.AccessToken)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
res.Valid = true
|
res.Valid = true
|
||||||
|
@ -57,7 +43,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Request) error {
|
||||||
var req VerifyIdentityTokenRequest
|
var req authenticateapi.VerifyIdentityTokenRequest
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
@ -68,7 +54,7 @@ func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Reques
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var res VerifyTokenResponse
|
var res authenticateapi.VerifyTokenResponse
|
||||||
claims, err := authenticator.VerifyIdentityToken(r.Context(), req.IdentityToken)
|
claims, err := authenticator.VerifyIdentityToken(r.Context(), req.IdentityToken)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
res.Valid = true
|
res.Valid = true
|
||||||
|
|
|
@ -29,7 +29,7 @@ import (
|
||||||
type Authorize struct {
|
type Authorize struct {
|
||||||
state *atomicutil.Value[*authorizeState]
|
state *atomicutil.Value[*authorizeState]
|
||||||
store *store.Store
|
store *store.Store
|
||||||
currentOptions *atomicutil.Value[*config.Options]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
accessTracker *AccessTracker
|
accessTracker *AccessTracker
|
||||||
globalCache storage.Cache
|
globalCache storage.Cache
|
||||||
groupsCacheWarmer *cacheWarmer
|
groupsCacheWarmer *cacheWarmer
|
||||||
|
@ -43,7 +43,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
|
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
|
||||||
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
||||||
a := &Authorize{
|
a := &Authorize{
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
||||||
store: store.New(),
|
store: store.New(),
|
||||||
globalCache: storage.NewGlobalCache(time.Minute),
|
globalCache: storage.NewGlobalCache(time.Minute),
|
||||||
tracerProvider: tracerProvider,
|
tracerProvider: tracerProvider,
|
||||||
|
@ -155,7 +155,7 @@ func newPolicyEvaluator(
|
||||||
// OnConfigChange updates internal structures based on config.Options
|
// OnConfigChange updates internal structures based on config.Options
|
||||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
currentState := a.state.Load()
|
currentState := a.state.Load()
|
||||||
a.currentOptions.Store(cfg.Options)
|
a.currentConfig.Store(cfg)
|
||||||
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -186,7 +186,7 @@ func (a *Authorize) deniedResponse(
|
||||||
Err: errors.New(reason),
|
Err: errors.New(reason),
|
||||||
DebugURL: debugEndpoint,
|
DebugURL: debugEndpoint,
|
||||||
RequestID: requestid.FromContext(ctx),
|
RequestID: requestid.FromContext(ctx),
|
||||||
BrandingOptions: a.currentOptions.Load().BrandingOptions,
|
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
|
||||||
}
|
}
|
||||||
httpErr.ErrorResponse(ctx, w, r)
|
httpErr.ErrorResponse(ctx, w, r)
|
||||||
|
|
||||||
|
@ -213,7 +213,7 @@ func (a *Authorize) requireLoginResponse(
|
||||||
in *envoy_service_auth_v3.CheckRequest,
|
in *envoy_service_auth_v3.CheckRequest,
|
||||||
request *evaluator.Request,
|
request *evaluator.Request,
|
||||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||||
options := a.currentOptions.Load()
|
options := a.currentConfig.Load().Options
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
if !a.shouldRedirect(in) {
|
if !a.shouldRedirect(in) {
|
||||||
|
@ -251,7 +251,7 @@ func (a *Authorize) requireWebAuthnResponse(
|
||||||
request *evaluator.Request,
|
request *evaluator.Request,
|
||||||
result *evaluator.Result,
|
result *evaluator.Result,
|
||||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||||
opts := a.currentOptions.Load()
|
opts := a.currentConfig.Load().Options
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
// always assume https scheme
|
// always assume https scheme
|
||||||
|
@ -327,7 +327,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
|
||||||
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
|
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
|
||||||
// session that lives on the authenticate service.
|
// session that lives on the authenticate service.
|
||||||
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
|
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
|
||||||
opts := a.currentOptions.Load()
|
opts := a.currentConfig.Load().Options
|
||||||
authenticateURL, err := opts.GetAuthenticateURL()
|
authenticateURL, err := opts.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -127,8 +127,9 @@ func TestAuthorize_okResponse(t *testing.T) {
|
||||||
}},
|
}},
|
||||||
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
||||||
}
|
}
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||||
a.currentOptions.Store(opt)
|
Options: opt,
|
||||||
|
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
a.store = store.New()
|
a.store = store.New()
|
||||||
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
|
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -183,15 +184,16 @@ func TestAuthorize_okResponse(t *testing.T) {
|
||||||
func TestAuthorize_deniedResponse(t *testing.T) {
|
func TestAuthorize_deniedResponse(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||||
a.currentOptions.Store(&config.Options{
|
Options: &config.Options{
|
||||||
Policies: []config.Policy{{
|
Policies: []config.Policy{{
|
||||||
From: "https://example.com",
|
From: "https://example.com",
|
||||||
SubPolicies: []config.SubPolicy{{
|
SubPolicies: []config.SubPolicy{{
|
||||||
Rego: []string{"allow = true"},
|
Rego: []string{"allow = true"},
|
||||||
}},
|
}},
|
||||||
}},
|
}},
|
||||||
})
|
},
|
||||||
|
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
|
|
||||||
t.Run("json", func(t *testing.T) {
|
t.Run("json", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
|
@ -3,15 +3,17 @@ package authorize
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"google.golang.org/grpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type sessionOrServiceAccount interface {
|
type sessionOrServiceAccount interface {
|
||||||
|
GetId() string
|
||||||
GetUserId() string
|
GetUserId() string
|
||||||
Validate() error
|
Validate() error
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package authorize
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -21,6 +22,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
|
@ -44,31 +46,25 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||||
ctx = requestid.WithValue(ctx, requestID)
|
ctx = requestid.WithValue(ctx, requestID)
|
||||||
|
|
||||||
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in)
|
||||||
|
|
||||||
var s sessionOrServiceAccount
|
|
||||||
var u *user.User
|
|
||||||
var err error
|
|
||||||
if sessionState != nil {
|
|
||||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
|
||||||
if status.Code(err) == codes.Unavailable {
|
|
||||||
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
|
|
||||||
return nil, err
|
|
||||||
} else if err != nil {
|
|
||||||
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
|
|
||||||
sessionState = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if sessionState != nil && s != nil {
|
|
||||||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
|
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load the session
|
||||||
|
s, err := a.loadSession(ctx, hreq, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there's a session or service account, load the user
|
||||||
|
var u *user.User
|
||||||
|
if s != nil {
|
||||||
|
req.Session.ID = s.GetId()
|
||||||
|
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
||||||
|
}
|
||||||
|
|
||||||
res, err := state.evaluator.Evaluate(ctx, req)
|
res, err := state.evaluator.Evaluate(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
|
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
|
||||||
|
@ -88,10 +84,65 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) loadSession(
|
||||||
|
ctx context.Context,
|
||||||
|
hreq *http.Request,
|
||||||
|
req *evaluator.Request,
|
||||||
|
) (s sessionOrServiceAccount, err error) {
|
||||||
|
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||||
|
|
||||||
|
// attempt to create a session from an incoming idp token
|
||||||
|
s, err = config.NewIncomingIDPTokenSessionCreator(
|
||||||
|
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||||
|
return getDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||||
|
},
|
||||||
|
func(ctx context.Context, records []*databroker.Record) error {
|
||||||
|
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
|
Records: records,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// invalidate cache
|
||||||
|
for _, record := range records {
|
||||||
|
storage.GetQuerier(ctx).InvalidateCache(ctx, &databroker.QueryRequest{
|
||||||
|
Type: record.GetType(),
|
||||||
|
Query: record.GetId(),
|
||||||
|
Limit: 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
||||||
|
if err == nil {
|
||||||
|
return s, nil
|
||||||
|
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
|
||||||
|
log.Ctx(ctx).Info().
|
||||||
|
Str("request-id", requestID).
|
||||||
|
Err(err).
|
||||||
|
Msg("error creating session for incoming idp token")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
||||||
|
if sessionState == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||||
|
if status.Code(err) == codes.Unavailable {
|
||||||
|
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
|
||||||
|
return nil, err
|
||||||
|
} else if err != nil {
|
||||||
|
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
in *envoy_service_auth_v3.CheckRequest,
|
in *envoy_service_auth_v3.CheckRequest,
|
||||||
sessionState *sessions.State,
|
|
||||||
) (*evaluator.Request, error) {
|
) (*evaluator.Request, error) {
|
||||||
requestURL := getCheckRequestURL(in)
|
requestURL := getCheckRequestURL(in)
|
||||||
attrs := in.GetAttributes()
|
attrs := in.GetAttributes()
|
||||||
|
@ -106,17 +157,12 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
||||||
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
|
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if sessionState != nil {
|
|
||||||
req.Session = evaluator.RequestSession{
|
|
||||||
ID: sessionState.ID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
|
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
|
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
|
||||||
options := a.currentOptions.Load()
|
options := a.currentConfig.Load().Options
|
||||||
|
|
||||||
for p := range options.GetAllPolicies() {
|
for p := range options.GetAllPolicies() {
|
||||||
id, _ := p.RouteID()
|
id, _ := p.RouteID()
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
@ -49,15 +48,16 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
||||||
-----END CERTIFICATE-----`
|
-----END CERTIFICATE-----`
|
||||||
|
|
||||||
func Test_getEvaluatorRequest(t *testing.T) {
|
func Test_getEvaluatorRequest(t *testing.T) {
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||||
a.currentOptions.Store(&config.Options{
|
Options: &config.Options{
|
||||||
Policies: []config.Policy{{
|
Policies: []config.Policy{{
|
||||||
From: "https://example.com",
|
From: "https://example.com",
|
||||||
SubPolicies: []config.SubPolicy{{
|
SubPolicies: []config.SubPolicy{{
|
||||||
Rego: []string{"allow = true"},
|
Rego: []string{"allow = true"},
|
||||||
}},
|
}},
|
||||||
}},
|
}},
|
||||||
})
|
},
|
||||||
|
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
|
|
||||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||||
&envoy_service_auth_v3.CheckRequest{
|
&envoy_service_auth_v3.CheckRequest{
|
||||||
|
@ -88,13 +88,10 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&sessions.State{
|
|
||||||
ID: "SESSION_ID",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := &evaluator.Request{
|
expect := &evaluator.Request{
|
||||||
Policy: &a.currentOptions.Load().Policies[0],
|
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||||
Session: evaluator.RequestSession{
|
Session: evaluator.RequestSession{
|
||||||
ID: "SESSION_ID",
|
ID: "SESSION_ID",
|
||||||
},
|
},
|
||||||
|
@ -117,15 +114,16 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||||
a.currentOptions.Store(&config.Options{
|
Options: &config.Options{
|
||||||
Policies: []config.Policy{{
|
Policies: []config.Policy{{
|
||||||
From: "https://example.com",
|
From: "https://example.com",
|
||||||
SubPolicies: []config.SubPolicy{{
|
SubPolicies: []config.SubPolicy{{
|
||||||
Rego: []string{"allow = true"},
|
Rego: []string{"allow = true"},
|
||||||
}},
|
}},
|
||||||
}},
|
}},
|
||||||
})
|
},
|
||||||
|
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
|
|
||||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||||
&envoy_service_auth_v3.CheckRequest{
|
&envoy_service_auth_v3.CheckRequest{
|
||||||
|
@ -145,10 +143,10 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, nil)
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := &evaluator.Request{
|
expect := &evaluator.Request{
|
||||||
Policy: &a.currentOptions.Load().Policies[0],
|
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||||
Session: evaluator.RequestSession{},
|
Session: evaluator.RequestSession{},
|
||||||
HTTP: evaluator.NewRequestHTTP(
|
HTTP: evaluator.NewRequestHTTP(
|
||||||
http.MethodGet,
|
http.MethodGet,
|
||||||
|
|
|
@ -31,7 +31,7 @@ func (a *Authorize) logAuthorizeCheck(
|
||||||
impersonateDetails := a.getImpersonateDetails(ctx, s)
|
impersonateDetails := a.getImpersonateDetails(ctx, s)
|
||||||
|
|
||||||
evt := log.Ctx(ctx).Info().Str("service", "authorize")
|
evt := log.Ctx(ctx).Info().Str("service", "authorize")
|
||||||
fields := a.currentOptions.Load().GetAuthorizeLogFields()
|
fields := a.currentConfig.Load().Options.GetAuthorizeLogFields()
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res)
|
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res)
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ func (bearerTokenFormat *BearerTokenFormat) ToPB() *configpb.BearerTokenFormat {
|
||||||
case BearerTokenFormatIDPIdentityToken:
|
case BearerTokenFormatIDPIdentityToken:
|
||||||
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN.Enum()
|
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN.Enum()
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unknown bearer token format: %s", bearerTokenFormat))
|
panic(fmt.Sprintf("unknown bearer token format: %v", bearerTokenFormat))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,10 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/pomerium/pomerium/config/otelconfig"
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config/otelconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -37,6 +38,7 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||||||
DecodePolicyBase64Hook(),
|
DecodePolicyBase64Hook(),
|
||||||
decodeNullBoolHookFunc(),
|
decodeNullBoolHookFunc(),
|
||||||
decodeJWTClaimHeadersHookFunc(),
|
decodeJWTClaimHeadersHookFunc(),
|
||||||
|
decodeBearerTokenFormatHookFunc(),
|
||||||
decodeCodecTypeHookFunc(),
|
decodeCodecTypeHookFunc(),
|
||||||
decodePPLPolicyHookFunc(),
|
decodePPLPolicyHookFunc(),
|
||||||
decodeSANMatcherHookFunc(),
|
decodeSANMatcherHookFunc(),
|
||||||
|
|
|
@ -5,15 +5,28 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A SessionStore saves and loads sessions based on the options.
|
// A SessionStore saves and loads sessions based on the options.
|
||||||
|
@ -116,82 +129,253 @@ func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v
|
||||||
return store.store.SaveSession(w, r, v)
|
return store.store.SaveSession(w, r, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// An IDPTokenSessionHandler handles incoming idp access and identity tokens.
|
var (
|
||||||
type IDPTokenSessionHandler struct {
|
accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d")
|
||||||
options *Options
|
identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17")
|
||||||
getSession func(ctx context.Context, id string) (*session.Session, error)
|
)
|
||||||
putSession func(ctx context.Context, s *session.Session) error
|
|
||||||
|
type IncomingIDPTokenSessionCreator interface {
|
||||||
|
CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIDPTokenSessionHandler creates a new IDPTokenSessionHandler.
|
type incomingIDPTokenSessionCreator struct {
|
||||||
func NewIDPTokenSessionHandler(
|
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
|
||||||
options *Options,
|
putRecords func(ctx context.Context, records []*databroker.Record) error
|
||||||
getSession func(ctx context.Context, id string) (*session.Session, error),
|
|
||||||
putSession func(ctx context.Context, s *session.Session) error,
|
|
||||||
) *IDPTokenSessionHandler {
|
|
||||||
return &IDPTokenSessionHandler{
|
|
||||||
options: options,
|
|
||||||
getSession: getSession,
|
|
||||||
putSession: putSession,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// // CreateSessionForIncomingIDPToken creates a session from an incoming idp access or identity token.
|
func NewIncomingIDPTokenSessionCreator(
|
||||||
// // If no such tokens are found or they are invalid ErrNoSessionFound will be returned.
|
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
|
||||||
// func (h *IDPTokenSessionHandler) CreateSessionForIncomingIDPToken(r *http.Request) (*session.Session, error) {
|
putRecords func(ctx context.Context, records []*databroker.Record) error,
|
||||||
// idp, err := h.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
) IncomingIDPTokenSessionCreator {
|
||||||
// if err != nil {
|
return &incomingIDPTokenSessionCreator{getRecord: getRecord, putRecords: putRecords}
|
||||||
// return nil, err
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
// return nil, sessions.ErrNoSessionFound
|
// CreateSession attempts to create a session for incoming idp access and
|
||||||
// }
|
// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned.
|
||||||
|
// If the tokens are not valid an error will be returned.
|
||||||
|
func (c *incomingIDPTokenSessionCreator) CreateSession(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *Config,
|
||||||
|
policy *Policy,
|
||||||
|
r *http.Request,
|
||||||
|
) (session *session.Session, err error) {
|
||||||
|
if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok {
|
||||||
|
return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
// func (h *IDPTokenSessionHandler) getIncomingIDPAccessToken(r *http.Request) (rawAccessToken string, ok bool) {
|
if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok {
|
||||||
// if h.options.
|
return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken)
|
||||||
|
}
|
||||||
|
|
||||||
// return "", false
|
return nil, sessions.ErrNoSessionFound
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func (h *IDPTokenSessionHandler) getIncomingIDPIdentityToken(r *http.Request) (rawIdentityToken string, ok bool) {
|
func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
||||||
// return "", false
|
ctx context.Context,
|
||||||
// }
|
cfg *Config,
|
||||||
|
policy *Policy,
|
||||||
|
rawAccessToken string,
|
||||||
|
) (*session.Session, error) {
|
||||||
|
sessionID := uuid.NewSHA1(accessTokenUUIDNamespace, []byte(rawAccessToken)).String()
|
||||||
|
s, err := c.getSession(ctx, sessionID)
|
||||||
|
if err == nil {
|
||||||
|
return s, nil
|
||||||
|
} else if !storage.IsNotFound(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// func CreateSessionForIncomingIDPToken(
|
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
|
||||||
// r *http.Request,
|
if err != nil {
|
||||||
// options *Options,
|
return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err)
|
||||||
// policy *Policy,
|
}
|
||||||
// getSession func(ctx context.Context, id string) (*session.Session, error),
|
|
||||||
// putSession func(ctx context.Context, s *session.Session) error)(*session.Session, error) {
|
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||||
// }
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
|
||||||
|
AccessToken: rawAccessToken,
|
||||||
|
IdentityProviderID: idp.GetId(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||||
|
} else if !res.Valid {
|
||||||
|
return nil, fmt.Errorf("invalid access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||||
|
s.OauthToken = &session.OAuthToken{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
AccessToken: rawAccessToken,
|
||||||
|
ExpiresAt: s.ExpiresAt,
|
||||||
|
}
|
||||||
|
u := c.newUserFromIDPClaims(res.Claims)
|
||||||
|
err = c.putSessionAndUser(ctx, s, u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *Config,
|
||||||
|
policy *Policy,
|
||||||
|
rawIdentityToken string,
|
||||||
|
) (*session.Session, error) {
|
||||||
|
sessionID := uuid.NewSHA1(identityTokenUUIDNamespace, []byte(rawIdentityToken)).String()
|
||||||
|
s, err := c.getSession(ctx, sessionID)
|
||||||
|
if err == nil {
|
||||||
|
return s, nil
|
||||||
|
} else if !storage.IsNotFound(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
|
||||||
|
IdentityToken: rawIdentityToken,
|
||||||
|
IdentityProviderID: idp.GetId(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
||||||
|
} else if !res.Valid {
|
||||||
|
return nil, fmt.Errorf("invalid identity token")
|
||||||
|
}
|
||||||
|
|
||||||
|
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||||
|
s.SetRawIDToken(rawIdentityToken)
|
||||||
|
u := c.newUserFromIDPClaims(res.Claims)
|
||||||
|
err = c.putSessionAndUser(ctx, s, u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
||||||
|
cfg *Config,
|
||||||
|
sessionID string,
|
||||||
|
claims jwtutil.Claims,
|
||||||
|
) *session.Session {
|
||||||
|
now := time.Now()
|
||||||
|
s := new(session.Session)
|
||||||
|
s.Id = sessionID
|
||||||
|
if userID, ok := claims.GetUserID(); ok {
|
||||||
|
s.UserId = userID
|
||||||
|
}
|
||||||
|
if issuedAt, ok := claims.GetIssuedAt(); ok {
|
||||||
|
s.IssuedAt = timestamppb.New(issuedAt)
|
||||||
|
} else {
|
||||||
|
s.IssuedAt = timestamppb.New(now)
|
||||||
|
}
|
||||||
|
if expiresAt, ok := claims.GetExpirationTime(); ok {
|
||||||
|
s.ExpiresAt = timestamppb.New(expiresAt)
|
||||||
|
} else {
|
||||||
|
s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire))
|
||||||
|
}
|
||||||
|
s.AccessedAt = timestamppb.New(now)
|
||||||
|
s.AddClaims(identity.Claims(claims).Flatten())
|
||||||
|
if aud, ok := claims.GetAudience(); ok {
|
||||||
|
s.Audience = aud
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
|
||||||
|
claims jwtutil.Claims,
|
||||||
|
) *user.User {
|
||||||
|
u := new(user.User)
|
||||||
|
if userID, ok := claims.GetUserID(); ok {
|
||||||
|
u.Id = userID
|
||||||
|
}
|
||||||
|
if name, ok := claims.GetString("name"); ok {
|
||||||
|
u.Name = name
|
||||||
|
}
|
||||||
|
if email, ok := claims.GetString("email"); ok {
|
||||||
|
u.Email = email
|
||||||
|
}
|
||||||
|
u.Claims = identity.Claims(claims).Flatten().ToPB()
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
|
||||||
|
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := record.GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return nil, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
s, ok := msg.(*session.Session)
|
||||||
|
if !ok {
|
||||||
|
return nil, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
|
||||||
|
var records []*databroker.Record
|
||||||
|
if id := s.GetId(); id != "" {
|
||||||
|
records = append(records, &databroker.Record{
|
||||||
|
Type: grpcutil.GetTypeURL(s),
|
||||||
|
Id: id,
|
||||||
|
Data: protoutil.NewAny(s),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if id := u.GetId(); id != "" {
|
||||||
|
records = append(records, &databroker.Record{
|
||||||
|
Type: grpcutil.GetTypeURL(u),
|
||||||
|
Id: id,
|
||||||
|
Data: protoutil.NewAny(u),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return c.putRecords(ctx, records)
|
||||||
|
}
|
||||||
|
|
||||||
// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
|
// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
|
||||||
func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
|
func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
|
||||||
bearerTokenFormat := BearerTokenFormatDefault
|
bearerTokenFormat := BearerTokenFormatDefault
|
||||||
if options != nil && options.BearerTokenFormat != nil {
|
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
|
||||||
bearerTokenFormat = *options.BearerTokenFormat
|
bearerTokenFormat = *cfg.Options.BearerTokenFormat
|
||||||
}
|
}
|
||||||
if policy != nil && policy.BearerTokenFormat != nil {
|
if policy != nil && policy.BearerTokenFormat != nil {
|
||||||
bearerTokenFormat = *policy.BearerTokenFormat
|
bearerTokenFormat = *policy.BearerTokenFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
if token := r.Header.Get("X-Pomerium-IDP-Access-Token"); token != "" {
|
if token := r.Header.Get(httputil.HeaderPomeriumIDPAccessToken); token != "" {
|
||||||
return token, true
|
return token, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
|
||||||
prefix := "Pomerium-IDP-Access-Token "
|
prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " "
|
||||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||||
return strings.TrimPrefix(auth, prefix), true
|
return strings.TrimPrefix(auth, prefix), true
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = "Bearer Pomerium-IDP-Access-Token-"
|
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-"
|
||||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||||
return strings.TrimPrefix(auth, prefix), true
|
return strings.TrimPrefix(auth, prefix), true
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = "Bearer "
|
prefix = "Bearer "
|
||||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPAccessToken {
|
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
|
||||||
|
bearerTokenFormat == BearerTokenFormatIDPAccessToken {
|
||||||
return strings.TrimPrefix(auth, prefix), true
|
return strings.TrimPrefix(auth, prefix), true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -200,32 +384,33 @@ func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *ht
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
|
// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
|
||||||
func (options *Options) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
|
func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
|
||||||
bearerTokenFormat := BearerTokenFormatDefault
|
bearerTokenFormat := BearerTokenFormatDefault
|
||||||
if options != nil && options.BearerTokenFormat != nil {
|
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
|
||||||
bearerTokenFormat = *options.BearerTokenFormat
|
bearerTokenFormat = *cfg.Options.BearerTokenFormat
|
||||||
}
|
}
|
||||||
if policy != nil && policy.BearerTokenFormat != nil {
|
if policy != nil && policy.BearerTokenFormat != nil {
|
||||||
bearerTokenFormat = *policy.BearerTokenFormat
|
bearerTokenFormat = *policy.BearerTokenFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
if token := r.Header.Get("X-Pomerium-IDP-Identity-Token"); token != "" {
|
if token := r.Header.Get(httputil.HeaderPomeriumIDPIdentityToken); token != "" {
|
||||||
return token, true
|
return token, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
|
||||||
prefix := "Pomerium-IDP-Identity-Token "
|
prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " "
|
||||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||||
return strings.TrimPrefix(auth, prefix), true
|
return strings.TrimPrefix(auth, prefix), true
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = "Bearer Pomerium-IDP-Identity-Token-"
|
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-"
|
||||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||||
return strings.TrimPrefix(auth, prefix), true
|
return strings.TrimPrefix(auth, prefix), true
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = "Bearer "
|
prefix = "Bearer "
|
||||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
|
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
|
||||||
|
bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
|
||||||
return strings.TrimPrefix(auth, prefix), true
|
return strings.TrimPrefix(auth, prefix), true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
package httputil
|
package httputil
|
||||||
|
|
||||||
|
// Pomerium authorization types
|
||||||
|
const (
|
||||||
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
|
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
|
||||||
const AuthorizationTypePomerium = "Pomerium"
|
AuthorizationTypePomerium = "Pomerium"
|
||||||
|
AuthorizationTypePomeriumIDPAccessToken = "Pomerium-IDP-Access-Token" //nolint: gosec
|
||||||
|
AuthorizationTypePomeriumIDPIdentityToken = "Pomerium-IDP-Identity-Token" //nolint: gosec
|
||||||
|
)
|
||||||
|
|
||||||
// Standard headers
|
// Standard headers
|
||||||
const (
|
const (
|
||||||
|
@ -17,6 +22,8 @@ const (
|
||||||
// can be used in place of the standard authorization header if that header is being
|
// can be used in place of the standard authorization header if that header is being
|
||||||
// used by upstream applications.
|
// used by upstream applications.
|
||||||
HeaderPomeriumAuthorization = "x-pomerium-authorization"
|
HeaderPomeriumAuthorization = "x-pomerium-authorization"
|
||||||
|
HeaderPomeriumIDPAccessToken = "x-pomerium-idp-access-token" //nolint: gosec
|
||||||
|
HeaderPomeriumIDPIdentityToken = "x-pomerium-idp-identity-token" //nolint: gosec
|
||||||
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
||||||
// as opposed to the upstream application and can be used to distinguish
|
// as opposed to the upstream application and can be used to distinguish
|
||||||
// between an application error, and a pomerium related error when debugging.
|
// between an application error, and a pomerium related error when debugging.
|
||||||
|
|
160
internal/jwtutil/jwtutil.go
Normal file
160
internal/jwtutil/jwtutil.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
// Package jwtutil contains functions for working with JWTs.
|
||||||
|
package jwtutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claims represent claims in a JWT.
|
||||||
|
type Claims map[string]any
|
||||||
|
|
||||||
|
// UnmarshalJSON implements a custom unmarshaller for claims data.
|
||||||
|
func (claims *Claims) UnmarshalJSON(raw []byte) error {
|
||||||
|
dst := map[string]any{}
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||||
|
dec.UseNumber()
|
||||||
|
err := dec.Decode(&dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*claims = Claims(dst)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// registered claims
|
||||||
|
|
||||||
|
// GetIssuer gets the iss claim.
|
||||||
|
func (claims Claims) GetIssuer() (issuer string, ok bool) {
|
||||||
|
return claims.GetString("iss")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSubject gets the sub claim.
|
||||||
|
func (claims Claims) GetSubject() (subject string, ok bool) {
|
||||||
|
return claims.GetString("sub")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAudience gets the aud claim.
|
||||||
|
func (claims Claims) GetAudience() (audiences []string, ok bool) {
|
||||||
|
return claims.GetStringSlice("aud")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExpirationTime gets the exp claim.
|
||||||
|
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
|
||||||
|
return claims.GetNumericDate("exp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNotBefore gets the nbf claim.
|
||||||
|
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
|
||||||
|
return claims.GetNumericDate("nbf")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIssuedAt gets the iat claim.
|
||||||
|
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
|
||||||
|
return claims.GetNumericDate("iat")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJWTID gets the jti claim.
|
||||||
|
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
|
||||||
|
return claims.GetString("jti")
|
||||||
|
}
|
||||||
|
|
||||||
|
// custom claims
|
||||||
|
|
||||||
|
// GetUserID returns the oid or sub claim.
|
||||||
|
func (claims Claims) GetUserID() (userID string, ok bool) {
|
||||||
|
if oid, ok := claims.GetString("oid"); ok {
|
||||||
|
return oid, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if sub, ok := claims.GetSubject(); ok {
|
||||||
|
return sub, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNumericDate returns the claim as a numeric date.
|
||||||
|
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
|
||||||
|
if claims == nil {
|
||||||
|
return tm, false
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, ok := claims[name]
|
||||||
|
if !ok {
|
||||||
|
return tm, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case float64:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case int64:
|
||||||
|
return time.Unix(v, 0), true
|
||||||
|
case json.Number:
|
||||||
|
i, err := v.Int64()
|
||||||
|
if err != nil {
|
||||||
|
if f, err := v.Float64(); err == nil {
|
||||||
|
i = int64(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return tm, false
|
||||||
|
}
|
||||||
|
return time.Unix(i, 0), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return tm, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetString returns the claim as a string.
|
||||||
|
func (claims Claims) GetString(name string) (value string, ok bool) {
|
||||||
|
if claims == nil {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, ok := claims[name]
|
||||||
|
if !ok {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return toString(raw), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStringSlice returns the claim as a slice of strings.
|
||||||
|
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
|
||||||
|
if claims == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, ok := claims[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return toStringSlice(raw), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func toString(data any) string {
|
||||||
|
switch v := data.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return fmt.Sprint(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStringSlice(obj any) []string {
|
||||||
|
v := reflect.ValueOf(obj)
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
vs := make([]string, v.Len())
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
vs[i] = toString(v.Index(i).Interface())
|
||||||
|
}
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{toString(obj)}
|
||||||
|
}
|
109
pkg/authenticateapi/authenticateapi.go
Normal file
109
pkg/authenticateapi/authenticateapi.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
// Package authenticateapi has the types and methods for the authenticate api.
|
||||||
|
package authenticateapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VerifyAccessTokenRequest is used to verify access tokens.
|
||||||
|
type VerifyAccessTokenRequest struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyIdentityTokenRequest is used to verify identity tokens.
|
||||||
|
type VerifyIdentityTokenRequest struct {
|
||||||
|
IdentityToken string `json:"identityToken"`
|
||||||
|
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyTokenResponse is the result of verifying an access or identity token.
|
||||||
|
type VerifyTokenResponse struct {
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
Claims jwtutil.Claims `json:"claims,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// An API is an api client for the authenticate service.
|
||||||
|
type API struct {
|
||||||
|
authenticateURL *url.URL
|
||||||
|
transport http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new API client.
|
||||||
|
func New(
|
||||||
|
authenticateURL *url.URL,
|
||||||
|
transport http.RoundTripper,
|
||||||
|
) *API {
|
||||||
|
return &API{
|
||||||
|
authenticateURL: authenticateURL,
|
||||||
|
transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyAccessToken verifies an access token.
|
||||||
|
func (api *API) VerifyAccessToken(ctx context.Context, request *VerifyAccessTokenRequest) (*VerifyTokenResponse, error) {
|
||||||
|
var response VerifyTokenResponse
|
||||||
|
err := api.call(ctx, "verify-access-token", request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyIdentityToken verifies an identity token.
|
||||||
|
func (api *API) VerifyIdentityToken(ctx context.Context, request *VerifyIdentityTokenRequest) (*VerifyTokenResponse, error) {
|
||||||
|
var response VerifyTokenResponse
|
||||||
|
err := api.call(ctx, "verify-identity-token", request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *API) call(
|
||||||
|
ctx context.Context,
|
||||||
|
endpoint string,
|
||||||
|
request, response any,
|
||||||
|
) error {
|
||||||
|
u := api.authenticateURL.ResolveReference(&url.URL{
|
||||||
|
Path: "/.pomerium/" + endpoint,
|
||||||
|
})
|
||||||
|
|
||||||
|
body, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshaling %s http request: %w", endpoint, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating %s http request: %w", endpoint, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := (&http.Client{
|
||||||
|
Transport: api.transport,
|
||||||
|
}).Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error executing %s http request: %w", endpoint, err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
body, err = io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reading %s http response: %w", endpoint, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reading %s http response (body=%s): %w", endpoint, body, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
|
||||||
)
|
)
|
||||||
|
@ -95,7 +96,7 @@ func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string)
|
||||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims = map[string]any{}
|
claims = jwtutil.Claims(map[string]any{})
|
||||||
err = token.Claims(&claims)
|
err = token.Claims(&claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue