authorize: support authenticating with idp tokens (#5484)

* identity: add support for verifying access and identity tokens

* allow overriding with policy option

* authenticate: add verify endpoints

* wip

* implement session creation

* add verify test

* implement idp token login

* fix tests

* add pr permission

* make session ids route-specific

* rename method

* add test

* add access token test

* test for newUserFromIDPClaims

* more tests

* make the session id per-idp

* use type for

* add test

* remove nil checks
This commit is contained in:
Caleb Doxsey 2025-02-18 13:02:06 -07:00 committed by GitHub
parent 6e22b7a19a
commit b9fd926618
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 2791 additions and 885 deletions

View file

@ -3,6 +3,7 @@ name: Benchmark
permissions:
contents: write
deployments: write
pull-requests: write
on:
push:

View file

@ -43,6 +43,16 @@ func (a *Authenticate) Handler() http.Handler {
func (a *Authenticate) Mount(r *mux.Router) {
r.StrictSlash(true)
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
// disable csrf checking for these endpoints
r.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.pomerium/verify-access-token" ||
r.URL.Path == "/.pomerium/verify-identity-token" {
r = csrf.UnsafeSkipCheck(r)
}
h.ServeHTTP(w, r)
})
})
r.Use(func(h http.Handler) http.Handler {
options := a.options.Load()
state := a.state.Load()
@ -95,6 +105,8 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
// routes that don't need a session:
sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
sr.Path("/signed_out").Handler(httputil.HandlerFunc(a.signedOut)).Methods(http.MethodGet)
sr.Path("/verify-access-token").Handler(httputil.HandlerFunc(a.verifyAccessToken)).Methods(http.MethodPost)
sr.Path("/verify-identity-token").Handler(httputil.HandlerFunc(a.verifyIdentityToken)).Methods(http.MethodPost)
// routes that need a session:
sr = sr.NewRoute().Subrouter()

View file

@ -0,0 +1,76 @@
package authenticate
import (
"encoding/json"
"net/http"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/authenticateapi"
)
func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) error {
var req authenticateapi.VerifyAccessTokenRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
authenticator, err := a.cfg.getIdentityProvider(r.Context(), a.tracerProvider, a.options.Load(), req.IdentityProviderID)
if err != nil {
return err
}
var res authenticateapi.VerifyTokenResponse
claims, err := authenticator.VerifyAccessToken(r.Context(), req.AccessToken)
if err == nil {
res.Valid = true
res.Claims = claims
} else {
res.Valid = false
log.Ctx(r.Context()).Info().
Err(err).
Str("idp", authenticator.Name()).
Msg("access token failed verification")
}
err = json.NewEncoder(w).Encode(&res)
if err != nil {
return err
}
return nil
}
func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Request) error {
var req authenticateapi.VerifyIdentityTokenRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
authenticator, err := a.cfg.getIdentityProvider(r.Context(), a.tracerProvider, a.options.Load(), req.IdentityProviderID)
if err != nil {
return err
}
var res authenticateapi.VerifyTokenResponse
claims, err := authenticator.VerifyIdentityToken(r.Context(), req.IdentityToken)
if err == nil {
res.Valid = true
res.Claims = claims
} else {
res.Valid = false
log.Ctx(r.Context()).Info().
Err(err).
Str("idp", authenticator.Name()).
Msg("identity token failed verification")
}
err = json.NewEncoder(w).Encode(&res)
if err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,45 @@
package authenticate_test
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/authenticate"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestVerifyAccessToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
a, err := authenticate.New(ctx, &config.Config{
Options: &config.Options{
CookieSecret: cryptutil.NewBase64Key(),
SharedKey: cryptutil.NewBase64Key(),
AuthenticateCallbackPath: "/oauth2/callback",
AuthenticateURLString: "https://authenticate.example.com",
Provider: "oidc",
ProviderURL: "http://oidc.example.com",
},
})
require.NoError(t, err)
w := httptest.NewRecorder()
r, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://authenticate.example.com/.pomerium/verify-access-token",
strings.NewReader(`{"accessToken":"ACCESS TOKEN"}`))
require.NoError(t, err)
a.Handler().ServeHTTP(w, r)
assert.Equal(t, 200, w.Code)
assert.JSONEq(t, `{"valid":false}`, w.Body.String())
}

View file

@ -3,11 +3,12 @@ package authenticate
import (
"context"
oteltrace "go.opentelemetry.io/otel/trace"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
oteltrace "go.opentelemetry.io/otel/trace"
)
func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.TracerProvider, options *config.Options, idpID string) (identity.Authenticator, error) {
@ -26,7 +27,8 @@ func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.Tr
if err != nil {
return nil, err
}
return identity.NewAuthenticator(ctx, tracerProvider, oauth.Options{
o := oauth.Options{
RedirectURL: redirectURL,
ProviderName: idp.GetType(),
ProviderURL: idp.GetUrl(),
@ -34,5 +36,9 @@ func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.Tr
ClientSecret: idp.GetClientSecret(),
Scopes: idp.GetScopes(),
AuthCodeOptions: idp.GetRequestParams(),
})
}
if v := idp.GetAccessTokenAllowedAudiences(); v != nil {
o.AccessTokenAllowedAudiences = &v.Values
}
return identity.NewAuthenticator(ctx, tracerProvider, o)
}

View file

@ -29,7 +29,7 @@ import (
type Authorize struct {
state *atomicutil.Value[*authorizeState]
store *store.Store
currentOptions *atomicutil.Value[*config.Options]
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
globalCache storage.Cache
groupsCacheWarmer *cacheWarmer
@ -43,7 +43,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
a := &Authorize{
currentOptions: config.NewAtomicOptions(),
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
store: store.New(),
globalCache: storage.NewGlobalCache(time.Minute),
tracerProvider: tracerProvider,
@ -155,7 +155,7 @@ func newPolicyEvaluator(
// OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load()
a.currentOptions.Store(cfg.Options)
a.currentConfig.Store(cfg)
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else {

View file

@ -186,7 +186,7 @@ func (a *Authorize) deniedResponse(
Err: errors.New(reason),
DebugURL: debugEndpoint,
RequestID: requestid.FromContext(ctx),
BrandingOptions: a.currentOptions.Load().BrandingOptions,
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
}
httpErr.ErrorResponse(ctx, w, r)
@ -213,7 +213,7 @@ func (a *Authorize) requireLoginResponse(
in *envoy_service_auth_v3.CheckRequest,
request *evaluator.Request,
) (*envoy_service_auth_v3.CheckResponse, error) {
options := a.currentOptions.Load()
options := a.currentConfig.Load().Options
state := a.state.Load()
if !a.shouldRedirect(in) {
@ -251,7 +251,7 @@ func (a *Authorize) requireWebAuthnResponse(
request *evaluator.Request,
result *evaluator.Result,
) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
state := a.state.Load()
// always assume https scheme
@ -327,7 +327,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
// session that lives on the authenticate service.
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err

View file

@ -127,8 +127,9 @@ func TestAuthorize_okResponse(t *testing.T) {
}},
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
}
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(opt)
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: opt,
}), state: atomicutil.NewValue(new(authorizeState))}
a.store = store.New()
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
require.NoError(t, err)
@ -183,15 +184,16 @@ func TestAuthorize_okResponse(t *testing.T) {
func TestAuthorize_deniedResponse(t *testing.T) {
t.Parallel()
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: &config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
},
}), state: atomicutil.NewValue(new(authorizeState))}
t.Run("json", func(t *testing.T) {
t.Parallel()

View file

@ -3,15 +3,17 @@ package authorize
import (
"context"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
"google.golang.org/grpc"
)
type sessionOrServiceAccount interface {
GetId() string
GetUserId() string
Validate() error
}

View file

@ -3,6 +3,7 @@ package authorize
import (
"context"
"encoding/pem"
"errors"
"io"
"net/http"
"net/url"
@ -21,6 +22,7 @@ import (
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
@ -44,31 +46,25 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
requestID := requestid.FromHTTPHeader(hreq.Header)
ctx = requestid.WithValue(ctx, requestID)
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
var s sessionOrServiceAccount
var u *user.User
var err error
if sessionState != nil {
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if status.Code(err) == codes.Unavailable {
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
return nil, err
} else if err != nil {
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
sessionState = nil
}
}
if sessionState != nil && s != nil {
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState)
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
return nil, err
}
// load the session
s, err := a.loadSession(ctx, hreq, req)
if err != nil {
return nil, err
}
// if there's a session or service account, load the user
var u *user.User
if s != nil {
req.Session.ID = s.GetId()
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
res, err := state.evaluator.Evaluate(ctx, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
@ -88,10 +84,65 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
return resp, err
}
func (a *Authorize) loadSession(
ctx context.Context,
hreq *http.Request,
req *evaluator.Request,
) (s sessionOrServiceAccount, err error) {
requestID := requestid.FromHTTPHeader(hreq.Header)
// attempt to create a session from an incoming idp token
s, err = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
return getDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
// invalidate cache
for _, record := range records {
storage.GetQuerier(ctx).InvalidateCache(ctx, &databroker.QueryRequest{
Type: record.GetType(),
Query: record.GetId(),
Limit: 1,
})
}
return nil
},
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
if err == nil {
return s, nil
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
log.Ctx(ctx).Info().
Str("request-id", requestID).
Err(err).
Msg("error creating session for incoming idp token")
}
sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq)
if sessionState == nil {
return nil, nil
}
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if status.Code(err) == codes.Unavailable {
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
return nil, err
} else if err != nil {
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
return nil, nil
}
return s, nil
}
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
sessionState *sessions.State,
) (*evaluator.Request, error) {
requestURL := getCheckRequestURL(in)
attrs := in.GetAttributes()
@ -106,17 +157,12 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
),
}
if sessionState != nil {
req.Session = evaluator.RequestSession{
ID: sessionState.ID,
}
}
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
return req, nil
}
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
options := a.currentOptions.Load()
options := a.currentConfig.Load().Options
for p := range options.GetAllPolicies() {
id, _ := p.RouteID()

View file

@ -18,7 +18,6 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
@ -49,15 +48,16 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: &config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
},
}), state: atomicutil.NewValue(new(authorizeState))}
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
@ -88,16 +88,10 @@ func Test_getEvaluatorRequest(t *testing.T) {
},
},
},
&sessions.State{
ID: "SESSION_ID",
},
)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Session: evaluator.RequestSession{
ID: "SESSION_ID",
},
Policy: &a.currentConfig.Load().Options.Policies[0],
HTTP: evaluator.NewRequestHTTP(
http.MethodGet,
mustParseURL("http://example.com/some/path?qs=1"),
@ -117,15 +111,16 @@ func Test_getEvaluatorRequest(t *testing.T) {
}
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
Options: &config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
},
}), state: atomicutil.NewValue(new(authorizeState))}
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
@ -145,10 +140,10 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
},
},
},
}, nil)
})
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &a.currentConfig.Load().Options.Policies[0],
Session: evaluator.RequestSession{},
HTTP: evaluator.NewRequestHTTP(
http.MethodGet,

View file

@ -31,7 +31,7 @@ func (a *Authorize) logAuthorizeCheck(
impersonateDetails := a.getImpersonateDetails(ctx, s)
evt := log.Ctx(ctx).Info().Str("service", "authorize")
fields := a.currentOptions.Load().GetAuthorizeLogFields()
fields := a.currentConfig.Load().Options.GetAuthorizeLogFields()
for _, field := range fields {
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res)
}

View file

@ -0,0 +1,98 @@
package config
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/mitchellh/mapstructure"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
)
// BearerTokenFormat specifies how bearer tokens are interepreted by Pomerium.
type BearerTokenFormat string
// Bearer Token Formats
const (
BearerTokenFormatUnknown BearerTokenFormat = ""
BearerTokenFormatDefault BearerTokenFormat = "default"
BearerTokenFormatIDPAccessToken BearerTokenFormat = "idp_access_token"
BearerTokenFormatIDPIdentityToken BearerTokenFormat = "idp_identity_token"
)
// ParseBearerTokenFormat parses the BearerTokenFormat.
func ParseBearerTokenFormat(raw string) (BearerTokenFormat, error) {
switch BearerTokenFormat(strings.TrimSpace(strings.ToLower(raw))) {
case BearerTokenFormatUnknown:
return BearerTokenFormatUnknown, nil
case BearerTokenFormatDefault:
return BearerTokenFormatDefault, nil
case BearerTokenFormatIDPAccessToken:
return BearerTokenFormatIDPAccessToken, nil
case BearerTokenFormatIDPIdentityToken:
return BearerTokenFormatIDPIdentityToken, nil
}
return BearerTokenFormatUnknown, fmt.Errorf("invalid bearer token format: %s", raw)
}
func BearerTokenFormatFromPB(pbBearerTokenFormat *configpb.BearerTokenFormat) *BearerTokenFormat {
if pbBearerTokenFormat == nil {
return nil
}
bearerTokenFormat := new(BearerTokenFormat)
*bearerTokenFormat = BearerTokenFormatDefault
switch *pbBearerTokenFormat {
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_UNKNOWN:
*bearerTokenFormat = BearerTokenFormatUnknown
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_DEFAULT:
*bearerTokenFormat = BearerTokenFormatDefault
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_ACCESS_TOKEN:
*bearerTokenFormat = BearerTokenFormatIDPAccessToken
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN:
*bearerTokenFormat = BearerTokenFormatIDPIdentityToken
}
return bearerTokenFormat
}
// ToEnvoy converts the bearer token format into a protobuf enum.
func (bearerTokenFormat *BearerTokenFormat) ToPB() *configpb.BearerTokenFormat {
if bearerTokenFormat == nil {
return nil
}
switch *bearerTokenFormat {
case BearerTokenFormatUnknown:
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_UNKNOWN.Enum()
case BearerTokenFormatDefault:
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_DEFAULT.Enum()
case BearerTokenFormatIDPAccessToken:
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_ACCESS_TOKEN.Enum()
case BearerTokenFormatIDPIdentityToken:
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN.Enum()
default:
panic(fmt.Sprintf("unknown bearer token format: %v", bearerTokenFormat))
}
}
func decodeBearerTokenFormatHookFunc() mapstructure.DecodeHookFunc {
return func(_, t reflect.Type, data any) (any, error) {
if t != reflect.TypeFor[BearerTokenFormat]() {
return data, nil
}
bs, err := json.Marshal(data)
if err != nil {
return nil, err
}
var raw string
err = json.Unmarshal(bs, &raw)
if err != nil {
return nil, err
}
return ParseBearerTokenFormat(raw)
}
}

View file

@ -4,9 +4,10 @@ import (
"errors"
"github.com/mitchellh/mapstructure"
"github.com/pomerium/pomerium/config/otelconfig"
"github.com/spf13/viper"
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/config/otelconfig"
)
const (
@ -37,6 +38,7 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
DecodePolicyBase64Hook(),
decodeNullBoolHookFunc(),
decodeJWTClaimHeadersHookFunc(),
decodeBearerTokenFormatHookFunc(),
decodeCodecTypeHookFunc(),
decodePPLPolicyHookFunc(),
decodeSANMatcherHookFunc(),

View file

@ -1,6 +1,8 @@
package config
import (
"slices"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
)
@ -43,6 +45,11 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
Url: o.ProviderURL,
RequestParams: o.RequestParams,
}
if v := o.IDPAccessTokenAllowedAudiences; v != nil {
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
Values: slices.Clone(*v),
}
}
if policy != nil {
if policy.IDPClientID != "" {
idp.ClientId = policy.IDPClientID
@ -50,6 +57,11 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
if policy.IDPClientSecret != "" {
idp.ClientSecret = policy.IDPClientSecret
}
if v := policy.IDPAccessTokenAllowedAudiences; v != nil {
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
Values: slices.Clone(*v),
}
}
}
idp.Id = idp.Hash()
return idp, nil

View file

@ -15,6 +15,7 @@ import (
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"time"
@ -158,6 +159,7 @@ type Options struct {
Provider string `mapstructure:"idp_provider" yaml:"idp_provider,omitempty"`
ProviderURL string `mapstructure:"idp_provider_url" yaml:"idp_provider_url,omitempty"`
Scopes []string `mapstructure:"idp_scopes" yaml:"idp_scopes,omitempty"`
IDPAccessTokenAllowedAudiences *[]string `mapstructure:"idp_access_token_allowed_audiences" yaml:"idp_access_token_allowed_audiences,omitempty"`
// RequestParams are custom request params added to the signin request as
// part of an Oauth2 code flow.
@ -194,6 +196,13 @@ type Options struct {
// List of JWT claims to insert as x-pomerium-claim-* headers on proxied requests
JWTClaimsHeaders JWTClaimHeaders `mapstructure:"jwt_claims_headers" yaml:"jwt_claims_headers,omitempty"`
// BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values:
// - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium.
// - "idp_access_token": The Bearer token will be interpreted as an IdP access token.
// - "idp_identity_token": The Bearer token will be interpreted as an IdP identity token.
// When unset "default" will be used.
BearerTokenFormat *BearerTokenFormat `mapstructure:"bearer_token_format" yaml:"bearer_token_format,omitempty"`
// Allowlist of group names/IDs to include in the Pomerium JWT.
JWTGroupsFilter JWTGroupsFilter
@ -1487,6 +1496,12 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
set(&o.ProviderURL, settings.IdpProviderUrl)
setSlice(&o.Scopes, settings.Scopes)
setMap(&o.RequestParams, settings.RequestParams)
if settings.IdpAccessTokenAllowedAudiences != nil {
values := slices.Clone(settings.IdpAccessTokenAllowedAudiences.Values)
o.IDPAccessTokenAllowedAudiences = &values
} else {
o.IDPAccessTokenAllowedAudiences = nil
}
setSlice(&o.AuthorizeURLStrings, settings.AuthorizeServiceUrls)
set(&o.AuthorizeInternalURLString, settings.AuthorizeInternalServiceUrl)
set(&o.OverrideCertificateName, settings.OverrideCertificateName)
@ -1495,6 +1510,7 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
set(&o.SigningKey, settings.SigningKey)
setMap(&o.SetResponseHeaders, settings.SetResponseHeaders)
setMap(&o.JWTClaimsHeaders, settings.JwtClaimsHeaders)
o.BearerTokenFormat = BearerTokenFormatFromPB(settings.BearerTokenFormat)
if len(settings.JwtGroupsFilter) > 0 {
o.JWTGroupsFilter = NewJWTGroupsFilter(settings.JwtGroupsFilter)
}
@ -1591,6 +1607,13 @@ func (o *Options) ToProto() *config.Config {
copySrcToOptionalDest(&settings.IdpProviderUrl, &o.ProviderURL)
settings.Scopes = o.Scopes
settings.RequestParams = o.RequestParams
if o.IDPAccessTokenAllowedAudiences != nil {
settings.IdpAccessTokenAllowedAudiences = &config.Settings_StringList{
Values: slices.Clone(*o.IDPAccessTokenAllowedAudiences),
}
} else {
settings.IdpAccessTokenAllowedAudiences = nil
}
settings.AuthorizeServiceUrls = o.AuthorizeURLStrings
copySrcToOptionalDest(&settings.AuthorizeInternalServiceUrl, &o.AuthorizeInternalURLString)
copySrcToOptionalDest(&settings.OverrideCertificateName, &o.OverrideCertificateName)
@ -1599,6 +1622,7 @@ func (o *Options) ToProto() *config.Config {
copySrcToOptionalDest(&settings.SigningKey, valueOrFromFileBase64(o.SigningKey, o.SigningKeyFile))
settings.SetResponseHeaders = o.SetResponseHeaders
settings.JwtClaimsHeaders = o.JWTClaimsHeaders
settings.BearerTokenFormat = o.BearerTokenFormat.ToPB()
settings.JwtGroupsFilter = o.JWTGroupsFilter.ToSlice()
copyOptionalDuration(&settings.DefaultUpstreamTimeout, o.DefaultUpstreamTimeout)
copySrcToOptionalDest(&settings.MetricsAddress, &o.MetricsAddr)

View file

@ -8,6 +8,7 @@ import (
"net/url"
"os"
"regexp"
"slices"
"sort"
"strings"
"time"
@ -167,6 +168,12 @@ type Policy struct {
// - "hostOnly" (default): Issuer strings will be the hostname of the route, with no scheme or trailing slash.
// - "uri": Issuer strings will be a complete URI, including the scheme and ending with a trailing slash.
JWTIssuerFormat string `mapstructure:"jwt_issuer_format" yaml:"jwt_issuer_format,omitempty"`
// BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values:
// - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium
// - "idp_access_token": The Bearer token will be interpreted as an IdP access token.
// - "idp_identity_token": The Bearer token will be interpreted as an IdP identity token.
// When unset the global option will be used.
BearerTokenFormat *BearerTokenFormat `mapstructure:"bearer_token_format" yaml:"bearer_token_format,omitempty"`
// Allowlist of group names/IDs to include in the Pomerium JWT.
// This expands on any global allowlist set in the main Options.
@ -186,6 +193,8 @@ type Policy struct {
IDPClientID string `mapstructure:"idp_client_id" yaml:"idp_client_id,omitempty"`
// IDPClientSecret is the client secret used for the identity provider.
IDPClientSecret string `mapstructure:"idp_client_secret" yaml:"idp_client_secret,omitempty"`
// IDPAccessTokenAllowedAudiences are the allowed audiences for idp access token validation.
IDPAccessTokenAllowedAudiences *[]string `mapstructure:"idp_access_token_allowed_audiences" yaml:"idp_access_token_allowed_audiences,omitempty"`
// ShowErrorDetails indicates whether or not additional error details should be displayed.
ShowErrorDetails bool `mapstructure:"show_error_details" yaml:"show_error_details" json:"show_error_details"`
@ -332,6 +341,12 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
TLSUpstreamServerName: pb.GetTlsUpstreamServerName(),
UpstreamTimeout: timeout,
}
if pb.IdpAccessTokenAllowedAudiences != nil {
values := slices.Clone(pb.IdpAccessTokenAllowedAudiences.Values)
p.IDPAccessTokenAllowedAudiences = &values
} else {
p.IDPAccessTokenAllowedAudiences = nil
}
if pb.Redirect.IsSet() {
p.Redirect = &PolicyRedirect{
HTTPSRedirect: pb.Redirect.HttpsRedirect,
@ -380,6 +395,8 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
p.JWTIssuerFormat = "uri"
}
p.BearerTokenFormat = BearerTokenFormatFromPB(pb.BearerTokenFormat)
for _, rwh := range pb.RewriteResponseHeaders {
p.RewriteResponseHeaders = append(p.RewriteResponseHeaders, RewriteHeader{
Header: rwh.GetHeader(),
@ -505,6 +522,13 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
if p.IDPClientSecret != "" {
pb.IdpClientSecret = proto.String(p.IDPClientSecret)
}
if p.IDPAccessTokenAllowedAudiences != nil {
pb.IdpAccessTokenAllowedAudiences = &configpb.Route_StringList{
Values: slices.Clone(*p.IDPAccessTokenAllowedAudiences),
}
} else {
pb.IdpAccessTokenAllowedAudiences = nil
}
if p.Redirect != nil {
pb.Redirect = &configpb.RouteRedirect{
HttpsRedirect: p.Redirect.HTTPSRedirect,
@ -538,6 +562,8 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
pb.JwtIssuerFormat = configpb.IssuerFormat_IssuerURI
}
pb.BearerTokenFormat = p.BearerTokenFormat.ToPB()
for _, rwh := range p.RewriteResponseHeaders {
pb.RewriteResponseHeaders = append(pb.RewriteResponseHeaders, &configpb.RouteRewriteHeader{
Header: rwh.Header,

View file

@ -1,16 +1,33 @@
package config
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/authenticateapi"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
// A SessionStore saves and loads sessions based on the options.
@ -112,3 +129,310 @@ func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessio
func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v any) error {
return store.store.SaveSession(w, r, v)
}
type IncomingIDPTokenSessionCreator interface {
CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error)
}
type incomingIDPTokenSessionCreator struct {
timeNow func() time.Time
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
putRecords func(ctx context.Context, records []*databroker.Record) error
}
func NewIncomingIDPTokenSessionCreator(
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
putRecords func(ctx context.Context, records []*databroker.Record) error,
) IncomingIDPTokenSessionCreator {
return &incomingIDPTokenSessionCreator{timeNow: time.Now, getRecord: getRecord, putRecords: putRecords}
}
// CreateSession attempts to create a session for incoming idp access and
// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned.
// If the tokens are not valid an error will be returned.
func (c *incomingIDPTokenSessionCreator) CreateSession(
ctx context.Context,
cfg *Config,
policy *Policy,
r *http.Request,
) (session *session.Session, err error) {
if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok {
return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken)
}
if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok {
return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken)
}
return nil, sessions.ErrNoSessionFound
}
func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
ctx context.Context,
cfg *Config,
policy *Policy,
rawAccessToken string,
) (*session.Session, error) {
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err)
}
sessionID := getAccessTokenSessionID(idp, rawAccessToken)
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
AccessToken: rawAccessToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid access token")
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
ExpiresAt: s.ExpiresAt,
}
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
}
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
ctx context.Context,
cfg *Config,
policy *Policy,
rawIdentityToken string,
) (*session.Session, error) {
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err)
}
sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
IdentityToken: rawIdentityToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid identity token")
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
}
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
cfg *Config,
sessionID string,
claims jwtutil.Claims,
) *session.Session {
now := c.timeNow()
s := new(session.Session)
s.Id = sessionID
if userID, ok := claims.GetUserID(); ok {
s.UserId = userID
}
if issuedAt, ok := claims.GetIssuedAt(); ok {
s.IssuedAt = timestamppb.New(issuedAt)
} else {
s.IssuedAt = timestamppb.New(now)
}
if expiresAt, ok := claims.GetExpirationTime(); ok {
s.ExpiresAt = timestamppb.New(expiresAt)
} else {
s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire))
}
s.AccessedAt = timestamppb.New(now)
s.AddClaims(identity.Claims(claims).Flatten())
if aud, ok := claims.GetAudience(); ok {
s.Audience = aud
}
return s
}
func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
claims jwtutil.Claims,
) *user.User {
u := new(user.User)
if userID, ok := claims.GetUserID(); ok {
u.Id = userID
}
if name, ok := claims.GetString("name"); ok {
u.Name = name
}
if email, ok := claims.GetString("email"); ok {
u.Email = email
}
u.Claims = identity.Claims(claims).Flatten().ToPB()
return u
}
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
if err != nil {
return nil, err
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, storage.ErrNotFound
}
s, ok := msg.(*session.Session)
if !ok {
return nil, storage.ErrNotFound
}
return s, nil
}
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
var records []*databroker.Record
if id := s.GetId(); id != "" {
records = append(records, &databroker.Record{
Type: grpcutil.GetTypeURL(s),
Id: id,
Data: protoutil.NewAny(s),
})
}
if id := u.GetId(); id != "" {
records = append(records, &databroker.Record{
Type: grpcutil.GetTypeURL(u),
Id: id,
Data: protoutil.NewAny(u),
})
}
return c.putRecords(ctx, records)
}
// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
bearerTokenFormat := BearerTokenFormatUnknown
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
bearerTokenFormat = *cfg.Options.BearerTokenFormat
}
if policy != nil && policy.BearerTokenFormat != nil {
bearerTokenFormat = *policy.BearerTokenFormat
}
if token := r.Header.Get(httputil.HeaderPomeriumIDPAccessToken); token != "" {
return token, true
}
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return auth[len(prefix):], true
}
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-"
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return auth[len(prefix):], true
}
prefix = "Bearer "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
bearerTokenFormat == BearerTokenFormatIDPAccessToken {
return auth[len(prefix):], true
}
}
return "", false
}
// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
bearerTokenFormat := BearerTokenFormatDefault
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
bearerTokenFormat = *cfg.Options.BearerTokenFormat
}
if policy != nil && policy.BearerTokenFormat != nil {
bearerTokenFormat = *policy.BearerTokenFormat
}
if token := r.Header.Get(httputil.HeaderPomeriumIDPIdentityToken); token != "" {
return token, true
}
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return auth[len(prefix):], true
}
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-"
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return auth[len(prefix):], true
}
prefix = "Bearer "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
return auth[len(prefix):], true
}
}
return "", false
}
var accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d")
func getAccessTokenSessionID(idp *identitypb.Provider, rawAccessToken string) string {
namespace := accessTokenUUIDNamespace
// make the session ID per-idp settings
if idp != nil {
namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
}
return uuid.NewSHA1(namespace, []byte(rawAccessToken)).String()
}
var identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17")
func getIdentityTokenSessionID(idp *identitypb.Provider, rawIdentityToken string) string {
namespace := identityTokenUUIDNamespace
// make the session ID per-idp settings
if idp != nil {
namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
}
return uuid.NewSHA1(namespace, []byte(rawIdentityToken)).String()
}

