mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
wip
This commit is contained in:
parent
a8650b1749
commit
2c0bd9e434
8 changed files with 134 additions and 20 deletions
|
@ -44,6 +44,12 @@ func (a *Authenticate) Handler() http.Handler {
|
|||
func (a *Authenticate) Mount(r *mux.Router) {
|
||||
r.StrictSlash(true)
|
||||
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
||||
r.Use(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r = csrf.UnsafeSkipCheck(r)
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Use(func(h http.Handler) http.Handler {
|
||||
options := a.options.Load()
|
||||
state := a.state.Load()
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
)
|
||||
|
||||
type sessionOrServiceAccount interface {
|
||||
GetId() string
|
||||
GetUserId() string
|
||||
Validate() error
|
||||
}
|
||||
|
|
|
@ -45,10 +45,12 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||
ctx = requestid.WithValue(ctx, requestID)
|
||||
|
||||
var sessionID string
|
||||
var s sessionOrServiceAccount
|
||||
var u *user.User
|
||||
if sess, err := a.state.Load().idpTokensLoader.LoadSession(hreq); err == nil {
|
||||
s = sess
|
||||
sessionID = sess.GetId()
|
||||
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
|
||||
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("error verifying idp tokens")
|
||||
} else {
|
||||
|
@ -60,6 +62,8 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
return nil, err
|
||||
} else if err != nil {
|
||||
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
|
||||
} else {
|
||||
sessionID = sessionState.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +71,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
||||
}
|
||||
|
||||
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, s.GetId())
|
||||
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionID)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
|
||||
return nil, err
|
||||
|
|
|
@ -90,7 +90,7 @@ func newAuthorizeStateFromConfig(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
state.idpTokensLoader = idptokens.NewLoader(cfg.Options, state.dataBrokerClient)
|
||||
state.idpTokensLoader = idptokens.NewLoader(cfg, state.dataBrokerClient)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
|
|
@ -215,7 +215,7 @@ func (cfg *Config) GetCertificatePool() (*x509.CertPool, error) {
|
|||
|
||||
// GetAuthenticateKeyFetcher returns a key fetcher for the authenticate service
|
||||
func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) {
|
||||
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||
authenticateURL, transport, err := cfg.ResolveAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -225,7 +225,9 @@ func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) {
|
|||
return hpke.NewKeyFetcher(hpkeURL, transport), nil
|
||||
}
|
||||
|
||||
func (cfg *Config) resolveAuthenticateURL() (*url.URL, *http.Transport, error) {
|
||||
// ResolveAuthenticateURL resolves the authenticate service URL and returns a transport suitable
|
||||
// for accessing the authenticate service.
|
||||
func (cfg *Config) ResolveAuthenticateURL() (*url.URL, *http.Transport, error) {
|
||||
authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid authenticate service url: %w", err)
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
)
|
||||
|
||||
// endpoints
|
||||
|
@ -39,11 +39,11 @@ type VerifyIdentityTokenRequest struct {
|
|||
|
||||
func apiVerifyAccessToken(
|
||||
ctx context.Context,
|
||||
authenticateServiceURL string,
|
||||
cfg *config.Config,
|
||||
request *VerifyAccessTokenRequest,
|
||||
) (*VerifyTokenResponse, error) {
|
||||
var response VerifyTokenResponse
|
||||
err := api(ctx, authenticateServiceURL, "verify-access-token", request, &response)
|
||||
err := api(ctx, cfg, "verify-access-token", request, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -52,11 +52,11 @@ func apiVerifyAccessToken(
|
|||
|
||||
func apiVerifyIdentityToken(
|
||||
ctx context.Context,
|
||||
authenticateServiceURL string,
|
||||
cfg *config.Config,
|
||||
request *VerifyIdentityTokenRequest,
|
||||
) (*VerifyTokenResponse, error) {
|
||||
var response VerifyTokenResponse
|
||||
err := api(ctx, authenticateServiceURL, "verify-identity-token", request, &response)
|
||||
err := api(ctx, cfg, "verify-identity-token", request, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -65,15 +65,15 @@ func apiVerifyIdentityToken(
|
|||
|
||||
func api(
|
||||
ctx context.Context,
|
||||
authenticateServiceURL string,
|
||||
cfg *config.Config,
|
||||
endpoint string,
|
||||
request, response any,
|
||||
) error {
|
||||
u, err := urlutil.ParseAndValidateURL(authenticateServiceURL)
|
||||
authenticateURL, transport, err := cfg.ResolveAuthenticateURL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid authenticate service url: %w", err)
|
||||
}
|
||||
u = u.ResolveReference(&url.URL{
|
||||
u := authenticateURL.ResolveReference(&url.URL{
|
||||
Path: "/.pomerium/" + endpoint,
|
||||
})
|
||||
|
||||
|
@ -87,7 +87,9 @@ func api(
|
|||
return fmt.Errorf("error creating %s http request: %w", endpoint, err)
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
res, err := (&http.Client{
|
||||
Transport: transport,
|
||||
}).Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error executing %s http request: %w", endpoint, err)
|
||||
}
|
||||
|
|
|
@ -26,14 +26,14 @@ var (
|
|||
|
||||
// A Loader loads sessions from IdP access and identity tokens.
|
||||
type Loader struct {
|
||||
options *config.Options
|
||||
cfg *config.Config
|
||||
dataBrokerServiceClient databroker.DataBrokerServiceClient
|
||||
}
|
||||
|
||||
// NewLoader creates a new Loader.
|
||||
func NewLoader(options *config.Options, dataBrokerServiceClient databroker.DataBrokerServiceClient) *Loader {
|
||||
func NewLoader(cfg *config.Config, dataBrokerServiceClient databroker.DataBrokerServiceClient) *Loader {
|
||||
return &Loader{
|
||||
options: options,
|
||||
cfg: cfg,
|
||||
dataBrokerServiceClient: dataBrokerServiceClient,
|
||||
}
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func NewLoader(options *config.Options, dataBrokerServiceClient databroker.DataB
|
|||
func (l *Loader) LoadSession(r *http.Request) (*session.Session, error) {
|
||||
ctx := r.Context()
|
||||
|
||||
idp, err := l.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
||||
idp, err := l.cfg.Options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ func (l *Loader) loadSessionFromAccessToken(ctx context.Context, idp *identity.P
|
|||
return nil, err
|
||||
}
|
||||
|
||||
res, err := apiVerifyAccessToken(ctx, idp.GetAuthenticateServiceUrl(), &VerifyAccessTokenRequest{
|
||||
res, err := apiVerifyAccessToken(ctx, l.cfg, &VerifyAccessTokenRequest{
|
||||
AccessToken: rawAccessToken,
|
||||
IdentityProviderID: idp.GetId(),
|
||||
})
|
||||
|
@ -121,7 +121,7 @@ func (l *Loader) loadSessionFromIdentityToken(ctx context.Context, idp *identity
|
|||
return nil, err
|
||||
}
|
||||
|
||||
res, err := apiVerifyIdentityToken(ctx, idp.GetAuthenticateServiceUrl(), &VerifyIdentityTokenRequest{
|
||||
res, err := apiVerifyIdentityToken(ctx, l.cfg, &VerifyIdentityTokenRequest{
|
||||
IdentityToken: rawIdentityToken,
|
||||
IdentityProviderID: idp.GetId(),
|
||||
})
|
||||
|
|
|
@ -10,8 +10,11 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
go_oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
|
@ -73,6 +76,52 @@ func (p *Provider) Name() string {
|
|||
return Name
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies a raw access token using the oidc UserInfo endpoint.
|
||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting oidc provider: %w", err)
|
||||
}
|
||||
|
||||
verifier := pp.Verifier(&go_oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
SkipIssuerCheck: true, // checked later
|
||||
})
|
||||
|
||||
token, err := verifier.Verify(ctx, rawAccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||
}
|
||||
|
||||
claims = map[string]any{}
|
||||
err = token.Claims(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
||||
}
|
||||
|
||||
err = verifyIssuer(pp, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
|
||||
}
|
||||
|
||||
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
|
||||
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
|
||||
}
|
||||
|
||||
err = userInfo.Claims(claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// newProvider overrides the default round tripper for well-known endpoint call that happens
|
||||
// on new provider registration.
|
||||
// By default, the "common" (both public and private domains) responds with
|
||||
|
@ -128,3 +177,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res
|
|||
res.Body = io.NopCloser(bytes.NewReader(bs))
|
||||
return res, nil
|
||||
}
|
||||
|
||||
const (
|
||||
v1IssuerPrefix = "https://sts.windows.net/"
|
||||
v1IssuerSuffix = "/"
|
||||
v2IssuerPrefix = "https://login.microsoftonline.com/"
|
||||
v2IssuerSuffix = "/v2.0"
|
||||
)
|
||||
|
||||
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
|
||||
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to find tenant id")
|
||||
}
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing issuer claim")
|
||||
}
|
||||
|
||||
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
|
||||
return fmt.Errorf("invalid issuer: %s", iss)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
|
||||
// URLs look like:
|
||||
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
|
||||
// Or:
|
||||
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
|
||||
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
|
||||
path, ok := strings.CutPrefix(rawTokenURL, prefix)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
idx := strings.Index(path, "/")
|
||||
if idx <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rawTenantID := path[:idx]
|
||||
if _, err := uuid.Parse(rawTenantID); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return rawTenantID, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue