mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 09:56:31 +02:00
authorize: support authenticating with idp tokens (#5484)
* identity: add support for verifying access and identity tokens * allow overriding with policy option * authenticate: add verify endpoints * wip * implement session creation * add verify test * implement idp token login * fix tests * add pr permission * make session ids route-specific * rename method * add test * add access token test * test for newUserFromIDPClaims * more tests * make the session id per-idp * use type for * add test * remove nil checks
This commit is contained in:
parent
6e22b7a19a
commit
b9fd926618
36 changed files with 2791 additions and 885 deletions
1
.github/workflows/benchmark.yaml
vendored
1
.github/workflows/benchmark.yaml
vendored
|
@ -3,6 +3,7 @@ name: Benchmark
|
|||
permissions:
|
||||
contents: write
|
||||
deployments: write
|
||||
pull-requests: write
|
||||
|
||||
on:
|
||||
push:
|
||||
|
|
|
@ -43,6 +43,16 @@ func (a *Authenticate) Handler() http.Handler {
|
|||
func (a *Authenticate) Mount(r *mux.Router) {
|
||||
r.StrictSlash(true)
|
||||
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
||||
// disable csrf checking for these endpoints
|
||||
r.Use(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/.pomerium/verify-access-token" ||
|
||||
r.URL.Path == "/.pomerium/verify-identity-token" {
|
||||
r = csrf.UnsafeSkipCheck(r)
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Use(func(h http.Handler) http.Handler {
|
||||
options := a.options.Load()
|
||||
state := a.state.Load()
|
||||
|
@ -95,6 +105,8 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
|
|||
// routes that don't need a session:
|
||||
sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
|
||||
sr.Path("/signed_out").Handler(httputil.HandlerFunc(a.signedOut)).Methods(http.MethodGet)
|
||||
sr.Path("/verify-access-token").Handler(httputil.HandlerFunc(a.verifyAccessToken)).Methods(http.MethodPost)
|
||||
sr.Path("/verify-identity-token").Handler(httputil.HandlerFunc(a.verifyIdentityToken)).Methods(http.MethodPost)
|
||||
|
||||
// routes that need a session:
|
||||
sr = sr.NewRoute().Subrouter()
|
||||
|
|
76
authenticate/handlers_verify.go
Normal file
76
authenticate/handlers_verify.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||
)
|
||||
|
||||
func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) error {
|
||||
var req authenticateapi.VerifyAccessTokenRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(r.Context(), a.tracerProvider, a.options.Load(), req.IdentityProviderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var res authenticateapi.VerifyTokenResponse
|
||||
claims, err := authenticator.VerifyAccessToken(r.Context(), req.AccessToken)
|
||||
if err == nil {
|
||||
res.Valid = true
|
||||
res.Claims = claims
|
||||
} else {
|
||||
res.Valid = false
|
||||
log.Ctx(r.Context()).Info().
|
||||
Err(err).
|
||||
Str("idp", authenticator.Name()).
|
||||
Msg("access token failed verification")
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(&res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Request) error {
|
||||
var req authenticateapi.VerifyIdentityTokenRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(r.Context(), a.tracerProvider, a.options.Load(), req.IdentityProviderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var res authenticateapi.VerifyTokenResponse
|
||||
claims, err := authenticator.VerifyIdentityToken(r.Context(), req.IdentityToken)
|
||||
if err == nil {
|
||||
res.Valid = true
|
||||
res.Claims = claims
|
||||
} else {
|
||||
res.Valid = false
|
||||
log.Ctx(r.Context()).Info().
|
||||
Err(err).
|
||||
Str("idp", authenticator.Name()).
|
||||
Msg("identity token failed verification")
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(&res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
45
authenticate/handlers_verify_test.go
Normal file
45
authenticate/handlers_verify_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package authenticate_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
a, err := authenticate.New(ctx, &config.Config{
|
||||
Options: &config.Options{
|
||||
CookieSecret: cryptutil.NewBase64Key(),
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
AuthenticateCallbackPath: "/oauth2/callback",
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
|
||||
Provider: "oidc",
|
||||
ProviderURL: "http://oidc.example.com",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://authenticate.example.com/.pomerium/verify-access-token",
|
||||
strings.NewReader(`{"accessToken":"ACCESS TOKEN"}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Handler().ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.JSONEq(t, `{"valid":false}`, w.Body.String())
|
||||
}
|
|
@ -3,11 +3,12 @@ package authenticate
|
|||
import (
|
||||
"context"
|
||||
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.TracerProvider, options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
|
@ -26,7 +27,8 @@ func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.Tr
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identity.NewAuthenticator(ctx, tracerProvider, oauth.Options{
|
||||
|
||||
o := oauth.Options{
|
||||
RedirectURL: redirectURL,
|
||||
ProviderName: idp.GetType(),
|
||||
ProviderURL: idp.GetUrl(),
|
||||
|
@ -34,5 +36,9 @@ func defaultGetIdentityProvider(ctx context.Context, tracerProvider oteltrace.Tr
|
|||
ClientSecret: idp.GetClientSecret(),
|
||||
Scopes: idp.GetScopes(),
|
||||
AuthCodeOptions: idp.GetRequestParams(),
|
||||
})
|
||||
}
|
||||
if v := idp.GetAccessTokenAllowedAudiences(); v != nil {
|
||||
o.AccessTokenAllowedAudiences = &v.Values
|
||||
}
|
||||
return identity.NewAuthenticator(ctx, tracerProvider, o)
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ import (
|
|||
type Authorize struct {
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
groupsCacheWarmer *cacheWarmer
|
||||
|
@ -43,7 +43,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
|
||||
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
||||
a := &Authorize{
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
||||
store: store.New(),
|
||||
globalCache: storage.NewGlobalCache(time.Minute),
|
||||
tracerProvider: tracerProvider,
|
||||
|
@ -155,7 +155,7 @@ func newPolicyEvaluator(
|
|||
// OnConfigChange updates internal structures based on config.Options
|
||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
currentState := a.state.Load()
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
a.currentConfig.Store(cfg)
|
||||
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
|
|
|
@ -186,7 +186,7 @@ func (a *Authorize) deniedResponse(
|
|||
Err: errors.New(reason),
|
||||
DebugURL: debugEndpoint,
|
||||
RequestID: requestid.FromContext(ctx),
|
||||
BrandingOptions: a.currentOptions.Load().BrandingOptions,
|
||||
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
|
||||
}
|
||||
httpErr.ErrorResponse(ctx, w, r)
|
||||
|
||||
|
@ -213,7 +213,7 @@ func (a *Authorize) requireLoginResponse(
|
|||
in *envoy_service_auth_v3.CheckRequest,
|
||||
request *evaluator.Request,
|
||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||
options := a.currentOptions.Load()
|
||||
options := a.currentConfig.Load().Options
|
||||
state := a.state.Load()
|
||||
|
||||
if !a.shouldRedirect(in) {
|
||||
|
@ -251,7 +251,7 @@ func (a *Authorize) requireWebAuthnResponse(
|
|||
request *evaluator.Request,
|
||||
result *evaluator.Result,
|
||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||
opts := a.currentOptions.Load()
|
||||
opts := a.currentConfig.Load().Options
|
||||
state := a.state.Load()
|
||||
|
||||
// always assume https scheme
|
||||
|
@ -327,7 +327,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
|
|||
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
|
||||
// session that lives on the authenticate service.
|
||||
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
|
||||
opts := a.currentOptions.Load()
|
||||
opts := a.currentConfig.Load().Options
|
||||
authenticateURL, err := opts.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -127,8 +127,9 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}},
|
||||
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
||||
}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(opt)
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: opt,
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.store = store.New()
|
||||
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
|
||||
require.NoError(t, err)
|
||||
|
@ -183,15 +184,16 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
func TestAuthorize_deniedResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: &config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
},
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -3,15 +3,17 @@ package authorize
|
|||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type sessionOrServiceAccount interface {
|
||||
GetId() string
|
||||
GetUserId() string
|
||||
Validate() error
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package authorize
|
|||
import (
|
||||
"context"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -21,6 +22,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||
|
@ -44,31 +46,25 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||
ctx = requestid.WithValue(ctx, requestID)
|
||||
|
||||
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
||||
|
||||
var s sessionOrServiceAccount
|
||||
var u *user.User
|
||||
var err error
|
||||
if sessionState != nil {
|
||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||
if status.Code(err) == codes.Unavailable {
|
||||
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
|
||||
sessionState = nil
|
||||
}
|
||||
}
|
||||
if sessionState != nil && s != nil {
|
||||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
||||
}
|
||||
|
||||
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState)
|
||||
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// load the session
|
||||
s, err := a.loadSession(ctx, hreq, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if there's a session or service account, load the user
|
||||
var u *user.User
|
||||
if s != nil {
|
||||
req.Session.ID = s.GetId()
|
||||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
||||
}
|
||||
|
||||
res, err := state.evaluator.Evaluate(ctx, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
|
||||
|
@ -88,10 +84,65 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
return resp, err
|
||||
}
|
||||
|
||||
func (a *Authorize) loadSession(
|
||||
ctx context.Context,
|
||||
hreq *http.Request,
|
||||
req *evaluator.Request,
|
||||
) (s sessionOrServiceAccount, err error) {
|
||||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||
|
||||
// attempt to create a session from an incoming idp token
|
||||
s, err = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
return getDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// invalidate cache
|
||||
for _, record := range records {
|
||||
storage.GetQuerier(ctx).InvalidateCache(ctx, &databroker.QueryRequest{
|
||||
Type: record.GetType(),
|
||||
Query: record.GetId(),
|
||||
Limit: 1,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
||||
if err == nil {
|
||||
return s, nil
|
||||
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
|
||||
log.Ctx(ctx).Info().
|
||||
Str("request-id", requestID).
|
||||
Err(err).
|
||||
Msg("error creating session for incoming idp token")
|
||||
}
|
||||
|
||||
sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
||||
if sessionState == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||
if status.Code(err) == codes.Unavailable {
|
||||
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
||||
ctx context.Context,
|
||||
in *envoy_service_auth_v3.CheckRequest,
|
||||
sessionState *sessions.State,
|
||||
) (*evaluator.Request, error) {
|
||||
requestURL := getCheckRequestURL(in)
|
||||
attrs := in.GetAttributes()
|
||||
|
@ -106,17 +157,12 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
|||
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
|
||||
),
|
||||
}
|
||||
if sessionState != nil {
|
||||
req.Session = evaluator.RequestSession{
|
||||
ID: sessionState.ID,
|
||||
}
|
||||
}
|
||||
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
|
||||
options := a.currentOptions.Load()
|
||||
options := a.currentConfig.Load().Options
|
||||
|
||||
for p := range options.GetAllPolicies() {
|
||||
id, _ := p.RouteID()
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
|
@ -49,15 +48,16 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
|||
-----END CERTIFICATE-----`
|
||||
|
||||
func Test_getEvaluatorRequest(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: &config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
},
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -88,16 +88,10 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
&sessions.State{
|
||||
ID: "SESSION_ID",
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Session: evaluator.RequestSession{
|
||||
ID: "SESSION_ID",
|
||||
},
|
||||
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||
HTTP: evaluator.NewRequestHTTP(
|
||||
http.MethodGet,
|
||||
mustParseURL("http://example.com/some/path?qs=1"),
|
||||
|
@ -117,15 +111,16 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: &config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
},
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -145,10 +140,10 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||
Session: evaluator.RequestSession{},
|
||||
HTTP: evaluator.NewRequestHTTP(
|
||||
http.MethodGet,
|
||||
|
|
|
@ -31,7 +31,7 @@ func (a *Authorize) logAuthorizeCheck(
|
|||
impersonateDetails := a.getImpersonateDetails(ctx, s)
|
||||
|
||||
evt := log.Ctx(ctx).Info().Str("service", "authorize")
|
||||
fields := a.currentOptions.Load().GetAuthorizeLogFields()
|
||||
fields := a.currentConfig.Load().Options.GetAuthorizeLogFields()
|
||||
for _, field := range fields {
|
||||
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res)
|
||||
}
|
||||
|
|
98
config/bearer_token_format.go
Normal file
98
config/bearer_token_format.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
)
|
||||
|
||||
// BearerTokenFormat specifies how bearer tokens are interepreted by Pomerium.
|
||||
type BearerTokenFormat string
|
||||
|
||||
// Bearer Token Formats
|
||||
const (
|
||||
BearerTokenFormatUnknown BearerTokenFormat = ""
|
||||
BearerTokenFormatDefault BearerTokenFormat = "default"
|
||||
BearerTokenFormatIDPAccessToken BearerTokenFormat = "idp_access_token"
|
||||
BearerTokenFormatIDPIdentityToken BearerTokenFormat = "idp_identity_token"
|
||||
)
|
||||
|
||||
// ParseBearerTokenFormat parses the BearerTokenFormat.
|
||||
func ParseBearerTokenFormat(raw string) (BearerTokenFormat, error) {
|
||||
switch BearerTokenFormat(strings.TrimSpace(strings.ToLower(raw))) {
|
||||
case BearerTokenFormatUnknown:
|
||||
return BearerTokenFormatUnknown, nil
|
||||
case BearerTokenFormatDefault:
|
||||
return BearerTokenFormatDefault, nil
|
||||
case BearerTokenFormatIDPAccessToken:
|
||||
return BearerTokenFormatIDPAccessToken, nil
|
||||
case BearerTokenFormatIDPIdentityToken:
|
||||
return BearerTokenFormatIDPIdentityToken, nil
|
||||
}
|
||||
return BearerTokenFormatUnknown, fmt.Errorf("invalid bearer token format: %s", raw)
|
||||
}
|
||||
|
||||
func BearerTokenFormatFromPB(pbBearerTokenFormat *configpb.BearerTokenFormat) *BearerTokenFormat {
|
||||
if pbBearerTokenFormat == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bearerTokenFormat := new(BearerTokenFormat)
|
||||
*bearerTokenFormat = BearerTokenFormatDefault
|
||||
|
||||
switch *pbBearerTokenFormat {
|
||||
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_UNKNOWN:
|
||||
*bearerTokenFormat = BearerTokenFormatUnknown
|
||||
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_DEFAULT:
|
||||
*bearerTokenFormat = BearerTokenFormatDefault
|
||||
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_ACCESS_TOKEN:
|
||||
*bearerTokenFormat = BearerTokenFormatIDPAccessToken
|
||||
case configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN:
|
||||
*bearerTokenFormat = BearerTokenFormatIDPIdentityToken
|
||||
}
|
||||
|
||||
return bearerTokenFormat
|
||||
}
|
||||
|
||||
// ToEnvoy converts the bearer token format into a protobuf enum.
|
||||
func (bearerTokenFormat *BearerTokenFormat) ToPB() *configpb.BearerTokenFormat {
|
||||
if bearerTokenFormat == nil {
|
||||
return nil
|
||||
}
|
||||
switch *bearerTokenFormat {
|
||||
case BearerTokenFormatUnknown:
|
||||
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_UNKNOWN.Enum()
|
||||
case BearerTokenFormatDefault:
|
||||
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_DEFAULT.Enum()
|
||||
case BearerTokenFormatIDPAccessToken:
|
||||
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_ACCESS_TOKEN.Enum()
|
||||
case BearerTokenFormatIDPIdentityToken:
|
||||
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN.Enum()
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown bearer token format: %v", bearerTokenFormat))
|
||||
}
|
||||
}
|
||||
|
||||
func decodeBearerTokenFormatHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(_, t reflect.Type, data any) (any, error) {
|
||||
if t != reflect.TypeFor[BearerTokenFormat]() {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var raw string
|
||||
err = json.Unmarshal(bs, &raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ParseBearerTokenFormat(raw)
|
||||
}
|
||||
}
|
|
@ -4,9 +4,10 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pomerium/pomerium/config/otelconfig"
|
||||
"github.com/spf13/viper"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/pomerium/pomerium/config/otelconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -37,6 +38,7 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
|||
DecodePolicyBase64Hook(),
|
||||
decodeNullBoolHookFunc(),
|
||||
decodeJWTClaimHeadersHookFunc(),
|
||||
decodeBearerTokenFormatHookFunc(),
|
||||
decodeCodecTypeHookFunc(),
|
||||
decodePPLPolicyHookFunc(),
|
||||
decodeSANMatcherHookFunc(),
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
@ -43,6 +45,11 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
Url: o.ProviderURL,
|
||||
RequestParams: o.RequestParams,
|
||||
}
|
||||
if v := o.IDPAccessTokenAllowedAudiences; v != nil {
|
||||
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
|
||||
Values: slices.Clone(*v),
|
||||
}
|
||||
}
|
||||
if policy != nil {
|
||||
if policy.IDPClientID != "" {
|
||||
idp.ClientId = policy.IDPClientID
|
||||
|
@ -50,6 +57,11 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
if policy.IDPClientSecret != "" {
|
||||
idp.ClientSecret = policy.IDPClientSecret
|
||||
}
|
||||
if v := policy.IDPAccessTokenAllowedAudiences; v != nil {
|
||||
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
|
||||
Values: slices.Clone(*v),
|
||||
}
|
||||
}
|
||||
}
|
||||
idp.Id = idp.Hash()
|
||||
return idp, nil
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -152,12 +153,13 @@ type Options struct {
|
|||
|
||||
// Identity provider configuration variables as specified by RFC6749
|
||||
// https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
|
||||
ClientID string `mapstructure:"idp_client_id" yaml:"idp_client_id,omitempty"`
|
||||
ClientSecret string `mapstructure:"idp_client_secret" yaml:"idp_client_secret,omitempty"`
|
||||
ClientSecretFile string `mapstructure:"idp_client_secret_file" yaml:"idp_client_secret_file,omitempty"`
|
||||
Provider string `mapstructure:"idp_provider" yaml:"idp_provider,omitempty"`
|
||||
ProviderURL string `mapstructure:"idp_provider_url" yaml:"idp_provider_url,omitempty"`
|
||||
Scopes []string `mapstructure:"idp_scopes" yaml:"idp_scopes,omitempty"`
|
||||
ClientID string `mapstructure:"idp_client_id" yaml:"idp_client_id,omitempty"`
|
||||
ClientSecret string `mapstructure:"idp_client_secret" yaml:"idp_client_secret,omitempty"`
|
||||
ClientSecretFile string `mapstructure:"idp_client_secret_file" yaml:"idp_client_secret_file,omitempty"`
|
||||
Provider string `mapstructure:"idp_provider" yaml:"idp_provider,omitempty"`
|
||||
ProviderURL string `mapstructure:"idp_provider_url" yaml:"idp_provider_url,omitempty"`
|
||||
Scopes []string `mapstructure:"idp_scopes" yaml:"idp_scopes,omitempty"`
|
||||
IDPAccessTokenAllowedAudiences *[]string `mapstructure:"idp_access_token_allowed_audiences" yaml:"idp_access_token_allowed_audiences,omitempty"`
|
||||
|
||||
// RequestParams are custom request params added to the signin request as
|
||||
// part of an Oauth2 code flow.
|
||||
|
@ -194,6 +196,13 @@ type Options struct {
|
|||
// List of JWT claims to insert as x-pomerium-claim-* headers on proxied requests
|
||||
JWTClaimsHeaders JWTClaimHeaders `mapstructure:"jwt_claims_headers" yaml:"jwt_claims_headers,omitempty"`
|
||||
|
||||
// BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values:
|
||||
// - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium.
|
||||
// - "idp_access_token": The Bearer token will be interpreted as an IdP access token.
|
||||
// - "idp_identity_token": The Bearer token will be interpreted as an IdP identity token.
|
||||
// When unset "default" will be used.
|
||||
BearerTokenFormat *BearerTokenFormat `mapstructure:"bearer_token_format" yaml:"bearer_token_format,omitempty"`
|
||||
|
||||
// Allowlist of group names/IDs to include in the Pomerium JWT.
|
||||
JWTGroupsFilter JWTGroupsFilter
|
||||
|
||||
|
@ -1487,6 +1496,12 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
|
|||
set(&o.ProviderURL, settings.IdpProviderUrl)
|
||||
setSlice(&o.Scopes, settings.Scopes)
|
||||
setMap(&o.RequestParams, settings.RequestParams)
|
||||
if settings.IdpAccessTokenAllowedAudiences != nil {
|
||||
values := slices.Clone(settings.IdpAccessTokenAllowedAudiences.Values)
|
||||
o.IDPAccessTokenAllowedAudiences = &values
|
||||
} else {
|
||||
o.IDPAccessTokenAllowedAudiences = nil
|
||||
}
|
||||
setSlice(&o.AuthorizeURLStrings, settings.AuthorizeServiceUrls)
|
||||
set(&o.AuthorizeInternalURLString, settings.AuthorizeInternalServiceUrl)
|
||||
set(&o.OverrideCertificateName, settings.OverrideCertificateName)
|
||||
|
@ -1495,6 +1510,7 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
|
|||
set(&o.SigningKey, settings.SigningKey)
|
||||
setMap(&o.SetResponseHeaders, settings.SetResponseHeaders)
|
||||
setMap(&o.JWTClaimsHeaders, settings.JwtClaimsHeaders)
|
||||
o.BearerTokenFormat = BearerTokenFormatFromPB(settings.BearerTokenFormat)
|
||||
if len(settings.JwtGroupsFilter) > 0 {
|
||||
o.JWTGroupsFilter = NewJWTGroupsFilter(settings.JwtGroupsFilter)
|
||||
}
|
||||
|
@ -1591,6 +1607,13 @@ func (o *Options) ToProto() *config.Config {
|
|||
copySrcToOptionalDest(&settings.IdpProviderUrl, &o.ProviderURL)
|
||||
settings.Scopes = o.Scopes
|
||||
settings.RequestParams = o.RequestParams
|
||||
if o.IDPAccessTokenAllowedAudiences != nil {
|
||||
settings.IdpAccessTokenAllowedAudiences = &config.Settings_StringList{
|
||||
Values: slices.Clone(*o.IDPAccessTokenAllowedAudiences),
|
||||
}
|
||||
} else {
|
||||
settings.IdpAccessTokenAllowedAudiences = nil
|
||||
}
|
||||
settings.AuthorizeServiceUrls = o.AuthorizeURLStrings
|
||||
copySrcToOptionalDest(&settings.AuthorizeInternalServiceUrl, &o.AuthorizeInternalURLString)
|
||||
copySrcToOptionalDest(&settings.OverrideCertificateName, &o.OverrideCertificateName)
|
||||
|
@ -1599,6 +1622,7 @@ func (o *Options) ToProto() *config.Config {
|
|||
copySrcToOptionalDest(&settings.SigningKey, valueOrFromFileBase64(o.SigningKey, o.SigningKeyFile))
|
||||
settings.SetResponseHeaders = o.SetResponseHeaders
|
||||
settings.JwtClaimsHeaders = o.JWTClaimsHeaders
|
||||
settings.BearerTokenFormat = o.BearerTokenFormat.ToPB()
|
||||
settings.JwtGroupsFilter = o.JWTGroupsFilter.ToSlice()
|
||||
copyOptionalDuration(&settings.DefaultUpstreamTimeout, o.DefaultUpstreamTimeout)
|
||||
copySrcToOptionalDest(&settings.MetricsAddress, &o.MetricsAddr)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -167,6 +168,12 @@ type Policy struct {
|
|||
// - "hostOnly" (default): Issuer strings will be the hostname of the route, with no scheme or trailing slash.
|
||||
// - "uri": Issuer strings will be a complete URI, including the scheme and ending with a trailing slash.
|
||||
JWTIssuerFormat string `mapstructure:"jwt_issuer_format" yaml:"jwt_issuer_format,omitempty"`
|
||||
// BearerTokenFormat indicates how authorization bearer tokens are interepreted. Possible values:
|
||||
// - "default": Only Bearer tokens prefixed with Pomerium- will be interpreted by Pomerium
|
||||
// - "idp_access_token": The Bearer token will be interpreted as an IdP access token.
|
||||
// - "idp_identity_token": The Bearer token will be interpreted as an IdP identity token.
|
||||
// When unset the global option will be used.
|
||||
BearerTokenFormat *BearerTokenFormat `mapstructure:"bearer_token_format" yaml:"bearer_token_format,omitempty"`
|
||||
|
||||
// Allowlist of group names/IDs to include in the Pomerium JWT.
|
||||
// This expands on any global allowlist set in the main Options.
|
||||
|
@ -186,6 +193,8 @@ type Policy struct {
|
|||
IDPClientID string `mapstructure:"idp_client_id" yaml:"idp_client_id,omitempty"`
|
||||
// IDPClientSecret is the client secret used for the identity provider.
|
||||
IDPClientSecret string `mapstructure:"idp_client_secret" yaml:"idp_client_secret,omitempty"`
|
||||
// IDPAccessTokenAllowedAudiences are the allowed audiences for idp access token validation.
|
||||
IDPAccessTokenAllowedAudiences *[]string `mapstructure:"idp_access_token_allowed_audiences" yaml:"idp_access_token_allowed_audiences,omitempty"`
|
||||
|
||||
// ShowErrorDetails indicates whether or not additional error details should be displayed.
|
||||
ShowErrorDetails bool `mapstructure:"show_error_details" yaml:"show_error_details" json:"show_error_details"`
|
||||
|
@ -332,6 +341,12 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
|
|||
TLSUpstreamServerName: pb.GetTlsUpstreamServerName(),
|
||||
UpstreamTimeout: timeout,
|
||||
}
|
||||
if pb.IdpAccessTokenAllowedAudiences != nil {
|
||||
values := slices.Clone(pb.IdpAccessTokenAllowedAudiences.Values)
|
||||
p.IDPAccessTokenAllowedAudiences = &values
|
||||
} else {
|
||||
p.IDPAccessTokenAllowedAudiences = nil
|
||||
}
|
||||
if pb.Redirect.IsSet() {
|
||||
p.Redirect = &PolicyRedirect{
|
||||
HTTPSRedirect: pb.Redirect.HttpsRedirect,
|
||||
|
@ -380,6 +395,8 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
|
|||
p.JWTIssuerFormat = "uri"
|
||||
}
|
||||
|
||||
p.BearerTokenFormat = BearerTokenFormatFromPB(pb.BearerTokenFormat)
|
||||
|
||||
for _, rwh := range pb.RewriteResponseHeaders {
|
||||
p.RewriteResponseHeaders = append(p.RewriteResponseHeaders, RewriteHeader{
|
||||
Header: rwh.GetHeader(),
|
||||
|
@ -505,6 +522,13 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
|
|||
if p.IDPClientSecret != "" {
|
||||
pb.IdpClientSecret = proto.String(p.IDPClientSecret)
|
||||
}
|
||||
if p.IDPAccessTokenAllowedAudiences != nil {
|
||||
pb.IdpAccessTokenAllowedAudiences = &configpb.Route_StringList{
|
||||
Values: slices.Clone(*p.IDPAccessTokenAllowedAudiences),
|
||||
}
|
||||
} else {
|
||||
pb.IdpAccessTokenAllowedAudiences = nil
|
||||
}
|
||||
if p.Redirect != nil {
|
||||
pb.Redirect = &configpb.RouteRedirect{
|
||||
HttpsRedirect: p.Redirect.HTTPSRedirect,
|
||||
|
@ -538,6 +562,8 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
|
|||
pb.JwtIssuerFormat = configpb.IssuerFormat_IssuerURI
|
||||
}
|
||||
|
||||
pb.BearerTokenFormat = p.BearerTokenFormat.ToPB()
|
||||
|
||||
for _, rwh := range p.RewriteResponseHeaders {
|
||||
pb.RewriteResponseHeaders = append(pb.RewriteResponseHeaders, &configpb.RouteRewriteHeader{
|
||||
Header: rwh.Header,
|
||||
|
|
|
@ -1,16 +1,33 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// A SessionStore saves and loads sessions based on the options.
|
||||
|
@ -112,3 +129,310 @@ func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessio
|
|||
func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v any) error {
|
||||
return store.store.SaveSession(w, r, v)
|
||||
}
|
||||
|
||||
type IncomingIDPTokenSessionCreator interface {
|
||||
CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error)
|
||||
}
|
||||
|
||||
type incomingIDPTokenSessionCreator struct {
|
||||
timeNow func() time.Time
|
||||
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
|
||||
putRecords func(ctx context.Context, records []*databroker.Record) error
|
||||
}
|
||||
|
||||
func NewIncomingIDPTokenSessionCreator(
|
||||
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
|
||||
putRecords func(ctx context.Context, records []*databroker.Record) error,
|
||||
) IncomingIDPTokenSessionCreator {
|
||||
return &incomingIDPTokenSessionCreator{timeNow: time.Now, getRecord: getRecord, putRecords: putRecords}
|
||||
}
|
||||
|
||||
// CreateSession attempts to create a session for incoming idp access and
|
||||
// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned.
|
||||
// If the tokens are not valid an error will be returned.
|
||||
func (c *incomingIDPTokenSessionCreator) CreateSession(
|
||||
ctx context.Context,
|
||||
cfg *Config,
|
||||
policy *Policy,
|
||||
r *http.Request,
|
||||
) (session *session.Session, err error) {
|
||||
if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok {
|
||||
return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken)
|
||||
}
|
||||
|
||||
if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok {
|
||||
return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken)
|
||||
}
|
||||
|
||||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
||||
ctx context.Context,
|
||||
cfg *Config,
|
||||
policy *Policy,
|
||||
rawAccessToken string,
|
||||
) (*session.Session, error) {
|
||||
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err)
|
||||
}
|
||||
|
||||
sessionID := getAccessTokenSessionID(idp, rawAccessToken)
|
||||
s, err := c.getSession(ctx, sessionID)
|
||||
if err == nil {
|
||||
return s, nil
|
||||
} else if !storage.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
|
||||
}
|
||||
|
||||
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
|
||||
AccessToken: rawAccessToken,
|
||||
IdentityProviderID: idp.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||
} else if !res.Valid {
|
||||
return nil, fmt.Errorf("invalid access token")
|
||||
}
|
||||
|
||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||
s.OauthToken = &session.OAuthToken{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
}
|
||||
u := c.newUserFromIDPClaims(res.Claims)
|
||||
err = c.putSessionAndUser(ctx, s, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
||||
ctx context.Context,
|
||||
cfg *Config,
|
||||
policy *Policy,
|
||||
rawIdentityToken string,
|
||||
) (*session.Session, error) {
|
||||
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err)
|
||||
}
|
||||
|
||||
sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
|
||||
s, err := c.getSession(ctx, sessionID)
|
||||
if err == nil {
|
||||
return s, nil
|
||||
} else if !storage.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
|
||||
}
|
||||
|
||||
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
|
||||
IdentityToken: rawIdentityToken,
|
||||
IdentityProviderID: idp.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
||||
} else if !res.Valid {
|
||||
return nil, fmt.Errorf("invalid identity token")
|
||||
}
|
||||
|
||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||
s.SetRawIDToken(rawIdentityToken)
|
||||
u := c.newUserFromIDPClaims(res.Claims)
|
||||
err = c.putSessionAndUser(ctx, s, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
||||
cfg *Config,
|
||||
sessionID string,
|
||||
claims jwtutil.Claims,
|
||||
) *session.Session {
|
||||
now := c.timeNow()
|
||||
s := new(session.Session)
|
||||
s.Id = sessionID
|
||||
if userID, ok := claims.GetUserID(); ok {
|
||||
s.UserId = userID
|
||||
}
|
||||
if issuedAt, ok := claims.GetIssuedAt(); ok {
|
||||
s.IssuedAt = timestamppb.New(issuedAt)
|
||||
} else {
|
||||
s.IssuedAt = timestamppb.New(now)
|
||||
}
|
||||
if expiresAt, ok := claims.GetExpirationTime(); ok {
|
||||
s.ExpiresAt = timestamppb.New(expiresAt)
|
||||
} else {
|
||||
s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire))
|
||||
}
|
||||
s.AccessedAt = timestamppb.New(now)
|
||||
s.AddClaims(identity.Claims(claims).Flatten())
|
||||
if aud, ok := claims.GetAudience(); ok {
|
||||
s.Audience = aud
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
|
||||
claims jwtutil.Claims,
|
||||
) *user.User {
|
||||
u := new(user.User)
|
||||
if userID, ok := claims.GetUserID(); ok {
|
||||
u.Id = userID
|
||||
}
|
||||
if name, ok := claims.GetString("name"); ok {
|
||||
u.Name = name
|
||||
}
|
||||
if email, ok := claims.GetString("email"); ok {
|
||||
u.Email = email
|
||||
}
|
||||
u.Claims = identity.Claims(claims).Flatten().ToPB()
|
||||
return u
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
|
||||
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
s, ok := msg.(*session.Session)
|
||||
if !ok {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
|
||||
var records []*databroker.Record
|
||||
if id := s.GetId(); id != "" {
|
||||
records = append(records, &databroker.Record{
|
||||
Type: grpcutil.GetTypeURL(s),
|
||||
Id: id,
|
||||
Data: protoutil.NewAny(s),
|
||||
})
|
||||
}
|
||||
if id := u.GetId(); id != "" {
|
||||
records = append(records, &databroker.Record{
|
||||
Type: grpcutil.GetTypeURL(u),
|
||||
Id: id,
|
||||
Data: protoutil.NewAny(u),
|
||||
})
|
||||
}
|
||||
return c.putRecords(ctx, records)
|
||||
}
|
||||
|
||||
// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
|
||||
func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
|
||||
bearerTokenFormat := BearerTokenFormatUnknown
|
||||
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *cfg.Options.BearerTokenFormat
|
||||
}
|
||||
if policy != nil && policy.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *policy.BearerTokenFormat
|
||||
}
|
||||
|
||||
if token := r.Header.Get(httputil.HeaderPomeriumIDPAccessToken); token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
|
||||
prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return auth[len(prefix):], true
|
||||
}
|
||||
|
||||
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-"
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return auth[len(prefix):], true
|
||||
}
|
||||
|
||||
prefix = "Bearer "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
|
||||
bearerTokenFormat == BearerTokenFormatIDPAccessToken {
|
||||
return auth[len(prefix):], true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
|
||||
func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
|
||||
bearerTokenFormat := BearerTokenFormatDefault
|
||||
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *cfg.Options.BearerTokenFormat
|
||||
}
|
||||
if policy != nil && policy.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *policy.BearerTokenFormat
|
||||
}
|
||||
|
||||
if token := r.Header.Get(httputil.HeaderPomeriumIDPIdentityToken); token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
|
||||
prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return auth[len(prefix):], true
|
||||
}
|
||||
|
||||
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-"
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return auth[len(prefix):], true
|
||||
}
|
||||
|
||||
prefix = "Bearer "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
|
||||
bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
|
||||
return auth[len(prefix):], true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
var accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d")
|
||||
|
||||
func getAccessTokenSessionID(idp *identitypb.Provider, rawAccessToken string) string {
|
||||
namespace := accessTokenUUIDNamespace
|
||||
// make the session ID per-idp settings
|
||||
if idp != nil {
|
||||
namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
|
||||
}
|
||||
return uuid.NewSHA1(namespace, []byte(rawAccessToken)).String()
|
||||
}
|
||||
|
||||
var identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17")
|
||||
|
||||
func getIdentityTokenSessionID(idp *identitypb.Provider, rawIdentityToken string) string {
|
||||
namespace := identityTokenUUIDNamespace
|
||||
// make the session ID per-idp settings
|
||||
if idp != nil {
|
||||
namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
|
||||
}
|
||||
return uuid.NewSHA1(namespace, []byte(rawIdentityToken)).String()
|
||||
}
|
||||
|
|
|
@ -1,20 +1,34 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestSessionStore_LoadSessionState(t *testing.T) {
|
||||
|
@ -164,3 +178,357 @@ func TestGetIdentityProviderDetectsChangesToAuthenticateServiceURL(t *testing.T)
|
|||
assert.NotEqual(t, idp1.GetId(), idp2.GetId(),
|
||||
"identity provider should change when authenticate service url changes")
|
||||
}
|
||||
|
||||
func Test_getTokenSessionID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "532b0a3d-b413-50a0-8c9f-e6eb340a05d3", getAccessTokenSessionID(nil, "TOKEN"))
|
||||
assert.Equal(t, "e0b8096c-54dd-5623-8098-5488f9c302db", getIdentityTokenSessionID(nil, "TOKEN"))
|
||||
assert.Equal(t, "9c99d1d0-805e-51cb-b808-772ab654268b", getAccessTokenSessionID(&identitypb.Provider{Id: "IDP1"}, "TOKEN"))
|
||||
assert.Equal(t, "0fe0e289-40bb-5ffe-b328-e290e043a652", getIdentityTokenSessionID(&identitypb.Provider{Id: "IDP1"}, "TOKEN"))
|
||||
}
|
||||
|
||||
func TestGetIncomingIDPAccessTokenForPolicy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bearerTokenFormatIDPAccessToken := BearerTokenFormatIDPAccessToken
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
globalBearerTokenFormat *BearerTokenFormat
|
||||
routeBearerTokenFormat *BearerTokenFormat
|
||||
headers http.Header
|
||||
expectedOK bool
|
||||
expectedToken string
|
||||
}{
|
||||
{
|
||||
name: "empty headers",
|
||||
expectedOK: false,
|
||||
},
|
||||
{
|
||||
name: "custom header",
|
||||
headers: http.Header{"X-Pomerium-Idp-Access-Token": {"access token via custom header"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "access token via custom header",
|
||||
},
|
||||
{
|
||||
name: "custom authorization",
|
||||
headers: http.Header{"Authorization": {"Pomerium-Idp-Access-Token access token via custom authorization"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "access token via custom authorization",
|
||||
},
|
||||
{
|
||||
name: "custom bearer",
|
||||
headers: http.Header{"Authorization": {"Bearer Pomerium-Idp-Access-Token-access token via custom bearer"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "access token via custom bearer",
|
||||
},
|
||||
{
|
||||
name: "bearer disabled",
|
||||
headers: http.Header{"Authorization": {"Bearer access token via bearer"}},
|
||||
expectedOK: false,
|
||||
},
|
||||
{
|
||||
name: "bearer enabled via options",
|
||||
globalBearerTokenFormat: &bearerTokenFormatIDPAccessToken,
|
||||
headers: http.Header{"Authorization": {"Bearer access token via bearer"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "access token via bearer",
|
||||
},
|
||||
{
|
||||
name: "bearer enabled via route",
|
||||
routeBearerTokenFormat: &bearerTokenFormatIDPAccessToken,
|
||||
headers: http.Header{"Authorization": {"Bearer access token via bearer"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "access token via bearer",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := &Config{
|
||||
Options: NewDefaultOptions(),
|
||||
}
|
||||
cfg.Options.BearerTokenFormat = tc.globalBearerTokenFormat
|
||||
|
||||
var route *Policy
|
||||
if tc.routeBearerTokenFormat != nil {
|
||||
route = &Policy{
|
||||
BearerTokenFormat: tc.routeBearerTokenFormat,
|
||||
}
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
require.NoError(t, err)
|
||||
if tc.headers != nil {
|
||||
r.Header = tc.headers
|
||||
}
|
||||
|
||||
actualToken, actualOK := cfg.GetIncomingIDPAccessTokenForPolicy(route, r)
|
||||
assert.Equal(t, tc.expectedOK, actualOK)
|
||||
assert.Equal(t, tc.expectedToken, actualToken)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIncomingIDPIdentityTokenForPolicy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bearerTokenFormatIDPIdentityToken := BearerTokenFormatIDPIdentityToken
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
globalBearerTokenFormat *BearerTokenFormat
|
||||
routeBearerTokenFormat *BearerTokenFormat
|
||||
headers http.Header
|
||||
expectedOK bool
|
||||
expectedToken string
|
||||
}{
|
||||
{
|
||||
name: "empty headers",
|
||||
expectedOK: false,
|
||||
},
|
||||
{
|
||||
name: "custom header",
|
||||
headers: http.Header{"X-Pomerium-Idp-Identity-Token": {"identity token via custom header"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "identity token via custom header",
|
||||
},
|
||||
{
|
||||
name: "custom authorization",
|
||||
headers: http.Header{"Authorization": {"Pomerium-Idp-Identity-Token identity token via custom authorization"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "identity token via custom authorization",
|
||||
},
|
||||
{
|
||||
name: "custom bearer",
|
||||
headers: http.Header{"Authorization": {"Bearer Pomerium-Idp-Identity-Token-identity token via custom bearer"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "identity token via custom bearer",
|
||||
},
|
||||
{
|
||||
name: "bearer disabled",
|
||||
headers: http.Header{"Authorization": {"Bearer identity token via bearer"}},
|
||||
expectedOK: false,
|
||||
},
|
||||
{
|
||||
name: "bearer enabled via options",
|
||||
globalBearerTokenFormat: &bearerTokenFormatIDPIdentityToken,
|
||||
headers: http.Header{"Authorization": {"Bearer identity token via bearer"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "identity token via bearer",
|
||||
},
|
||||
{
|
||||
name: "bearer enabled via route",
|
||||
routeBearerTokenFormat: &bearerTokenFormatIDPIdentityToken,
|
||||
headers: http.Header{"Authorization": {"Bearer identity token via bearer"}},
|
||||
expectedOK: true,
|
||||
expectedToken: "identity token via bearer",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := &Config{
|
||||
Options: NewDefaultOptions(),
|
||||
}
|
||||
cfg.Options.BearerTokenFormat = tc.globalBearerTokenFormat
|
||||
|
||||
var route *Policy
|
||||
if tc.routeBearerTokenFormat != nil {
|
||||
route = &Policy{
|
||||
BearerTokenFormat: tc.routeBearerTokenFormat,
|
||||
}
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
require.NoError(t, err)
|
||||
if tc.headers != nil {
|
||||
r.Header = tc.headers
|
||||
}
|
||||
|
||||
actualToken, actualOK := cfg.GetIncomingIDPIdentityTokenForPolicy(route, r)
|
||||
assert.Equal(t, tc.expectedOK, actualOK)
|
||||
assert.Equal(t, tc.expectedToken, actualToken)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newSessionFromIDPClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tm1 := time.Date(2025, 2, 18, 8, 6, 0, 0, time.UTC)
|
||||
tm2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
tm3 := tm2.Add(time.Hour)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
sessionID string
|
||||
claims jwtutil.Claims
|
||||
expect *session.Session
|
||||
}{
|
||||
{
|
||||
"empty claims", "S1",
|
||||
nil,
|
||||
&session.Session{
|
||||
Id: "S1",
|
||||
AccessedAt: timestamppb.New(tm1),
|
||||
ExpiresAt: timestamppb.New(tm1.Add(time.Hour * 14)),
|
||||
IssuedAt: timestamppb.New(tm1),
|
||||
},
|
||||
},
|
||||
{
|
||||
"full claims", "S2",
|
||||
jwtutil.Claims{
|
||||
"aud": "https://www.example.com",
|
||||
"sub": "U1",
|
||||
"iat": tm2.Unix(),
|
||||
"exp": tm3.Unix(),
|
||||
},
|
||||
&session.Session{
|
||||
Id: "S2",
|
||||
UserId: "U1",
|
||||
AccessedAt: timestamppb.New(tm1),
|
||||
ExpiresAt: timestamppb.New(tm3),
|
||||
IssuedAt: timestamppb.New(tm2),
|
||||
Audience: []string{"https://www.example.com"},
|
||||
Claims: identity.FlattenedClaims{
|
||||
"aud": {"https://www.example.com"},
|
||||
"sub": {"U1"},
|
||||
"iat": {tm2.Unix()},
|
||||
"exp": {tm3.Unix()},
|
||||
}.ToPB(),
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := &Config{Options: NewDefaultOptions()}
|
||||
c := &incomingIDPTokenSessionCreator{
|
||||
timeNow: func() time.Time { return tm1 },
|
||||
}
|
||||
actual := c.newSessionFromIDPClaims(cfg, tc.sessionID, tc.claims)
|
||||
testutil.AssertProtoEqual(t, tc.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newUserFromIDPClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
claims jwtutil.Claims
|
||||
expect *user.User
|
||||
}{
|
||||
{"empty claims", nil, &user.User{}},
|
||||
{"full claims", jwtutil.Claims{
|
||||
"sub": "USER_ID",
|
||||
"name": "NAME",
|
||||
"email": "EMAIL",
|
||||
}, &user.User{
|
||||
Id: "USER_ID",
|
||||
Name: "NAME",
|
||||
Email: "EMAIL",
|
||||
Claims: identity.FlattenedClaims{
|
||||
"sub": {"USER_ID"},
|
||||
"name": {"NAME"},
|
||||
"email": {"EMAIL"},
|
||||
}.ToPB(),
|
||||
}},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := new(incomingIDPTokenSessionCreator).newUserFromIDPClaims(tc.claims)
|
||||
testutil.AssertProtoEqual(t, tc.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("access_token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/.pomerium/verify-access-token", func(w http.ResponseWriter, _ *http.Request) {
|
||||
json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{
|
||||
Valid: true,
|
||||
Claims: jwtutil.Claims{"sub": "U1"},
|
||||
})
|
||||
})
|
||||
srv := httptest.NewTLSServer(mux)
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
cfg := &Config{Options: NewDefaultOptions()}
|
||||
cfg.Options.AuthenticateURLString = srv.URL
|
||||
cfg.Options.ClientSecret = "CLIENT_SECRET_1"
|
||||
cfg.Options.ClientID = "CLIENT_ID_1"
|
||||
route := &Policy{}
|
||||
route.IDPClientSecret = "CLIENT_SECRET_2"
|
||||
route.IDPClientID = "CLIENT_ID_2"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
|
||||
c := NewIncomingIDPTokenSessionCreator(
|
||||
func(_ context.Context, recordType, _ string) (*databroker.Record, error) {
|
||||
assert.Equal(t, "type.googleapis.com/session.Session", recordType)
|
||||
return nil, storage.ErrNotFound
|
||||
},
|
||||
func(_ context.Context, records []*databroker.Record) error {
|
||||
if assert.Len(t, records, 2, "should put session and user") {
|
||||
assert.Equal(t, "type.googleapis.com/session.Session", records[0].Type)
|
||||
assert.Equal(t, "type.googleapis.com/user.User", records[1].Type)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
s, err := c.CreateSession(ctx, cfg, route, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "U1", s.GetUserId())
|
||||
assert.Equal(t, "ACCESS_TOKEN", s.GetOauthToken().GetAccessToken())
|
||||
})
|
||||
t.Run("identity_token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/.pomerium/verify-identity-token", func(w http.ResponseWriter, _ *http.Request) {
|
||||
json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{
|
||||
Valid: true,
|
||||
Claims: jwtutil.Claims{"sub": "U1"},
|
||||
})
|
||||
})
|
||||
srv := httptest.NewTLSServer(mux)
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
cfg := &Config{Options: NewDefaultOptions()}
|
||||
cfg.Options.AuthenticateURLString = srv.URL
|
||||
cfg.Options.ClientSecret = "CLIENT_SECRET_1"
|
||||
cfg.Options.ClientID = "CLIENT_ID_1"
|
||||
route := &Policy{}
|
||||
route.IDPClientSecret = "CLIENT_SECRET_2"
|
||||
route.IDPClientID = "CLIENT_ID_2"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(httputil.HeaderPomeriumIDPIdentityToken, "IDENTITY_TOKEN")
|
||||
c := NewIncomingIDPTokenSessionCreator(
|
||||
func(_ context.Context, recordType, _ string) (*databroker.Record, error) {
|
||||
assert.Equal(t, "type.googleapis.com/session.Session", recordType)
|
||||
return nil, storage.ErrNotFound
|
||||
},
|
||||
func(_ context.Context, records []*databroker.Record) error {
|
||||
if assert.Len(t, records, 2, "should put session and user") {
|
||||
assert.Equal(t, "type.googleapis.com/session.Session", records[0].Type)
|
||||
assert.Equal(t, "type.googleapis.com/user.User", records[1].Type)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
s, err := c.CreateSession(ctx, cfg, route, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "U1", s.GetUserId())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
package httputil
|
||||
|
||||
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
|
||||
const AuthorizationTypePomerium = "Pomerium"
|
||||
// Pomerium authorization types
|
||||
const (
|
||||
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
|
||||
AuthorizationTypePomerium = "Pomerium"
|
||||
AuthorizationTypePomeriumIDPAccessToken = "Pomerium-IDP-Access-Token" //nolint: gosec
|
||||
AuthorizationTypePomeriumIDPIdentityToken = "Pomerium-IDP-Identity-Token" //nolint: gosec
|
||||
)
|
||||
|
||||
// Standard headers
|
||||
const (
|
||||
|
@ -16,7 +21,9 @@ const (
|
|||
// HeaderPomeriumAuthorization is the header key for a pomerium authorization JWT. It
|
||||
// can be used in place of the standard authorization header if that header is being
|
||||
// used by upstream applications.
|
||||
HeaderPomeriumAuthorization = "x-pomerium-authorization"
|
||||
HeaderPomeriumAuthorization = "x-pomerium-authorization"
|
||||
HeaderPomeriumIDPAccessToken = "x-pomerium-idp-access-token" //nolint: gosec
|
||||
HeaderPomeriumIDPIdentityToken = "x-pomerium-idp-identity-token" //nolint: gosec
|
||||
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
||||
// as opposed to the upstream application and can be used to distinguish
|
||||
// between an application error, and a pomerium related error when debugging.
|
||||
|
|
172
internal/jwtutil/jwtutil.go
Normal file
172
internal/jwtutil/jwtutil.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
// Package jwtutil contains functions for working with JWTs.
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claims represent claims in a JWT.
|
||||
type Claims map[string]any
|
||||
|
||||
// UnmarshalJSON implements a custom unmarshaller for claims data.
|
||||
func (claims *Claims) UnmarshalJSON(raw []byte) error {
|
||||
dst := map[string]any{}
|
||||
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||
dec.UseNumber()
|
||||
err := dec.Decode(&dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*claims = Claims(dst)
|
||||
return nil
|
||||
}
|
||||
|
||||
// registered claims
|
||||
|
||||
// GetIssuer gets the iss claim.
|
||||
func (claims Claims) GetIssuer() (issuer string, ok bool) {
|
||||
return claims.GetString("iss")
|
||||
}
|
||||
|
||||
// GetSubject gets the sub claim.
|
||||
func (claims Claims) GetSubject() (subject string, ok bool) {
|
||||
return claims.GetString("sub")
|
||||
}
|
||||
|
||||
// GetAudience gets the aud claim.
|
||||
func (claims Claims) GetAudience() (audiences []string, ok bool) {
|
||||
return claims.GetStringSlice("aud")
|
||||
}
|
||||
|
||||
// GetExpirationTime gets the exp claim.
|
||||
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
|
||||
return claims.GetNumericDate("exp")
|
||||
}
|
||||
|
||||
// GetNotBefore gets the nbf claim.
|
||||
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
|
||||
return claims.GetNumericDate("nbf")
|
||||
}
|
||||
|
||||
// GetIssuedAt gets the iat claim.
|
||||
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
|
||||
return claims.GetNumericDate("iat")
|
||||
}
|
||||
|
||||
// GetJWTID gets the jti claim.
|
||||
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
|
||||
return claims.GetString("jti")
|
||||
}
|
||||
|
||||
// custom claims
|
||||
|
||||
// GetUserID returns the oid or sub claim.
|
||||
func (claims Claims) GetUserID() (userID string, ok bool) {
|
||||
if oid, ok := claims.GetString("oid"); ok {
|
||||
return oid, true
|
||||
}
|
||||
|
||||
if sub, ok := claims.GetSubject(); ok {
|
||||
return sub, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetNumericDate returns the claim as a numeric date.
|
||||
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
|
||||
if claims == nil {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case float32:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case float64:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int64:
|
||||
return time.Unix(v, 0), true
|
||||
case int32:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int16:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int8:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint64:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint32:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint16:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint8:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case uint:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case json.Number:
|
||||
i, err := v.Int64()
|
||||
if err != nil {
|
||||
if f, err := v.Float64(); err == nil {
|
||||
i = int64(f)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return tm, false
|
||||
}
|
||||
return time.Unix(i, 0), true
|
||||
}
|
||||
|
||||
return tm, false
|
||||
}
|
||||
|
||||
// GetString returns the claim as a string.
|
||||
func (claims Claims) GetString(name string) (value string, ok bool) {
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return value, false
|
||||
}
|
||||
|
||||
return toString(raw), true
|
||||
}
|
||||
|
||||
// GetStringSlice returns the claim as a slice of strings.
|
||||
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return toStringSlice(raw), true
|
||||
}
|
||||
|
||||
func toString(data any) string {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
return v
|
||||
}
|
||||
return fmt.Sprint(data)
|
||||
}
|
||||
|
||||
func toStringSlice(obj any) []string {
|
||||
v := reflect.ValueOf(obj)
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
vs := make([]string, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
vs[i] = toString(v.Index(i).Interface())
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
return []string{toString(obj)}
|
||||
}
|
109
pkg/authenticateapi/authenticateapi.go
Normal file
109
pkg/authenticateapi/authenticateapi.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
// Package authenticateapi has the types and methods for the authenticate api.
|
||||
package authenticateapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
)
|
||||
|
||||
// VerifyAccessTokenRequest is used to verify access tokens.
|
||||
type VerifyAccessTokenRequest struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||
}
|
||||
|
||||
// VerifyIdentityTokenRequest is used to verify identity tokens.
|
||||
type VerifyIdentityTokenRequest struct {
|
||||
IdentityToken string `json:"identityToken"`
|
||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||
}
|
||||
|
||||
// VerifyTokenResponse is the result of verifying an access or identity token.
|
||||
type VerifyTokenResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
Claims jwtutil.Claims `json:"claims,omitempty"`
|
||||
}
|
||||
|
||||
// An API is an api client for the authenticate service.
|
||||
type API struct {
|
||||
authenticateURL *url.URL
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
// New creates a new API client.
|
||||
func New(
|
||||
authenticateURL *url.URL,
|
||||
transport http.RoundTripper,
|
||||
) *API {
|
||||
return &API{
|
||||
authenticateURL: authenticateURL,
|
||||
transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (api *API) VerifyAccessToken(ctx context.Context, request *VerifyAccessTokenRequest) (*VerifyTokenResponse, error) {
|
||||
var response VerifyTokenResponse
|
||||
err := api.call(ctx, "verify-access-token", request, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (api *API) VerifyIdentityToken(ctx context.Context, request *VerifyIdentityTokenRequest) (*VerifyTokenResponse, error) {
|
||||
var response VerifyTokenResponse
|
||||
err := api.call(ctx, "verify-identity-token", request, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (api *API) call(
|
||||
ctx context.Context,
|
||||
endpoint string,
|
||||
request, response any,
|
||||
) error {
|
||||
u := api.authenticateURL.ResolveReference(&url.URL{
|
||||
Path: "/.pomerium/" + endpoint,
|
||||
})
|
||||
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling %s http request: %w", endpoint, err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating %s http request: %w", endpoint, err)
|
||||
}
|
||||
|
||||
res, err := (&http.Client{
|
||||
Transport: api.transport,
|
||||
}).Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error executing %s http request: %w", endpoint, err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err = io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading %s http response: %w", endpoint, err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading %s http response (body=%s): %w", endpoint, body, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -45,8 +45,17 @@ enum IssuerFormat {
|
|||
IssuerURI = 1;
|
||||
}
|
||||
|
||||
// Next ID: 69.
|
||||
enum BearerTokenFormat {
|
||||
BEARER_TOKEN_FORMAT_UNKNOWN = 0;
|
||||
BEARER_TOKEN_FORMAT_DEFAULT = 1;
|
||||
BEARER_TOKEN_FORMAT_IDP_ACCESS_TOKEN = 2;
|
||||
BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN = 3;
|
||||
}
|
||||
|
||||
// Next ID: 71.
|
||||
message Route {
|
||||
message StringList { repeated string values = 1; }
|
||||
|
||||
string name = 1;
|
||||
string description = 67;
|
||||
string logo_url = 68;
|
||||
|
@ -116,6 +125,7 @@ message Route {
|
|||
bool enable_google_cloud_serverless_authentication = 42;
|
||||
IssuerFormat jwt_issuer_format = 65;
|
||||
repeated string jwt_groups_filter = 66;
|
||||
optional BearerTokenFormat bearer_token_format = 70;
|
||||
|
||||
envoy.config.cluster.v3.Cluster envoy_opts = 36;
|
||||
|
||||
|
@ -130,6 +140,7 @@ message Route {
|
|||
|
||||
optional string idp_client_id = 55;
|
||||
optional string idp_client_secret = 56;
|
||||
optional StringList idp_access_token_allowed_audiences = 69;
|
||||
bool show_error_details = 59;
|
||||
}
|
||||
|
||||
|
@ -149,7 +160,7 @@ message Policy {
|
|||
string remediation = 9;
|
||||
}
|
||||
|
||||
// Next ID: 137.
|
||||
// Next ID: 139.
|
||||
message Settings {
|
||||
message Certificate {
|
||||
bytes cert_bytes = 3;
|
||||
|
@ -188,6 +199,7 @@ message Settings {
|
|||
optional string idp_client_secret = 23;
|
||||
optional string idp_provider = 24;
|
||||
optional string idp_provider_url = 25;
|
||||
optional StringList idp_access_token_allowed_audiences = 137;
|
||||
repeated string scopes = 26;
|
||||
// optional string idp_service_account = 27;
|
||||
// optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
|
||||
|
@ -203,6 +215,7 @@ message Settings {
|
|||
// repeated string jwt_claims_headers = 37;
|
||||
map<string, string> jwt_claims_headers = 63;
|
||||
repeated string jwt_groups_filter = 119;
|
||||
optional BearerTokenFormat bearer_token_format = 138;
|
||||
optional google.protobuf.Duration default_upstream_timeout = 39;
|
||||
optional string metrics_address = 40;
|
||||
optional string metrics_basic_auth = 64;
|
||||
|
|
|
@ -33,8 +33,9 @@ type Provider struct {
|
|||
Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"`
|
||||
// string service_account = 6;
|
||||
Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"`
|
||||
RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"`
|
||||
RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
AccessTokenAllowedAudiences *Provider_StringList `protobuf:"bytes,10,opt,name=access_token_allowed_audiences,json=accessTokenAllowedAudiences,proto3,oneof" json:"access_token_allowed_audiences,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Provider) Reset() {
|
||||
|
@ -125,6 +126,13 @@ func (x *Provider) GetRequestParams() map[string]string {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (x *Provider) GetAccessTokenAllowedAudiences() *Provider_StringList {
|
||||
if x != nil {
|
||||
return x.AccessTokenAllowedAudiences
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
@ -196,6 +204,53 @@ func (x *Profile) GetClaims() *structpb.Struct {
|
|||
return nil
|
||||
}
|
||||
|
||||
type Provider_StringList struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Provider_StringList) Reset() {
|
||||
*x = Provider_StringList{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_identity_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Provider_StringList) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Provider_StringList) ProtoMessage() {}
|
||||
|
||||
func (x *Provider_StringList) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_identity_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Provider_StringList.ProtoReflect.Descriptor instead.
|
||||
func (*Provider_StringList) Descriptor() ([]byte, []int) {
|
||||
return file_identity_proto_rawDescGZIP(), []int{0, 0}
|
||||
}
|
||||
|
||||
func (x *Provider_StringList) GetValues() []string {
|
||||
if x != nil {
|
||||
return x.Values
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_identity_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_identity_proto_rawDesc = []byte{
|
||||
|
@ -203,7 +258,7 @@ var file_identity_proto_rawDesc = []byte{
|
|||
0x12, 0x11, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74,
|
||||
0x69, 0x74, 0x79, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x22, 0xed, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0e,
|
||||
0x6f, 0x22, 0xa8, 0x04, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0e,
|
||||
0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x38,
|
||||
0x0a, 0x18, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x73,
|
||||
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09,
|
||||
|
@ -221,25 +276,36 @@ var file_identity_proto_rawDesc = []byte{
|
|||
0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e,
|
||||
0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79,
|
||||
0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a,
|
||||
0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73,
|
||||
0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38,
|
||||
0x01, 0x22, 0x97, 0x01, 0x0a, 0x07, 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x12, 0x1f, 0x0a,
|
||||
0x0b, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x0a, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x49, 0x64, 0x12, 0x19,
|
||||
0x0a, 0x08, 0x69, 0x64, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c,
|
||||
0x52, 0x07, 0x69, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x61, 0x75,
|
||||
0x74, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a,
|
||||
0x6f, 0x61, 0x75, 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6c,
|
||||
0x61, 0x69, 0x6d, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f,
|
||||
0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72,
|
||||
0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x42, 0x30, 0x5a, 0x2e, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
|
||||
0x67, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12,
|
||||
0x70, 0x0a, 0x1e, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f,
|
||||
0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
|
||||
0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76,
|
||||
0x69, 0x64, 0x65, 0x72, 0x2e, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x48,
|
||||
0x00, 0x52, 0x1b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x6c,
|
||||
0x6c, 0x6f, 0x77, 0x65, 0x64, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x88, 0x01,
|
||||
0x01, 0x1a, 0x24, 0x0a, 0x0a, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x12,
|
||||
0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52,
|
||||
0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a,
|
||||
0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12,
|
||||
0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
|
||||
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x21, 0x0a, 0x1f, 0x5f, 0x61, 0x63,
|
||||
0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x61, 0x6c, 0x6c, 0x6f, 0x77,
|
||||
0x65, 0x64, 0x5f, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0x97, 0x01, 0x0a,
|
||||
0x07, 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x70, 0x72, 0x6f, 0x76,
|
||||
0x69, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x70,
|
||||
0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x49, 0x64, 0x12, 0x19, 0x0a, 0x08, 0x69, 0x64, 0x5f,
|
||||
0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x69, 0x64, 0x54,
|
||||
0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x74, 0x6f,
|
||||
0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x6f, 0x61, 0x75, 0x74, 0x68,
|
||||
0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x18,
|
||||
0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06,
|
||||
0x63, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f,
|
||||
0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f,
|
||||
0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -254,21 +320,23 @@ func file_identity_proto_rawDescGZIP() []byte {
|
|||
return file_identity_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_identity_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
|
||||
var file_identity_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||
var file_identity_proto_goTypes = []any{
|
||||
(*Provider)(nil), // 0: pomerium.identity.Provider
|
||||
(*Profile)(nil), // 1: pomerium.identity.Profile
|
||||
nil, // 2: pomerium.identity.Provider.RequestParamsEntry
|
||||
(*structpb.Struct)(nil), // 3: google.protobuf.Struct
|
||||
(*Provider)(nil), // 0: pomerium.identity.Provider
|
||||
(*Profile)(nil), // 1: pomerium.identity.Profile
|
||||
(*Provider_StringList)(nil), // 2: pomerium.identity.Provider.StringList
|
||||
nil, // 3: pomerium.identity.Provider.RequestParamsEntry
|
||||
(*structpb.Struct)(nil), // 4: google.protobuf.Struct
|
||||
}
|
||||
var file_identity_proto_depIdxs = []int32{
|
||||
2, // 0: pomerium.identity.Provider.request_params:type_name -> pomerium.identity.Provider.RequestParamsEntry
|
||||
3, // 1: pomerium.identity.Profile.claims:type_name -> google.protobuf.Struct
|
||||
2, // [2:2] is the sub-list for method output_type
|
||||
2, // [2:2] is the sub-list for method input_type
|
||||
2, // [2:2] is the sub-list for extension type_name
|
||||
2, // [2:2] is the sub-list for extension extendee
|
||||
0, // [0:2] is the sub-list for field type_name
|
||||
3, // 0: pomerium.identity.Provider.request_params:type_name -> pomerium.identity.Provider.RequestParamsEntry
|
||||
2, // 1: pomerium.identity.Provider.access_token_allowed_audiences:type_name -> pomerium.identity.Provider.StringList
|
||||
4, // 2: pomerium.identity.Profile.claims:type_name -> google.protobuf.Struct
|
||||
3, // [3:3] is the sub-list for method output_type
|
||||
3, // [3:3] is the sub-list for method input_type
|
||||
3, // [3:3] is the sub-list for extension type_name
|
||||
3, // [3:3] is the sub-list for extension extendee
|
||||
0, // [0:3] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_identity_proto_init() }
|
||||
|
@ -301,14 +369,27 @@ func file_identity_proto_init() {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
file_identity_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*Provider_StringList); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_identity_proto_msgTypes[0].OneofWrappers = []any{}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_identity_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 3,
|
||||
NumMessages: 4,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
|
|
|
@ -6,6 +6,7 @@ option go_package = "github.com/pomerium/pomerium/pkg/grpc/identity";
|
|||
import "google/protobuf/struct.proto";
|
||||
|
||||
message Provider {
|
||||
message StringList { repeated string values = 1; }
|
||||
string id = 1;
|
||||
string authenticate_service_url = 9;
|
||||
string client_id = 2;
|
||||
|
@ -15,6 +16,7 @@ message Provider {
|
|||
// string service_account = 6;
|
||||
string url = 7;
|
||||
map<string, string> request_params = 8;
|
||||
optional StringList access_token_allowed_audiences = 10;
|
||||
}
|
||||
|
||||
message Profile {
|
||||
|
|
9
pkg/identity/errors.go
Normal file
9
pkg/identity/errors.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package identity
|
||||
|
||||
import "github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
|
||||
// re-exported errors
|
||||
var (
|
||||
ErrVerifyAccessTokenNotSupported = identity.ErrVerifyAccessTokenNotSupported
|
||||
ErrVerifyIdentityTokenNotSupported = identity.ErrVerifyIdentityTokenNotSupported
|
||||
)
|
9
pkg/identity/identity/errors.go
Normal file
9
pkg/identity/identity/errors.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package identity
|
||||
|
||||
import "errors"
|
||||
|
||||
// well known errors
|
||||
var (
|
||||
ErrVerifyAccessTokenNotSupported = errors.New("identity: access token verification not supported")
|
||||
ErrVerifyIdentityTokenNotSupported = errors.New("identity: identity token verification not supported")
|
||||
)
|
|
@ -2,6 +2,7 @@ package identity
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -55,3 +56,13 @@ func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ s
|
|||
func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error {
|
||||
return mp.SignInError
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, fmt.Errorf("VerifyAccessToken not implemented")
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (mp MockProvider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, fmt.Errorf("VerifyIdentityToken not implemented")
|
||||
}
|
||||
|
|
|
@ -182,3 +182,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
|
|||
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
|
||||
return oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyAccessTokenNotSupported
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyIdentityTokenNotSupported
|
||||
}
|
||||
|
|
|
@ -256,3 +256,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
|
|||
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
|
||||
return oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyAccessTokenNotSupported
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyIdentityTokenNotSupported
|
||||
}
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
// authorization with Bearer JWT.
|
||||
package oauth
|
||||
|
||||
import "net/url"
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Options contains the fields required for an OAuth 2.0 (inc. OIDC) auth flow.
|
||||
//
|
||||
|
@ -29,4 +31,7 @@ type Options struct {
|
|||
// AuthCodeOptions specifies additional key value pairs query params to add
|
||||
// to the request flow signin url.
|
||||
AuthCodeOptions map[string]string
|
||||
|
||||
// When set validates the audience in access tokens.
|
||||
AccessTokenAllowedAudiences *[]string
|
||||
}
|
||||
|
|
|
@ -10,10 +10,14 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
go_oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
|
||||
)
|
||||
|
@ -37,11 +41,13 @@ var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"}
|
|||
// Provider is an Azure implementation of the Authenticator interface.
|
||||
type Provider struct {
|
||||
*pom_oidc.Provider
|
||||
accessTokenAllowedAudiences *[]string
|
||||
}
|
||||
|
||||
// New instantiates an OpenID Connect (OIDC) provider for Azure.
|
||||
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||
var p Provider
|
||||
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
|
||||
var err error
|
||||
if o.ProviderURL == "" {
|
||||
o.ProviderURL = defaultProviderURL
|
||||
|
@ -73,6 +79,59 @@ func (p *Provider) Name() string {
|
|||
return Name
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies a raw access token.
|
||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting oidc provider: %w", err)
|
||||
}
|
||||
|
||||
// azure access tokens are JWTs signed with the same keys as identity tokens
|
||||
verifier := pp.Verifier(&go_oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
SkipIssuerCheck: true, // checked later
|
||||
})
|
||||
token, err := verifier.Verify(ctx, rawAccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||
}
|
||||
|
||||
claims = jwtutil.Claims(map[string]any{})
|
||||
err = token.Claims(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
||||
}
|
||||
|
||||
// verify audience
|
||||
if p.accessTokenAllowedAudiences != nil {
|
||||
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
|
||||
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
|
||||
}
|
||||
}
|
||||
|
||||
err = verifyIssuer(pp, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
|
||||
}
|
||||
|
||||
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
|
||||
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
|
||||
}
|
||||
|
||||
err = userInfo.Claims(claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// newProvider overrides the default round tripper for well-known endpoint call that happens
|
||||
// on new provider registration.
|
||||
// By default, the "common" (both public and private domains) responds with
|
||||
|
@ -128,3 +187,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res
|
|||
res.Body = io.NopCloser(bytes.NewReader(bs))
|
||||
return res, nil
|
||||
}
|
||||
|
||||
const (
|
||||
v1IssuerPrefix = "https://sts.windows.net/"
|
||||
v1IssuerSuffix = "/"
|
||||
v2IssuerPrefix = "https://login.microsoftonline.com/"
|
||||
v2IssuerSuffix = "/v2.0"
|
||||
)
|
||||
|
||||
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
|
||||
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to find tenant id")
|
||||
}
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing issuer claim")
|
||||
}
|
||||
|
||||
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
|
||||
return fmt.Errorf("invalid issuer: %s", iss)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
|
||||
// URLs look like:
|
||||
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
|
||||
// Or:
|
||||
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
|
||||
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
|
||||
path, ok := strings.CutPrefix(rawTokenURL, prefix)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
idx := strings.Index(path, "/")
|
||||
if idx <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rawTenantID := path[:idx]
|
||||
if _, err := uuid.Parse(rawTenantID); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return rawTenantID, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
|
|
@ -2,15 +2,27 @@ package azure
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
)
|
||||
|
||||
func TestAuthCodeOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var options oauth.Options
|
||||
p, err := New(context.Background(), &options)
|
||||
require.NoError(t, err)
|
||||
|
@ -21,3 +33,101 @@ func TestAuthCodeOptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{}, p.AuthCodeOptions)
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtSigner, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, nil)
|
||||
require.NoError(t, err)
|
||||
iat := time.Now().Unix()
|
||||
exp := iat + 3600
|
||||
rawAccessToken1, err := jwt.Signed(jwtSigner).Claims(map[string]any{
|
||||
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
|
||||
"aud": "https://client.example.com",
|
||||
"sub": "subject",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
}).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
rawAccessToken2, err := jwt.Signed(jwtSigner).Claims(map[string]any{
|
||||
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
|
||||
"aud": "https://unexpected.example.com",
|
||||
"sub": "subject",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
}).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
|
||||
var srvURL string
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"issuer": srvURL,
|
||||
"authorization_endpoint": srvURL + "/auth",
|
||||
"token_endpoint": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/token",
|
||||
"jwks_uri": srvURL + "/keys",
|
||||
"id_token_signing_alg_values_supported": []any{"RS256"},
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
json.NewEncoder(w).Encode(jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{Key: privateKey.Public(), Use: "sig", Algorithm: "RS256"},
|
||||
},
|
||||
})
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
srvURL = srv.URL
|
||||
|
||||
audiences := []string{"https://other.example.com", "https://client.example.com"}
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
ProviderName: Name,
|
||||
ProviderURL: srv.URL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
AccessTokenAllowedAudiences: &audiences,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := p.VerifyAccessToken(ctx, rawAccessToken1)
|
||||
require.NoError(t, err)
|
||||
delete(claims, "iat")
|
||||
delete(claims, "exp")
|
||||
assert.Equal(t, map[string]any{
|
||||
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
|
||||
"aud": "https://client.example.com",
|
||||
"sub": "subject",
|
||||
}, claims)
|
||||
|
||||
_, err = p.VerifyAccessToken(ctx, rawAccessToken2)
|
||||
assert.ErrorContains(t, err, "invalid audience")
|
||||
}
|
||||
|
||||
func TestVerifyIdentityToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
srv := httptest.NewServer(mux)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
ProviderName: Name,
|
||||
ProviderURL: srv.URL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := p.VerifyIdentityToken(ctx, "RAW IDENTITY TOKEN")
|
||||
assert.ErrorIs(t, identity.ErrVerifyIdentityTokenNotSupported, err)
|
||||
assert.Nil(t, claims)
|
||||
}
|
||||
|
|
|
@ -360,3 +360,13 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
|
|||
httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyAccessTokenNotSupported
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyIdentityTokenNotSupported
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
|
@ -23,7 +24,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/identity/oidc/okta"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oidc/onelogin"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oidc/ping"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// State is the identity state.
|
||||
|
@ -36,6 +36,8 @@ type Authenticator interface {
|
|||
Revoke(context.Context, *oauth2.Token) error
|
||||
Name() string
|
||||
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v any) error
|
||||
VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error)
|
||||
VerifyIdentityToken(ctx context.Context, rawIdentityToken string) (claims map[string]any, err error)
|
||||
|
||||
SignIn(w http.ResponseWriter, r *http.Request, state string) error
|
||||
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
|
||||
|
|
Loading…
Add table
Reference in a new issue