View file

@ -1,20 +1,34 @@
package config
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/authenticateapi"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestSessionStore_LoadSessionState(t *testing.T) {
@ -164,3 +178,357 @@ func TestGetIdentityProviderDetectsChangesToAuthenticateServiceURL(t *testing.T)
assert.NotEqual(t, idp1.GetId(), idp2.GetId(),
"identity provider should change when authenticate service url changes")
}
func Test_getTokenSessionID(t *testing.T) {
t.Parallel()
assert.Equal(t, "532b0a3d-b413-50a0-8c9f-e6eb340a05d3", getAccessTokenSessionID(nil, "TOKEN"))
assert.Equal(t, "e0b8096c-54dd-5623-8098-5488f9c302db", getIdentityTokenSessionID(nil, "TOKEN"))
assert.Equal(t, "9c99d1d0-805e-51cb-b808-772ab654268b", getAccessTokenSessionID(&identitypb.Provider{Id: "IDP1"}, "TOKEN"))
assert.Equal(t, "0fe0e289-40bb-5ffe-b328-e290e043a652", getIdentityTokenSessionID(&identitypb.Provider{Id: "IDP1"}, "TOKEN"))
}
func TestGetIncomingIDPAccessTokenForPolicy(t *testing.T) {
t.Parallel()
bearerTokenFormatIDPAccessToken := BearerTokenFormatIDPAccessToken
for _, tc := range []struct {
name string
globalBearerTokenFormat *BearerTokenFormat
routeBearerTokenFormat *BearerTokenFormat
headers http.Header
expectedOK bool
expectedToken string
}{
{
name: "empty headers",
expectedOK: false,
},
{
name: "custom header",
headers: http.Header{"X-Pomerium-Idp-Access-Token": {"access token via custom header"}},
expectedOK: true,
expectedToken: "access token via custom header",
},
{
name: "custom authorization",
headers: http.Header{"Authorization": {"Pomerium-Idp-Access-Token access token via custom authorization"}},
expectedOK: true,
expectedToken: "access token via custom authorization",
},
{
name: "custom bearer",
headers: http.Header{"Authorization": {"Bearer Pomerium-Idp-Access-Token-access token via custom bearer"}},
expectedOK: true,
expectedToken: "access token via custom bearer",
},
{
name: "bearer disabled",
headers: http.Header{"Authorization": {"Bearer access token via bearer"}},
expectedOK: false,
},
{
name: "bearer enabled via options",
globalBearerTokenFormat: &bearerTokenFormatIDPAccessToken,
headers: http.Header{"Authorization": {"Bearer access token via bearer"}},
expectedOK: true,
expectedToken: "access token via bearer",
},
{
name: "bearer enabled via route",
routeBearerTokenFormat: &bearerTokenFormatIDPAccessToken,
headers: http.Header{"Authorization": {"Bearer access token via bearer"}},
expectedOK: true,
expectedToken: "access token via bearer",
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
cfg := &Config{
Options: NewDefaultOptions(),
}
cfg.Options.BearerTokenFormat = tc.globalBearerTokenFormat
var route *Policy
if tc.routeBearerTokenFormat != nil {
route = &Policy{
BearerTokenFormat: tc.routeBearerTokenFormat,
}
}
r, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
require.NoError(t, err)
if tc.headers != nil {
r.Header = tc.headers
}
actualToken, actualOK := cfg.GetIncomingIDPAccessTokenForPolicy(route, r)
assert.Equal(t, tc.expectedOK, actualOK)
assert.Equal(t, tc.expectedToken, actualToken)
})
}
}
func TestGetIncomingIDPIdentityTokenForPolicy(t *testing.T) {
t.Parallel()
bearerTokenFormatIDPIdentityToken := BearerTokenFormatIDPIdentityToken
for _, tc := range []struct {
name string
globalBearerTokenFormat *BearerTokenFormat
routeBearerTokenFormat *BearerTokenFormat
headers http.Header
expectedOK bool
expectedToken string
}{
{
name: "empty headers",
expectedOK: false,
},
{
name: "custom header",
headers: http.Header{"X-Pomerium-Idp-Identity-Token": {"identity token via custom header"}},
expectedOK: true,
expectedToken: "identity token via custom header",
},
{
name: "custom authorization",
headers: http.Header{"Authorization": {"Pomerium-Idp-Identity-Token identity token via custom authorization"}},
expectedOK: true,
expectedToken: "identity token via custom authorization",
},
{
name: "custom bearer",
headers: http.Header{"Authorization": {"Bearer Pomerium-Idp-Identity-Token-identity token via custom bearer"}},
expectedOK: true,
expectedToken: "identity token via custom bearer",
},
{
name: "bearer disabled",
headers: http.Header{"Authorization": {"Bearer identity token via bearer"}},
expectedOK: false,
},
{
name: "bearer enabled via options",
globalBearerTokenFormat: &bearerTokenFormatIDPIdentityToken,
headers: http.Header{"Authorization": {"Bearer identity token via bearer"}},
expectedOK: true,
expectedToken: "identity token via bearer",
},
{
name: "bearer enabled via route",
routeBearerTokenFormat: &bearerTokenFormatIDPIdentityToken,
headers: http.Header{"Authorization": {"Bearer identity token via bearer"}},
expectedOK: true,
expectedToken: "identity token via bearer",
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
cfg := &Config{
Options: NewDefaultOptions(),
}
cfg.Options.BearerTokenFormat = tc.globalBearerTokenFormat
var route *Policy
if tc.routeBearerTokenFormat != nil {
route = &Policy{
BearerTokenFormat: tc.routeBearerTokenFormat,
}
}
r, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
require.NoError(t, err)
if tc.headers != nil {
r.Header = tc.headers
}
actualToken, actualOK := cfg.GetIncomingIDPIdentityTokenForPolicy(route, r)
assert.Equal(t, tc.expectedOK, actualOK)
assert.Equal(t, tc.expectedToken, actualToken)
})
}
}
func Test_newSessionFromIDPClaims(t *testing.T) {
t.Parallel()
tm1 := time.Date(2025, 2, 18, 8, 6, 0, 0, time.UTC)
tm2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
tm3 := tm2.Add(time.Hour)
for _, tc := range []struct {
name string
sessionID string
claims jwtutil.Claims
expect *session.Session
}{
{
"empty claims", "S1",
nil,
&session.Session{
Id: "S1",
AccessedAt: timestamppb.New(tm1),
ExpiresAt: timestamppb.New(tm1.Add(time.Hour * 14)),
IssuedAt: timestamppb.New(tm1),
},
},
{
"full claims", "S2",
jwtutil.Claims{
"aud": "https://www.example.com",
"sub": "U1",
"iat": tm2.Unix(),
"exp": tm3.Unix(),
},
&session.Session{
Id: "S2",
UserId: "U1",
AccessedAt: timestamppb.New(tm1),
ExpiresAt: timestamppb.New(tm3),
IssuedAt: timestamppb.New(tm2),
Audience: []string{"https://www.example.com"},
Claims: identity.FlattenedClaims{
"aud": {"https://www.example.com"},
"sub": {"U1"},
"iat": {tm2.Unix()},
"exp": {tm3.Unix()},
}.ToPB(),
},
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
cfg := &Config{Options: NewDefaultOptions()}
c := &incomingIDPTokenSessionCreator{
timeNow: func() time.Time { return tm1 },
}
actual := c.newSessionFromIDPClaims(cfg, tc.sessionID, tc.claims)
testutil.AssertProtoEqual(t, tc.expect, actual)
})
}
}
func Test_newUserFromIDPClaims(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
claims jwtutil.Claims
expect *user.User
}{
{"empty claims", nil, &user.User{}},
{"full claims", jwtutil.Claims{
"sub": "USER_ID",
"name": "NAME",
"email": "EMAIL",
}, &user.User{
Id: "USER_ID",
Name: "NAME",
Email: "EMAIL",
Claims: identity.FlattenedClaims{
"sub": {"USER_ID"},
"name": {"NAME"},
"email": {"EMAIL"},
}.ToPB(),
}},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
actual := new(incomingIDPTokenSessionCreator).newUserFromIDPClaims(tc.claims)
testutil.AssertProtoEqual(t, tc.expect, actual)
})
}
}
func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) {
t.Parallel()
t.Run("access_token", func(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/.pomerium/verify-access-token", func(w http.ResponseWriter, _ *http.Request) {
json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{
Valid: true,
Claims: jwtutil.Claims{"sub": "U1"},
})
})
srv := httptest.NewTLSServer(mux)
ctx := testutil.GetContext(t, time.Minute)
cfg := &Config{Options: NewDefaultOptions()}
cfg.Options.AuthenticateURLString = srv.URL
cfg.Options.ClientSecret = "CLIENT_SECRET_1"
cfg.Options.ClientID = "CLIENT_ID_1"
route := &Policy{}
route.IDPClientSecret = "CLIENT_SECRET_2"
route.IDPClientID = "CLIENT_ID_2"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.example.com", nil)
require.NoError(t, err)
req.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
c := NewIncomingIDPTokenSessionCreator(
func(_ context.Context, recordType, _ string) (*databroker.Record, error) {
assert.Equal(t, "type.googleapis.com/session.Session", recordType)
return nil, storage.ErrNotFound
},
func(_ context.Context, records []*databroker.Record) error {
if assert.Len(t, records, 2, "should put session and user") {
assert.Equal(t, "type.googleapis.com/session.Session", records[0].Type)
assert.Equal(t, "type.googleapis.com/user.User", records[1].Type)
}
return nil
},
)
s, err := c.CreateSession(ctx, cfg, route, req)
assert.NoError(t, err)
assert.Equal(t, "U1", s.GetUserId())
assert.Equal(t, "ACCESS_TOKEN", s.GetOauthToken().GetAccessToken())
})
t.Run("identity_token", func(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/.pomerium/verify-identity-token", func(w http.ResponseWriter, _ *http.Request) {
json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{
Valid: true,
Claims: jwtutil.Claims{"sub": "U1"},
})
})
srv := httptest.NewTLSServer(mux)
ctx := testutil.GetContext(t, time.Minute)
cfg := &Config{Options: NewDefaultOptions()}
cfg.Options.AuthenticateURLString = srv.URL
cfg.Options.ClientSecret = "CLIENT_SECRET_1"
cfg.Options.ClientID = "CLIENT_ID_1"
route := &Policy{}
route.IDPClientSecret = "CLIENT_SECRET_2"
route.IDPClientID = "CLIENT_ID_2"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.example.com", nil)
require.NoError(t, err)
req.Header.Set(httputil.HeaderPomeriumIDPIdentityToken, "IDENTITY_TOKEN")
c := NewIncomingIDPTokenSessionCreator(
func(_ context.Context, recordType, _ string) (*databroker.Record, error) {
assert.Equal(t, "type.googleapis.com/session.Session", recordType)
return nil, storage.ErrNotFound
},
func(_ context.Context, records []*databroker.Record) error {
if assert.Len(t, records, 2, "should put session and user") {
assert.Equal(t, "type.googleapis.com/session.Session", records[0].Type)
assert.Equal(t, "type.googleapis.com/user.User", records[1].Type)
}
return nil
},
)
s, err := c.CreateSession(ctx, cfg, route, req)
assert.NoError(t, err)
assert.Equal(t, "U1", s.GetUserId())
})
}

