mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 08:19:23 +02:00
authorize: support authenticating with idp tokens (#5484)
* identity: add support for verifying access and identity tokens * allow overriding with policy option * authenticate: add verify endpoints * wip * implement session creation * add verify test * implement idp token login * fix tests * add pr permission * make session ids route-specific * rename method * add test * add access token test * test for newUserFromIDPClaims * more tests * make the session id per-idp * use type for * add test * remove nil checks
This commit is contained in:
parent
6e22b7a19a
commit
b9fd926618
36 changed files with 2791 additions and 885 deletions
9
pkg/identity/errors.go
Normal file
9
pkg/identity/errors.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package identity
|
||||
|
||||
import "github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
|
||||
// re-exported errors
|
||||
var (
|
||||
ErrVerifyAccessTokenNotSupported = identity.ErrVerifyAccessTokenNotSupported
|
||||
ErrVerifyIdentityTokenNotSupported = identity.ErrVerifyIdentityTokenNotSupported
|
||||
)
|
9
pkg/identity/identity/errors.go
Normal file
9
pkg/identity/identity/errors.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package identity
|
||||
|
||||
import "errors"
|
||||
|
||||
// well known errors
|
||||
var (
|
||||
ErrVerifyAccessTokenNotSupported = errors.New("identity: access token verification not supported")
|
||||
ErrVerifyIdentityTokenNotSupported = errors.New("identity: identity token verification not supported")
|
||||
)
|
|
@ -2,6 +2,7 @@ package identity
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -55,3 +56,13 @@ func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ s
|
|||
func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error {
|
||||
return mp.SignInError
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, fmt.Errorf("VerifyAccessToken not implemented")
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (mp MockProvider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, fmt.Errorf("VerifyIdentityToken not implemented")
|
||||
}
|
||||
|
|
|
@ -182,3 +182,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
|
|||
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
|
||||
return oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyAccessTokenNotSupported
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyIdentityTokenNotSupported
|
||||
}
|
||||
|
|
|
@ -256,3 +256,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
|
|||
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
|
||||
return oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyAccessTokenNotSupported
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyIdentityTokenNotSupported
|
||||
}
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
// authorization with Bearer JWT.
|
||||
package oauth
|
||||
|
||||
import "net/url"
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Options contains the fields required for an OAuth 2.0 (inc. OIDC) auth flow.
|
||||
//
|
||||
|
@ -29,4 +31,7 @@ type Options struct {
|
|||
// AuthCodeOptions specifies additional key value pairs query params to add
|
||||
// to the request flow signin url.
|
||||
AuthCodeOptions map[string]string
|
||||
|
||||
// When set validates the audience in access tokens.
|
||||
AccessTokenAllowedAudiences *[]string
|
||||
}
|
||||
|
|
|
@ -10,10 +10,14 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
go_oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
|
||||
)
|
||||
|
@ -37,11 +41,13 @@ var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"}
|
|||
// Provider is an Azure implementation of the Authenticator interface.
|
||||
type Provider struct {
|
||||
*pom_oidc.Provider
|
||||
accessTokenAllowedAudiences *[]string
|
||||
}
|
||||
|
||||
// New instantiates an OpenID Connect (OIDC) provider for Azure.
|
||||
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||
var p Provider
|
||||
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
|
||||
var err error
|
||||
if o.ProviderURL == "" {
|
||||
o.ProviderURL = defaultProviderURL
|
||||
|
@ -73,6 +79,59 @@ func (p *Provider) Name() string {
|
|||
return Name
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies a raw access token.
|
||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting oidc provider: %w", err)
|
||||
}
|
||||
|
||||
// azure access tokens are JWTs signed with the same keys as identity tokens
|
||||
verifier := pp.Verifier(&go_oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
SkipIssuerCheck: true, // checked later
|
||||
})
|
||||
token, err := verifier.Verify(ctx, rawAccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||
}
|
||||
|
||||
claims = jwtutil.Claims(map[string]any{})
|
||||
err = token.Claims(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
||||
}
|
||||
|
||||
// verify audience
|
||||
if p.accessTokenAllowedAudiences != nil {
|
||||
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
|
||||
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
|
||||
}
|
||||
}
|
||||
|
||||
err = verifyIssuer(pp, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
|
||||
}
|
||||
|
||||
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
|
||||
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
|
||||
}
|
||||
|
||||
err = userInfo.Claims(claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// newProvider overrides the default round tripper for well-known endpoint call that happens
|
||||
// on new provider registration.
|
||||
// By default, the "common" (both public and private domains) responds with
|
||||
|
@ -128,3 +187,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res
|
|||
res.Body = io.NopCloser(bytes.NewReader(bs))
|
||||
return res, nil
|
||||
}
|
||||
|
||||
const (
|
||||
v1IssuerPrefix = "https://sts.windows.net/"
|
||||
v1IssuerSuffix = "/"
|
||||
v2IssuerPrefix = "https://login.microsoftonline.com/"
|
||||
v2IssuerSuffix = "/v2.0"
|
||||
)
|
||||
|
||||
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
|
||||
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to find tenant id")
|
||||
}
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing issuer claim")
|
||||
}
|
||||
|
||||
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
|
||||
return fmt.Errorf("invalid issuer: %s", iss)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
|
||||
// URLs look like:
|
||||
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
|
||||
// Or:
|
||||
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
|
||||
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
|
||||
path, ok := strings.CutPrefix(rawTokenURL, prefix)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
idx := strings.Index(path, "/")
|
||||
if idx <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rawTenantID := path[:idx]
|
||||
if _, err := uuid.Parse(rawTenantID); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return rawTenantID, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
|
|
@ -2,15 +2,27 @@ package azure
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
)
|
||||
|
||||
func TestAuthCodeOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var options oauth.Options
|
||||
p, err := New(context.Background(), &options)
|
||||
require.NoError(t, err)
|
||||
|
@ -21,3 +33,101 @@ func TestAuthCodeOptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{}, p.AuthCodeOptions)
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtSigner, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, nil)
|
||||
require.NoError(t, err)
|
||||
iat := time.Now().Unix()
|
||||
exp := iat + 3600
|
||||
rawAccessToken1, err := jwt.Signed(jwtSigner).Claims(map[string]any{
|
||||
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
|
||||
"aud": "https://client.example.com",
|
||||
"sub": "subject",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
}).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
rawAccessToken2, err := jwt.Signed(jwtSigner).Claims(map[string]any{
|
||||
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
|
||||
"aud": "https://unexpected.example.com",
|
||||
"sub": "subject",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
}).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
|
||||
var srvURL string
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"issuer": srvURL,
|
||||
"authorization_endpoint": srvURL + "/auth",
|
||||
"token_endpoint": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/token",
|
||||
"jwks_uri": srvURL + "/keys",
|
||||
"id_token_signing_alg_values_supported": []any{"RS256"},
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
json.NewEncoder(w).Encode(jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{Key: privateKey.Public(), Use: "sig", Algorithm: "RS256"},
|
||||
},
|
||||
})
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
srvURL = srv.URL
|
||||
|
||||
audiences := []string{"https://other.example.com", "https://client.example.com"}
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
ProviderName: Name,
|
||||
ProviderURL: srv.URL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
AccessTokenAllowedAudiences: &audiences,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := p.VerifyAccessToken(ctx, rawAccessToken1)
|
||||
require.NoError(t, err)
|
||||
delete(claims, "iat")
|
||||
delete(claims, "exp")
|
||||
assert.Equal(t, map[string]any{
|
||||
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
|
||||
"aud": "https://client.example.com",
|
||||
"sub": "subject",
|
||||
}, claims)
|
||||
|
||||
_, err = p.VerifyAccessToken(ctx, rawAccessToken2)
|
||||
assert.ErrorContains(t, err, "invalid audience")
|
||||
}
|
||||
|
||||
func TestVerifyIdentityToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
srv := httptest.NewServer(mux)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
ProviderName: Name,
|
||||
ProviderURL: srv.URL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := p.VerifyIdentityToken(ctx, "RAW IDENTITY TOKEN")
|
||||
assert.ErrorIs(t, identity.ErrVerifyIdentityTokenNotSupported, err)
|
||||
assert.Nil(t, claims)
|
||||
}
|
||||
|
|
|
@ -360,3 +360,13 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
|
|||
httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyAccessTokenNotSupported
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyIdentityTokenNotSupported
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
|
@ -23,7 +24,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/identity/oidc/okta"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oidc/onelogin"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oidc/ping"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// State is the identity state.
|
||||
|
@ -36,6 +36,8 @@ type Authenticator interface {
|
|||
Revoke(context.Context, *oauth2.Token) error
|
||||
Name() string
|
||||
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v any) error
|
||||
VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error)
|
||||
VerifyIdentityToken(ctx context.Context, rawIdentityToken string) (claims map[string]any, err error)
|
||||
|
||||
SignIn(w http.ResponseWriter, r *http.Request, state string) error
|
||||
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue