implement session creation

This commit is contained in:
Caleb Doxsey 2025-02-14 14:43:23 -07:00
parent 24b35e26a5
commit b95ad4dbc3
15 changed files with 646 additions and 148 deletions

View file

@ -6,25 +6,11 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/authenticateapi"
)
type VerifyAccessTokenRequest struct {
AccessToken string `json:"accessToken"`
IdentityProviderID string `json:"identityProviderId,omitempty"`
}
type VerifyIdentityTokenRequest struct {
IdentityToken string `json:"identityToken"`
IdentityProviderID string `json:"identityProviderId,omitempty"`
}
type VerifyTokenResponse struct {
Valid bool `json:"valid"`
Claims map[string]any `json:"claims,omitempty"`
}
func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) error {
var req VerifyAccessTokenRequest
var req authenticateapi.VerifyAccessTokenRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
@ -35,7 +21,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request)
return err
}
var res VerifyTokenResponse
var res authenticateapi.VerifyTokenResponse
claims, err := authenticator.VerifyAccessToken(r.Context(), req.AccessToken)
if err == nil {
res.Valid = true
@ -57,7 +43,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request)
}
func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Request) error {
var req VerifyIdentityTokenRequest
var req authenticateapi.VerifyIdentityTokenRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
@ -68,7 +54,7 @@ func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Reques
return err
}
var res VerifyTokenResponse
var res authenticateapi.VerifyTokenResponse
claims, err := authenticator.VerifyIdentityToken(r.Context(), req.IdentityToken)
if err == nil {
res.Valid = true

View file

@ -29,7 +29,7 @@ import (
type Authorize struct {
state *atomicutil.Value[*authorizeState]
store *store.Store
currentOptions *atomicutil.Value[*config.Options]
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
globalCache storage.Cache
groupsCacheWarmer *cacheWarmer
@ -43,7 +43,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
a := &Authorize{
currentOptions: config.NewAtomicOptions(),
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
store: store.New(),
globalCache: storage.NewGlobalCache(time.Minute),
tracerProvider: tracerProvider,
@ -155,7 +155,7 @@ func newPolicyEvaluator(
// OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load()
a.currentOptions.Store(cfg.Options)
a.currentConfig.Store(cfg)
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else {

View file

@ -186,7 +186,7 @@ func (a *Authorize) deniedResponse(
Err: errors.New(reason),
DebugURL: debugEndpoint,
RequestID: requestid.FromContext(ctx),
BrandingOptions: a.currentOptions.Load().BrandingOptions,
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
}
httpErr.ErrorResponse(ctx, w, r)
@ -213,7 +213,7 @@ func (a *Authorize) requireLoginResponse(
in *envoy_service_auth_v3.CheckRequest,
request *evaluator.Request,
) (*envoy_service_auth_v3.CheckResponse, error) {
options := a.currentOptions.Load()
options := a.currentConfig.Load().Options
state := a.state.Load()
if !a.shouldRedirect(in) {
@ -251,7 +251,7 @@ func (a *Authorize) requireWebAuthnResponse(
request *evaluator.Request,
result *evaluator.Result,
) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
state := a.state.Load()
// always assume https scheme
@ -327,7 +327,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
// session that lives on the authenticate service.
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err

View file

@ -127,8 +127,9 @@ func TestAuthorize_okResponse(t *testing.T) {
}},
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
}
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(opt)
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: opt,
}), state: atomicutil.NewValue(new(authorizeState))}
a.store = store.New()
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
require.NoError(t, err)
@ -183,15 +184,16 @@ func TestAuthorize_okResponse(t *testing.T) {
func TestAuthorize_deniedResponse(t *testing.T) {
t.Parallel()
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: &config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
}},
})
},
}), state: atomicutil.NewValue(new(authorizeState))}
t.Run("json", func(t *testing.T) {
t.Parallel()

View file

@ -3,15 +3,17 @@ package authorize
import (
"context"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
"google.golang.org/grpc"
)
type sessionOrServiceAccount interface {
GetId() string
GetUserId() string
Validate() error
}

View file

@ -3,6 +3,7 @@ package authorize
import (
"context"
"encoding/pem"
"errors"
"io"
"net/http"
"net/url"
@ -21,6 +22,7 @@ import (
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
@ -44,31 +46,25 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
requestID := requestid.FromHTTPHeader(hreq.Header)
ctx = requestid.WithValue(ctx, requestID)
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
var s sessionOrServiceAccount
var u *user.User
var err error
if sessionState != nil {
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if status.Code(err) == codes.Unavailable {
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
return nil, err
} else if err != nil {
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
sessionState = nil
}
}
if sessionState != nil && s != nil {
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState)
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
return nil, err
}
// load the session
s, err := a.loadSession(ctx, hreq, req)
if err != nil {
return nil, err
}
// if there's a session or service account, load the user
var u *user.User
if s != nil {
req.Session.ID = s.GetId()
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
res, err := state.evaluator.Evaluate(ctx, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
@ -88,10 +84,65 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
return resp, err
}
func (a *Authorize) loadSession(
ctx context.Context,
hreq *http.Request,
req *evaluator.Request,
) (s sessionOrServiceAccount, err error) {
requestID := requestid.FromHTTPHeader(hreq.Header)
// attempt to create a session from an incoming idp token
s, err = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
return getDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
// invalidate cache
for _, record := range records {
storage.GetQuerier(ctx).InvalidateCache(ctx, &databroker.QueryRequest{
Type: record.GetType(),
Query: record.GetId(),
Limit: 1,
})
}
return nil
},
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
if err == nil {
return s, nil
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
log.Ctx(ctx).Info().
Str("request-id", requestID).
Err(err).
Msg("error creating session for incoming idp token")
}
sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq)
if sessionState == nil {
return nil, nil
}
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if status.Code(err) == codes.Unavailable {
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
return nil, err
} else if err != nil {
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
return nil, nil
}
return s, nil
}
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
sessionState *sessions.State,
) (*evaluator.Request, error) {
requestURL := getCheckRequestURL(in)
attrs := in.GetAttributes()
@ -106,17 +157,12 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
),
}
if sessionState != nil {
req.Session = evaluator.RequestSession{
ID: sessionState.ID,
}
}
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
return req, nil
}
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
options := a.currentOptions.Load()
options := a.currentConfig.Load().Options
for p := range options.GetAllPolicies() {
id, _ := p.RouteID()

View file

@ -18,7 +18,6 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
@ -49,15 +48,16 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: &config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
}},
})
},
}), state: atomicutil.NewValue(new(authorizeState))}
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
@ -88,13 +88,10 @@ func Test_getEvaluatorRequest(t *testing.T) {
},
},
},
&sessions.State{
ID: "SESSION_ID",
},
)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &a.currentConfig.Load().Options.Policies[0],
Session: evaluator.RequestSession{
ID: "SESSION_ID",
},
@ -117,15 +114,16 @@ func Test_getEvaluatorRequest(t *testing.T) {
}
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: &config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
}},
})
},
}), state: atomicutil.NewValue(new(authorizeState))}
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
@ -145,10 +143,10 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
},
},
},
}, nil)
})
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &a.currentConfig.Load().Options.Policies[0],
Session: evaluator.RequestSession{},
HTTP: evaluator.NewRequestHTTP(
http.MethodGet,

View file

@ -31,7 +31,7 @@ func (a *Authorize) logAuthorizeCheck(
impersonateDetails := a.getImpersonateDetails(ctx, s)
evt := log.Ctx(ctx).Info().Str("service", "authorize")
fields := a.currentOptions.Load().GetAuthorizeLogFields()
fields := a.currentConfig.Load().Options.GetAuthorizeLogFields()
for _, field := range fields {
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res)
}

View file

@ -67,7 +67,7 @@ func (bearerTokenFormat *BearerTokenFormat) ToPB() *configpb.BearerTokenFormat {
case BearerTokenFormatIDPIdentityToken:
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN.Enum()
default:
panic(fmt.Sprintf("unknown bearer token format: %s", bearerTokenFormat))
panic(fmt.Sprintf("unknown bearer token format: %v", bearerTokenFormat))
}
}

View file

@ -4,9 +4,10 @@ import (
"errors"
"github.com/mitchellh/mapstructure"
"github.com/pomerium/pomerium/config/otelconfig"
"github.com/spf13/viper"
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/config/otelconfig"
)
const (
@ -37,6 +38,7 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
DecodePolicyBase64Hook(),
decodeNullBoolHookFunc(),
decodeJWTClaimHeadersHookFunc(),
decodeBearerTokenFormatHookFunc(),
decodeCodecTypeHookFunc(),
decodePPLPolicyHookFunc(),
decodeSANMatcherHookFunc(),

View file

@ -5,15 +5,28 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/authenticateapi"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
// A SessionStore saves and loads sessions based on the options.
@ -116,82 +129,253 @@ func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v
return store.store.SaveSession(w, r, v)
}
// An IDPTokenSessionHandler handles incoming idp access and identity tokens.
type IDPTokenSessionHandler struct {
options *Options
getSession func(ctx context.Context, id string) (*session.Session, error)
putSession func(ctx context.Context, s *session.Session) error
var (
accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d")
identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17")
)
type IncomingIDPTokenSessionCreator interface {
CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error)
}
// NewIDPTokenSessionHandler creates a new IDPTokenSessionHandler.
func NewIDPTokenSessionHandler(
options *Options,
getSession func(ctx context.Context, id string) (*session.Session, error),
putSession func(ctx context.Context, s *session.Session) error,
) *IDPTokenSessionHandler {
return &IDPTokenSessionHandler{
options: options,
getSession: getSession,
putSession: putSession,
type incomingIDPTokenSessionCreator struct {
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
putRecords func(ctx context.Context, records []*databroker.Record) error
}
func NewIncomingIDPTokenSessionCreator(
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
putRecords func(ctx context.Context, records []*databroker.Record) error,
) IncomingIDPTokenSessionCreator {
return &incomingIDPTokenSessionCreator{getRecord: getRecord, putRecords: putRecords}
}
// CreateSession attempts to create a session for incoming idp access and
// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned.
// If the tokens are not valid an error will be returned.
func (c *incomingIDPTokenSessionCreator) CreateSession(
ctx context.Context,
cfg *Config,
policy *Policy,
r *http.Request,
) (session *session.Session, err error) {
if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok {
return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken)
}
if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok {
return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken)
}
return nil, sessions.ErrNoSessionFound
}
// // CreateSessionForIncomingIDPToken creates a session from an incoming idp access or identity token.
// // If no such tokens are found or they are invalid ErrNoSessionFound will be returned.
// func (h *IDPTokenSessionHandler) CreateSessionForIncomingIDPToken(r *http.Request) (*session.Session, error) {
// idp, err := h.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
// if err != nil {
// return nil, err
// }
func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
ctx context.Context,
cfg *Config,
policy *Policy,
rawAccessToken string,
) (*session.Session, error) {
sessionID := uuid.NewSHA1(accessTokenUUIDNamespace, []byte(rawAccessToken)).String()
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
// return nil, sessions.ErrNoSessionFound
// }
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err)
}
// func (h *IDPTokenSessionHandler) getIncomingIDPAccessToken(r *http.Request) (rawAccessToken string, ok bool) {
// if h.options.
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
}
// return "", false
// }
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
AccessToken: rawAccessToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid access token")
}
// func (h *IDPTokenSessionHandler) getIncomingIDPIdentityToken(r *http.Request) (rawIdentityToken string, ok bool) {
// return "", false
// }
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
ExpiresAt: s.ExpiresAt,
}
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
// func CreateSessionForIncomingIDPToken(
// r *http.Request,
// options *Options,
// policy *Policy,
// getSession func(ctx context.Context, id string) (*session.Session, error),
// putSession func(ctx context.Context, s *session.Session) error)(*session.Session, error) {
// }
return s, nil
}
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
ctx context.Context,
cfg *Config,
policy *Policy,
rawIdentityToken string,
) (*session.Session, error) {
sessionID := uuid.NewSHA1(identityTokenUUIDNamespace, []byte(rawIdentityToken)).String()
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err)
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
IdentityToken: rawIdentityToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid identity token")
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
}
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
cfg *Config,
sessionID string,
claims jwtutil.Claims,
) *session.Session {
now := time.Now()
s := new(session.Session)
s.Id = sessionID
if userID, ok := claims.GetUserID(); ok {
s.UserId = userID
}
if issuedAt, ok := claims.GetIssuedAt(); ok {
s.IssuedAt = timestamppb.New(issuedAt)
} else {
s.IssuedAt = timestamppb.New(now)
}
if expiresAt, ok := claims.GetExpirationTime(); ok {
s.ExpiresAt = timestamppb.New(expiresAt)
} else {
s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire))
}
s.AccessedAt = timestamppb.New(now)
s.AddClaims(identity.Claims(claims).Flatten())
if aud, ok := claims.GetAudience(); ok {
s.Audience = aud
}
return s
}
func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
claims jwtutil.Claims,
) *user.User {
u := new(user.User)
if userID, ok := claims.GetUserID(); ok {
u.Id = userID
}
if name, ok := claims.GetString("name"); ok {
u.Name = name
}
if email, ok := claims.GetString("email"); ok {
u.Email = email
}
u.Claims = identity.Claims(claims).Flatten().ToPB()
return u
}
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
if err != nil {
return nil, err
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, storage.ErrNotFound
}
s, ok := msg.(*session.Session)
if !ok {
return nil, storage.ErrNotFound
}
return s, nil
}
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
var records []*databroker.Record
if id := s.GetId(); id != "" {
records = append(records, &databroker.Record{
Type: grpcutil.GetTypeURL(s),
Id: id,
Data: protoutil.NewAny(s),
})
}
if id := u.GetId(); id != "" {
records = append(records, &databroker.Record{
Type: grpcutil.GetTypeURL(u),
Id: id,
Data: protoutil.NewAny(u),
})
}
return c.putRecords(ctx, records)
}
// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
bearerTokenFormat := BearerTokenFormatDefault
if options != nil && options.BearerTokenFormat != nil {
bearerTokenFormat = *options.BearerTokenFormat
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
bearerTokenFormat = *cfg.Options.BearerTokenFormat
}
if policy != nil && policy.BearerTokenFormat != nil {
bearerTokenFormat = *policy.BearerTokenFormat
}
if token := r.Header.Get("X-Pomerium-IDP-Access-Token"); token != "" {
if token := r.Header.Get(httputil.HeaderPomeriumIDPAccessToken); token != "" {
return token, true
}
if auth := r.Header.Get("Authorization"); auth != "" {
prefix := "Pomerium-IDP-Access-Token "
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return strings.TrimPrefix(auth, prefix), true
}
prefix = "Bearer Pomerium-IDP-Access-Token-"
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-"
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return strings.TrimPrefix(auth, prefix), true
}
prefix = "Bearer "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPAccessToken {
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
bearerTokenFormat == BearerTokenFormatIDPAccessToken {
return strings.TrimPrefix(auth, prefix), true
}
}
@ -200,32 +384,33 @@ func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *ht
}
// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
func (options *Options) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
bearerTokenFormat := BearerTokenFormatDefault
if options != nil && options.BearerTokenFormat != nil {
bearerTokenFormat = *options.BearerTokenFormat
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
bearerTokenFormat = *cfg.Options.BearerTokenFormat
}
if policy != nil && policy.BearerTokenFormat != nil {
bearerTokenFormat = *policy.BearerTokenFormat
}
if token := r.Header.Get("X-Pomerium-IDP-Identity-Token"); token != "" {
if token := r.Header.Get(httputil.HeaderPomeriumIDPIdentityToken); token != "" {
return token, true
}
if auth := r.Header.Get("Authorization"); auth != "" {
prefix := "Pomerium-IDP-Identity-Token "
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return strings.TrimPrefix(auth, prefix), true
}
prefix = "Bearer Pomerium-IDP-Identity-Token-"
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-"
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return strings.TrimPrefix(auth, prefix), true
}
prefix = "Bearer "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
return strings.TrimPrefix(auth, prefix), true
}
}

View file

@ -1,7 +1,12 @@
package httputil
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
const AuthorizationTypePomerium = "Pomerium"
// Pomerium authorization types
const (
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
AuthorizationTypePomerium = "Pomerium"
AuthorizationTypePomeriumIDPAccessToken = "Pomerium-IDP-Access-Token" //nolint: gosec
AuthorizationTypePomeriumIDPIdentityToken = "Pomerium-IDP-Identity-Token" //nolint: gosec
)
// Standard headers
const (
@ -16,7 +21,9 @@ const (
// HeaderPomeriumAuthorization is the header key for a pomerium authorization JWT. It
// can be used in place of the standard authorization header if that header is being
// used by upstream applications.
HeaderPomeriumAuthorization = "x-pomerium-authorization"
HeaderPomeriumAuthorization = "x-pomerium-authorization"
HeaderPomeriumIDPAccessToken = "x-pomerium-idp-access-token" //nolint: gosec
HeaderPomeriumIDPIdentityToken = "x-pomerium-idp-identity-token" //nolint: gosec
// HeaderPomeriumResponse is set when pomerium itself creates a response,
// as opposed to the upstream application and can be used to distinguish
// between an application error, and a pomerium related error when debugging.

160
internal/jwtutil/jwtutil.go Normal file
View file

@ -0,0 +1,160 @@
// Package jwtutil contains functions for working with JWTs.
package jwtutil
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"time"
)
// Claims represent claims in a JWT.
type Claims map[string]any
// UnmarshalJSON implements a custom unmarshaller for claims data.
func (claims *Claims) UnmarshalJSON(raw []byte) error {
dst := map[string]any{}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
err := dec.Decode(&dst)
if err != nil {
return err
}
*claims = Claims(dst)
return nil
}
// registered claims
// GetIssuer gets the iss claim.
func (claims Claims) GetIssuer() (issuer string, ok bool) {
return claims.GetString("iss")
}
// GetSubject gets the sub claim.
func (claims Claims) GetSubject() (subject string, ok bool) {
return claims.GetString("sub")
}
// GetAudience gets the aud claim.
func (claims Claims) GetAudience() (audiences []string, ok bool) {
return claims.GetStringSlice("aud")
}
// GetExpirationTime gets the exp claim.
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
return claims.GetNumericDate("exp")
}
// GetNotBefore gets the nbf claim.
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
return claims.GetNumericDate("nbf")
}
// GetIssuedAt gets the iat claim.
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
return claims.GetNumericDate("iat")
}
// GetJWTID gets the jti claim.
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
return claims.GetString("jti")
}
// custom claims
// GetUserID returns the oid or sub claim.
func (claims Claims) GetUserID() (userID string, ok bool) {
if oid, ok := claims.GetString("oid"); ok {
return oid, true
}
if sub, ok := claims.GetSubject(); ok {
return sub, true
}
return "", false
}
// GetNumericDate returns the claim as a numeric date.
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
if claims == nil {
return tm, false
}
raw, ok := claims[name]
if !ok {
return tm, false
}
switch v := raw.(type) {
case float64:
return time.Unix(int64(v), 0), true
case int64:
return time.Unix(v, 0), true
case json.Number:
i, err := v.Int64()
if err != nil {
if f, err := v.Float64(); err == nil {
i = int64(f)
}
}
if err != nil {
return tm, false
}
return time.Unix(i, 0), true
}
return tm, false
}
// GetString returns the claim as a string.
func (claims Claims) GetString(name string) (value string, ok bool) {
if claims == nil {
return value, false
}
raw, ok := claims[name]
if !ok {
return value, false
}
return toString(raw), true
}
// GetStringSlice returns the claim as a slice of strings.
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
if claims == nil {
return nil, false
}
raw, ok := claims[name]
if !ok {
return nil, false
}
return toStringSlice(raw), true
}
func toString(data any) string {
switch v := data.(type) {
case string:
return v
}
return fmt.Sprint(data)
}
func toStringSlice(obj any) []string {
v := reflect.ValueOf(obj)
switch v.Kind() {
case reflect.Slice:
vs := make([]string, v.Len())
for i := 0; i < v.Len(); i++ {
vs[i] = toString(v.Index(i).Interface())
}
return vs
}
return []string{toString(obj)}
}