View file

@ -1,7 +1,12 @@
package httputil
// Pomerium authorization types
const (
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
const AuthorizationTypePomerium = "Pomerium"
AuthorizationTypePomerium = "Pomerium"
AuthorizationTypePomeriumIDPAccessToken = "Pomerium-IDP-Access-Token" //nolint: gosec
AuthorizationTypePomeriumIDPIdentityToken = "Pomerium-IDP-Identity-Token" //nolint: gosec
)
// Standard headers
const (
@ -17,6 +22,8 @@ const (
// can be used in place of the standard authorization header if that header is being
// used by upstream applications.
HeaderPomeriumAuthorization = "x-pomerium-authorization"
HeaderPomeriumIDPAccessToken = "x-pomerium-idp-access-token" //nolint: gosec
HeaderPomeriumIDPIdentityToken = "x-pomerium-idp-identity-token" //nolint: gosec
// HeaderPomeriumResponse is set when pomerium itself creates a response,
// as opposed to the upstream application and can be used to distinguish
// between an application error, and a pomerium related error when debugging.

172
internal/jwtutil/jwtutil.go Normal file
View file

@ -0,0 +1,172 @@
// Package jwtutil contains functions for working with JWTs.
package jwtutil
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"time"
)
// Claims represent claims in a JWT.
type Claims map[string]any
// UnmarshalJSON implements a custom unmarshaller for claims data.
func (claims *Claims) UnmarshalJSON(raw []byte) error {
dst := map[string]any{}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
err := dec.Decode(&dst)
if err != nil {
return err
}
*claims = Claims(dst)
return nil
}
// registered claims
// GetIssuer gets the iss claim.
func (claims Claims) GetIssuer() (issuer string, ok bool) {
return claims.GetString("iss")
}
// GetSubject gets the sub claim.
func (claims Claims) GetSubject() (subject string, ok bool) {
return claims.GetString("sub")
}
// GetAudience gets the aud claim.
func (claims Claims) GetAudience() (audiences []string, ok bool) {
return claims.GetStringSlice("aud")
}
// GetExpirationTime gets the exp claim.
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
return claims.GetNumericDate("exp")
}
// GetNotBefore gets the nbf claim.
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
return claims.GetNumericDate("nbf")
}
// GetIssuedAt gets the iat claim.
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
return claims.GetNumericDate("iat")
}
// GetJWTID gets the jti claim.
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
return claims.GetString("jti")
}
// custom claims
// GetUserID returns the oid or sub claim.
func (claims Claims) GetUserID() (userID string, ok bool) {
if oid, ok := claims.GetString("oid"); ok {
return oid, true
}
if sub, ok := claims.GetSubject(); ok {
return sub, true
}
return "", false
}
// GetNumericDate returns the claim as a numeric date.
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
if claims == nil {
return tm, false
}
raw, ok := claims[name]
if !ok {
return tm, false
}
switch v := raw.(type) {
case float32:
return time.Unix(int64(v), 0), true
case float64:
return time.Unix(int64(v), 0), true
case int64:
return time.Unix(v, 0), true
case int32:
return time.Unix(int64(v), 0), true
case int16:
return time.Unix(int64(v), 0), true
case int8:
return time.Unix(int64(v), 0), true
case int:
return time.Unix(int64(v), 0), true
case uint64:
return time.Unix(int64(v), 0), true
case uint32:
return time.Unix(int64(v), 0), true
case uint16:
return time.Unix(int64(v), 0), true
case uint8:
return time.Unix(int64(v), 0), true
case uint:
return time.Unix(int64(v), 0), true
case json.Number:
i, err := v.Int64()
if err != nil {
if f, err := v.Float64(); err == nil {
i = int64(f)
}
}
if err != nil {
return tm, false
}
return time.Unix(i, 0), true
}
return tm, false
}
// GetString returns the claim as a string.
func (claims Claims) GetString(name string) (value string, ok bool) {
raw, ok := claims[name]
if !ok {
return value, false
}
return toString(raw), true
}
// GetStringSlice returns the claim as a slice of strings.
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
raw, ok := claims[name]
if !ok {
return nil, false
}
return toStringSlice(raw), true
}
func toString(data any) string {
switch v := data.(type) {
case string:
return v
}
return fmt.Sprint(data)
}
func toStringSlice(obj any) []string {
v := reflect.ValueOf(obj)
switch v.Kind() {
case reflect.Slice:
vs := make([]string, v.Len())
for i := 0; i < v.Len(); i++ {
vs[i] = toString(v.Index(i).Interface())
}
return vs
}
return []string{toString(obj)}
}

