From 73d8900c47a5fe5c9cc776b692b39de1d038de4e Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 16 Dec 2022 13:24:40 -0700 Subject: [PATCH] oidc: fix token revocation (#3810) --- internal/identity/oidc/oidc.go | 24 ++++++----- internal/identity/oidc/oidc_test.go | 65 +++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 10 deletions(-) create mode 100644 internal/identity/oidc/oidc_test.go diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 14a565e00..13fbce41f 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -176,6 +176,11 @@ func (p *Provider) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interf // Group membership is also refreshed. // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) { + oa, err := p.GetOauthConfig() + if err != nil { + return nil, err + } + if t == nil { return nil, ErrMissingAccessToken } @@ -183,11 +188,6 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat return nil, ErrMissingRefreshToken } - oa, err := p.GetOauthConfig() - if err != nil { - return nil, err - } - newToken, err := oa.TokenSource(ctx, t).Token() if err != nil { return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) @@ -230,6 +230,11 @@ func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.ID // // https://tools.ietf.org/html/rfc7009#section-2.1 func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error { + oa, err := p.GetOauthConfig() + if err != nil { + return err + } + if p.RevocationURL == "" { return ErrRevokeNotImplemented } @@ -237,11 +242,6 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error { return ErrMissingAccessToken } - oa, err := p.GetOauthConfig() - if err != nil { - return err - } - params := url.Values{} params.Add("token", t.AccessToken) params.Add("token_type_hint", "access_token") @@ -263,6 +263,10 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error { // session to be initiated. // https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated func (p *Provider) LogOut() (*url.URL, error) { + _, err := p.GetProvider() + if err != nil { + return nil, err + } if p.EndSessionURL == "" { return nil, ErrSignoutNotImplemented } diff --git a/internal/identity/oidc/oidc_test.go b/internal/identity/oidc/oidc_test.go new file mode 100644 index 000000000..a15601925 --- /dev/null +++ b/internal/identity/oidc/oidc_test.go @@ -0,0 +1,65 @@ +package oidc + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/pomerium/pomerium/internal/identity/oauth" +) + +func TestRevoke(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + var srv *httptest.Server + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + baseURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/.well-known/openid-configuration": + json.NewEncoder(w).Encode(map[string]any{ + "issuer": baseURL.String(), + "revocation_endpoint": baseURL.ResolveReference(&url.URL{ + Path: "/revoke", + }).String(), + }) + case "/revoke": + assert.Equal(t, "ACCESS_TOKEN", r.FormValue("token")) + assert.Equal(t, "access_token", r.FormValue("token_type_hint")) + assert.Equal(t, "CLIENT_ID", r.FormValue("client_id")) + assert.Equal(t, "CLIENT_SECRET", r.FormValue("client_secret")) + + default: + assert.Failf(t, "unexpected http request", "url: %s", r.URL.String()) + } + }) + srv = httptest.NewServer(handler) + t.Cleanup(srv.Close) + + redirectURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + p, err := New(ctx, &oauth.Options{ + ProviderURL: srv.URL, + RedirectURL: redirectURL, + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + }) + require.NoError(t, err) + require.NotNil(t, p) + + assert.NoError(t, p.Revoke(ctx, &oauth2.Token{ + AccessToken: "ACCESS_TOKEN", + })) +}