View file

@ -0,0 +1,109 @@
// Package authenticateapi has the types and methods for the authenticate api.
package authenticateapi
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"github.com/pomerium/pomerium/internal/jwtutil"
)
// VerifyAccessTokenRequest is used to verify access tokens.
type VerifyAccessTokenRequest struct {
AccessToken string `json:"accessToken"`
IdentityProviderID string `json:"identityProviderId,omitempty"`
}
// VerifyIdentityTokenRequest is used to verify identity tokens.
type VerifyIdentityTokenRequest struct {
IdentityToken string `json:"identityToken"`
IdentityProviderID string `json:"identityProviderId,omitempty"`
}
// VerifyTokenResponse is the result of verifying an access or identity token.
type VerifyTokenResponse struct {
Valid bool `json:"valid"`
Claims jwtutil.Claims `json:"claims,omitempty"`
}
// An API is an api client for the authenticate service.
type API struct {
authenticateURL *url.URL
transport http.RoundTripper
}
// New creates a new API client.
func New(
authenticateURL *url.URL,
transport http.RoundTripper,
) *API {
return &API{
authenticateURL: authenticateURL,
transport: transport,
}
}
// VerifyAccessToken verifies an access token.
func (api *API) VerifyAccessToken(ctx context.Context, request *VerifyAccessTokenRequest) (*VerifyTokenResponse, error) {
var response VerifyTokenResponse
err := api.call(ctx, "verify-access-token", request, &response)
if err != nil {
return nil, err
}
return &response, nil
}
// VerifyIdentityToken verifies an identity token.
func (api *API) VerifyIdentityToken(ctx context.Context, request *VerifyIdentityTokenRequest) (*VerifyTokenResponse, error) {
var response VerifyTokenResponse
err := api.call(ctx, "verify-identity-token", request, &response)
if err != nil {
return nil, err
}
return &response, nil
}
func (api *API) call(
ctx context.Context,
endpoint string,
request, response any,
) error {
u := api.authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/" + endpoint,
})
body, err := json.Marshal(request)
if err != nil {
return fmt.Errorf("error marshaling %s http request: %w", endpoint, err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body))
if err != nil {
return fmt.Errorf("error creating %s http request: %w", endpoint, err)
}
res, err := (&http.Client{
Transport: api.transport,
}).Do(req)
if err != nil {
return fmt.Errorf("error executing %s http request: %w", endpoint, err)
}
defer res.Body.Close()
body, err = io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("error reading %s http response: %w", endpoint, err)
}
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("error reading %s http response (body=%s): %w", endpoint, body, err)
}
return nil
}

View file

@ -17,6 +17,7 @@ import (
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/pkg/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
)
@ -95,7 +96,7 @@ func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string)
return nil, fmt.Errorf("error verifying access token: %w", err)
}
claims = map[string]any{}
claims = jwtutil.Claims(map[string]any{})
err = token.Claims(&claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)