View file

@ -0,0 +1,109 @@
// Package authenticateapi has the types and methods for the authenticate api.
package authenticateapi
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"github.com/pomerium/pomerium/internal/jwtutil"
)
// VerifyAccessTokenRequest is used to verify access tokens.
type VerifyAccessTokenRequest struct {
AccessToken string `json:"accessToken"`
IdentityProviderID string `json:"identityProviderId,omitempty"`
}
// VerifyIdentityTokenRequest is used to verify identity tokens.
type VerifyIdentityTokenRequest struct {
IdentityToken string `json:"identityToken"`
IdentityProviderID string `json:"identityProviderId,omitempty"`
}
// VerifyTokenResponse is the result of verifying an access or identity token.
type VerifyTokenResponse struct {
Valid bool `json:"valid"`
Claims jwtutil.Claims `json:"claims,omitempty"`
}
// An API is an api client for the authenticate service.
type API struct {
authenticateURL *url.URL
transport http.RoundTripper
}
// New creates a new API client.
func New(
authenticateURL *url.URL,
transport http.RoundTripper,
) *API {
return &API{
authenticateURL: authenticateURL,
transport: transport,
}
}
// VerifyAccessToken verifies an access token.
func (api *API) VerifyAccessToken(ctx context.Context, request *VerifyAccessTokenRequest) (*VerifyTokenResponse, error) {
var response VerifyTokenResponse
err := api.call(ctx, "verify-access-token", request, &response)
if err != nil {
return nil, err
}
return &response, nil
}
// VerifyIdentityToken verifies an identity token.
func (api *API) VerifyIdentityToken(ctx context.Context, request *VerifyIdentityTokenRequest) (*VerifyTokenResponse, error) {
var response VerifyTokenResponse
err := api.call(ctx, "verify-identity-token", request, &response)
if err != nil {
return nil, err
}
return &response, nil
}
func (api *API) call(
ctx context.Context,
endpoint string,
request, response any,
) error {
u := api.authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/" + endpoint,
})
body, err := json.Marshal(request)
if err != nil {
return fmt.Errorf("error marshaling %s http request: %w", endpoint, err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body))
if err != nil {
return fmt.Errorf("error creating %s http request: %w", endpoint, err)
}
res, err := (&http.Client{
Transport: api.transport,
}).Do(req)
if err != nil {
return fmt.Errorf("error executing %s http request: %w", endpoint, err)
}
defer res.Body.Close()
body, err = io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("error reading %s http response: %w", endpoint, err)
}
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("error reading %s http response (body=%s): %w", endpoint, body, err)
}
return nil
}

File diff suppressed because it is too large Load diff

View file

@ -45,8 +45,17 @@ enum IssuerFormat {
IssuerURI = 1;
}
// Next ID: 69.
enum BearerTokenFormat {
BEARER_TOKEN_FORMAT_UNKNOWN = 0;
BEARER_TOKEN_FORMAT_DEFAULT = 1;
BEARER_TOKEN_FORMAT_IDP_ACCESS_TOKEN = 2;
BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN = 3;
}
// Next ID: 71.
message Route {
message StringList { repeated string values = 1; }
string name = 1;
string description = 67;
string logo_url = 68;
@ -116,6 +125,7 @@ message Route {
bool enable_google_cloud_serverless_authentication = 42;
IssuerFormat jwt_issuer_format = 65;
repeated string jwt_groups_filter = 66;
optional BearerTokenFormat bearer_token_format = 70;
envoy.config.cluster.v3.Cluster envoy_opts = 36;
@ -130,6 +140,7 @@ message Route {
optional string idp_client_id = 55;
optional string idp_client_secret = 56;
optional StringList idp_access_token_allowed_audiences = 69;
bool show_error_details = 59;
}
@ -149,7 +160,7 @@ message Policy {
string remediation = 9;
}
// Next ID: 137.
// Next ID: 139.
message Settings {
message Certificate {
bytes cert_bytes = 3;
@ -188,6 +199,7 @@ message Settings {
optional string idp_client_secret = 23;
optional string idp_provider = 24;
optional string idp_provider_url = 25;
optional StringList idp_access_token_allowed_audiences = 137;
repeated string scopes = 26;
// optional string idp_service_account = 27;
// optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
@ -203,6 +215,7 @@ message Settings {
// repeated string jwt_claims_headers = 37;
map<string, string> jwt_claims_headers = 63;
repeated string jwt_groups_filter = 119;
optional BearerTokenFormat bearer_token_format = 138;
optional google.protobuf.Duration default_upstream_timeout = 39;
optional string metrics_address = 40;
optional string metrics_basic_auth = 64;

View file

@ -35,6 +35,7 @@ type Provider struct {
// string service_account = 6;
Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"`
RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
AccessTokenAllowedAudiences *Provider_StringList `protobuf:"bytes,10,opt,name=access_token_allowed_audiences,json=accessTokenAllowedAudiences,proto3,oneof" json:"access_token_allowed_audiences,omitempty"`
}
func (x *Provider) Reset() {
@ -125,6 +126,13 @@ func (x *Provider) GetRequestParams() map[string]string {
return nil
}
func (x *Provider) GetAccessTokenAllowedAudiences() *Provider_StringList {
if x != nil {
return x.AccessTokenAllowedAudiences
}
return nil
}
type Profile struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@ -196,6 +204,53 @@ func (x *Profile) GetClaims() *structpb.Struct {
return nil
}
type Provider_StringList struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"`
}
func (x *Provider_StringList) Reset() {
*x = Provider_StringList{}
if protoimpl.UnsafeEnabled {
mi := &file_identity_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Provider_StringList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Provider_StringList) ProtoMessage() {}
func (x *Provider_StringList) ProtoReflect() protoreflect.Message {
mi := &file_identity_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Provider_StringList.ProtoReflect.Descriptor instead.
func (*Provider_StringList) Descriptor() ([]byte, []int) {
return file_identity_proto_rawDescGZIP(), []int{0, 0}
}
func (x *Provider_StringList) GetValues() []string {
if x != nil {
return x.Values
}
return nil
}
var File_identity_proto protoreflect.FileDescriptor
var file_identity_proto_rawDesc = []byte{
@ -203,7 +258,7 @@ var file_identity_proto_rawDesc = []byte{
0x12, 0x11, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74,
0x69, 0x74, 0x79, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x22, 0xed, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0e,
0x6f, 0x22, 0xa8, 0x04, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0e,
0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x38,
0x0a, 0x18, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x73,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09,
@ -221,25 +276,36 @@ var file_identity_proto_rawDesc = []byte{
0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e,
0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79,
0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a,
0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73,
0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38,
0x01, 0x22, 0x97, 0x01, 0x0a, 0x07, 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x12, 0x1f, 0x0a,
0x0b, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x0a, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x49, 0x64, 0x12, 0x19,
0x0a, 0x08, 0x69, 0x64, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x69, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x61, 0x75,
0x74, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a,
0x6f, 0x61, 0x75, 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6c,
0x61, 0x69, 0x6d, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f,
0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72,
0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x42, 0x30, 0x5a, 0x2e, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
0x67, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x33,
0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12,
0x70, 0x0a, 0x1e, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f,
0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76,
0x69, 0x64, 0x65, 0x72, 0x2e, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x48,
0x00, 0x52, 0x1b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x6c,
0x6c, 0x6f, 0x77, 0x65, 0x64, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x88, 0x01,
0x01, 0x1a, 0x24, 0x0a, 0x0a, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x12,
0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52,
0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a,
0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12,
0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x21, 0x0a, 0x1f, 0x5f, 0x61, 0x63,
0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x61, 0x6c, 0x6c, 0x6f, 0x77,
0x65, 0x64, 0x5f, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0x97, 0x01, 0x0a,
0x07, 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x70, 0x72, 0x6f, 0x76,
0x69, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x70,
0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x49, 0x64, 0x12, 0x19, 0x0a, 0x08, 0x69, 0x64, 0x5f,
0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x69, 0x64, 0x54,
0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x74, 0x6f,
0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x6f, 0x61, 0x75, 0x74, 0x68,
0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x18,
0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06,
0x63, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f,
0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f,
0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@ -254,21 +320,23 @@ func file_identity_proto_rawDescGZIP() []byte {
return file_identity_proto_rawDescData
}
var file_identity_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
var file_identity_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_identity_proto_goTypes = []any{
(*Provider)(nil), // 0: pomerium.identity.Provider
(*Profile)(nil), // 1: pomerium.identity.Profile
nil, // 2: pomerium.identity.Provider.RequestParamsEntry
(*structpb.Struct)(nil), // 3: google.protobuf.Struct
(*Provider_StringList)(nil), // 2: pomerium.identity.Provider.StringList
nil, // 3: pomerium.identity.Provider.RequestParamsEntry
(*structpb.Struct)(nil), // 4: google.protobuf.Struct
}
var file_identity_proto_depIdxs = []int32{
2, // 0: pomerium.identity.Provider.request_params:type_name -> pomerium.identity.Provider.RequestParamsEntry
3, // 1: pomerium.identity.Profile.claims:type_name -> google.protobuf.Struct
2, // [2:2] is the sub-list for method output_type
2, // [2:2] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
3, // 0: pomerium.identity.Provider.request_params:type_name -> pomerium.identity.Provider.RequestParamsEntry
2, // 1: pomerium.identity.Provider.access_token_allowed_audiences:type_name -> pomerium.identity.Provider.StringList
4, // 2: pomerium.identity.Profile.claims:type_name -> google.protobuf.Struct
3, // [3:3] is the sub-list for method output_type
3, // [3:3] is the sub-list for method input_type
3, // [3:3] is the sub-list for extension type_name
3, // [3:3] is the sub-list for extension extendee
0, // [0:3] is the sub-list for field type_name
}
func init() { file_identity_proto_init() }
@ -301,14 +369,27 @@ func file_identity_proto_init() {
return nil
}
}
file_identity_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*Provider_StringList); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_identity_proto_msgTypes[0].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_identity_proto_rawDesc,
NumEnums: 0,
NumMessages: 3,
NumMessages: 4,
NumExtensions: 0,
NumServices: 0,
},

View file

@ -6,6 +6,7 @@ option go_package = "github.com/pomerium/pomerium/pkg/grpc/identity";
import "google/protobuf/struct.proto";
message Provider {
message StringList { repeated string values = 1; }
string id = 1;
string authenticate_service_url = 9;
string client_id = 2;
@ -15,6 +16,7 @@ message Provider {
// string service_account = 6;
string url = 7;
map<string, string> request_params = 8;
optional StringList access_token_allowed_audiences = 10;
}
message Profile {

9
pkg/identity/errors.go Normal file
View file

@ -0,0 +1,9 @@
package identity
import "github.com/pomerium/pomerium/pkg/identity/identity"
// re-exported errors
var (
ErrVerifyAccessTokenNotSupported = identity.ErrVerifyAccessTokenNotSupported
ErrVerifyIdentityTokenNotSupported = identity.ErrVerifyIdentityTokenNotSupported
)

View file

@ -0,0 +1,9 @@
package identity
import "errors"
// well known errors
var (
ErrVerifyAccessTokenNotSupported = errors.New("identity: access token verification not supported")
ErrVerifyIdentityTokenNotSupported = errors.New("identity: identity token verification not supported")
)

View file

@ -2,6 +2,7 @@ package identity
import (
"context"
"fmt"
"net/http"
"golang.org/x/oauth2"
@ -55,3 +56,13 @@ func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ s
func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error {
return mp.SignInError
}
// VerifyAccessToken verifies an access token.
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyAccessToken not implemented")
}
// VerifyIdentityToken verifies an identity token.
func (mp MockProvider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyIdentityToken not implemented")
}

View file

@ -182,3 +182,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyAccessTokenNotSupported
}
// VerifyIdentityToken verifies an identity token.
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyIdentityTokenNotSupported
}

View file

@ -256,3 +256,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyAccessTokenNotSupported
}
// VerifyIdentityToken verifies an identity token.
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyIdentityTokenNotSupported
}

View file

@ -3,7 +3,9 @@
// authorization with Bearer JWT.
package oauth
import "net/url"
import (
"net/url"
)
// Options contains the fields required for an OAuth 2.0 (inc. OIDC) auth flow.
//
@ -29,4 +31,7 @@ type Options struct {
// AuthCodeOptions specifies additional key value pairs query params to add
// to the request flow signin url.
AuthCodeOptions map[string]string
// When set validates the audience in access tokens.
AccessTokenAllowedAudiences *[]string
}

View file

@ -10,10 +10,14 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strings"
go_oidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/pkg/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
)
@ -37,11 +41,13 @@ var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"}
// Provider is an Azure implementation of the Authenticator interface.
type Provider struct {
*pom_oidc.Provider
accessTokenAllowedAudiences *[]string
}
// New instantiates an OpenID Connect (OIDC) provider for Azure.
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
var p Provider
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
var err error
if o.ProviderURL == "" {
o.ProviderURL = defaultProviderURL
@ -73,6 +79,59 @@ func (p *Provider) Name() string {
return Name
}
// VerifyAccessToken verifies a raw access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
pp, err := p.GetProvider()
if err != nil {
return nil, fmt.Errorf("error getting oidc provider: %w", err)
}
// azure access tokens are JWTs signed with the same keys as identity tokens
verifier := pp.Verifier(&go_oidc.Config{
SkipClientIDCheck: true,
SkipIssuerCheck: true, // checked later
})
token, err := verifier.Verify(ctx, rawAccessToken)
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
}
claims = jwtutil.Claims(map[string]any{})
err = token.Claims(&claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
}
// verify audience
if p.accessTokenAllowedAudiences != nil {
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
}
}
err = verifyIssuer(pp, claims)
if err != nil {
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
}
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
TokenType: "Bearer",
AccessToken: rawAccessToken,
}))
if err != nil {
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
}
err = userInfo.Claims(claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
}
}
return claims, nil
}
// newProvider overrides the default round tripper for well-known endpoint call that happens
// on new provider registration.
// By default, the "common" (both public and private domains) responds with
@ -128,3 +187,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res
res.Body = io.NopCloser(bytes.NewReader(bs))
return res, nil
}
const (
v1IssuerPrefix = "https://sts.windows.net/"
v1IssuerSuffix = "/"
v2IssuerPrefix = "https://login.microsoftonline.com/"
v2IssuerSuffix = "/v2.0"
)
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
if !ok {
return fmt.Errorf("failed to find tenant id")
}
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing issuer claim")
}
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
return fmt.Errorf("invalid issuer: %s", iss)
}
return nil
}
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
// URLs look like:
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
// Or:
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
path, ok := strings.CutPrefix(rawTokenURL, prefix)
if !ok {
continue
}
idx := strings.Index(path, "/")
if idx <= 0 {
continue
}
rawTenantID := path[:idx]
if _, err := uuid.Parse(rawTenantID); err != nil {
continue
}
return rawTenantID, true
}
return "", false
}

View file

@ -2,15 +2,27 @@ package azure
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/identity/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
)
func TestAuthCodeOptions(t *testing.T) {
t.Parallel()
var options oauth.Options
p, err := New(context.Background(), &options)
require.NoError(t, err)
@ -21,3 +33,101 @@ func TestAuthCodeOptions(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, map[string]string{}, p.AuthCodeOptions)
}
func TestVerifyAccessToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
jwtSigner, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, nil)
require.NoError(t, err)
iat := time.Now().Unix()
exp := iat + 3600
rawAccessToken1, err := jwt.Signed(jwtSigner).Claims(map[string]any{
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
"aud": "https://client.example.com",
"sub": "subject",
"exp": exp,
"iat": iat,
}).CompactSerialize()
require.NoError(t, err)
rawAccessToken2, err := jwt.Signed(jwtSigner).Claims(map[string]any{
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
"aud": "https://unexpected.example.com",
"sub": "subject",
"exp": exp,
"iat": iat,
}).CompactSerialize()
require.NoError(t, err)
var srvURL string
mux := http.NewServeMux()
mux.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]any{
"issuer": srvURL,
"authorization_endpoint": srvURL + "/auth",
"token_endpoint": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/token",
"jwks_uri": srvURL + "/keys",
"id_token_signing_alg_values_supported": []any{"RS256"},
})
})
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{Key: privateKey.Public(), Use: "sig", Algorithm: "RS256"},
},
})
})
srv := httptest.NewServer(mux)
srvURL = srv.URL
audiences := []string{"https://other.example.com", "https://client.example.com"}
p, err := New(ctx, &oauth.Options{
ProviderName: Name,
ProviderURL: srv.URL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
AccessTokenAllowedAudiences: &audiences,
})
require.NoError(t, err)
claims, err := p.VerifyAccessToken(ctx, rawAccessToken1)
require.NoError(t, err)
delete(claims, "iat")
delete(claims, "exp")
assert.Equal(t, map[string]any{
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
"aud": "https://client.example.com",
"sub": "subject",
}, claims)
_, err = p.VerifyAccessToken(ctx, rawAccessToken2)
assert.ErrorContains(t, err, "invalid audience")
}
func TestVerifyIdentityToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
p, err := New(ctx, &oauth.Options{
ProviderName: Name,
ProviderURL: srv.URL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
claims, err := p.VerifyIdentityToken(ctx, "RAW IDENTITY TOKEN")
assert.ErrorIs(t, identity.ErrVerifyIdentityTokenNotSupported, err)
assert.Nil(t, claims)
}

View file

@ -360,3 +360,13 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound)
return nil
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyAccessTokenNotSupported
}
// VerifyIdentityToken verifies an identity token.
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyIdentityTokenNotSupported
}

View file

@ -8,6 +8,7 @@ import (
"net/http"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/pkg/identity/identity"
@ -23,7 +24,6 @@ import (
"github.com/pomerium/pomerium/pkg/identity/oidc/okta"
"github.com/pomerium/pomerium/pkg/identity/oidc/onelogin"
"github.com/pomerium/pomerium/pkg/identity/oidc/ping"
oteltrace "go.opentelemetry.io/otel/trace"
)
// State is the identity state.
@ -36,6 +36,8 @@ type Authenticator interface {
Revoke(context.Context, *oauth2.Token) error
Name() string
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v any) error
VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error)
VerifyIdentityToken(ctx context.Context, rawIdentityToken string) (claims map[string]any, err error)
SignIn(w http.ResponseWriter, r *http.Request, state string) error
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error