mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 16:59:22 +02:00
authenticate: fix user-info call for AWS cognito (#792)
This commit is contained in:
parent
b16bc5e090
commit
988477c90d
3 changed files with 140 additions and 1 deletions
|
@ -125,7 +125,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v interface{})
|
|||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
func (p *Provider) updateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(t))
|
||||
userInfo, err := getUserInfo(ctx, p.Provider, oauth2.StaticTokenSource(t))
|
||||
if err != nil {
|
||||
return fmt.Errorf("identity/oidc: user info endpoint: %w", err)
|
||||
}
|
||||
|
|
65
internal/identity/oidc/userinfo.go
Normal file
65
internal/identity/oidc/userinfo.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// getUserInfo gets the user info for OIDC. We wrap the underlying call because AWS Cognito chose to violate the spec
|
||||
// and return data in an invalid format. By using our own custom http client, we're able to modify the response to
|
||||
// make it compliant, and then the rest of the library works as expected.
|
||||
func getUserInfo(ctx context.Context, provider *oidc.Provider, tokenSource oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
originalClient := http.DefaultClient
|
||||
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
originalClient = c
|
||||
}
|
||||
|
||||
client := new(http.Client)
|
||||
*client = *originalClient
|
||||
client.Transport = &userInfoRoundTripper{underlying: client.Transport}
|
||||
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
return provider.UserInfo(ctx, tokenSource)
|
||||
}
|
||||
|
||||
type userInfoRoundTripper struct {
|
||||
underlying http.RoundTripper
|
||||
}
|
||||
|
||||
func (transport *userInfoRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
underlying := transport.underlying
|
||||
if underlying == nil {
|
||||
underlying = http.DefaultTransport
|
||||
}
|
||||
|
||||
res, err := underlying.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
bs, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// AWS Cognito returns email_verified as a string, so we'll make it a bool
|
||||
var userInfo map[string]interface{}
|
||||
if err := json.Unmarshal(bs, &userInfo); err == nil {
|
||||
if ev, ok := userInfo["email_verified"]; ok {
|
||||
userInfo["email_verified"], _ = strconv.ParseBool(fmt.Sprint(ev))
|
||||
}
|
||||
bs, _ = json.Marshal(userInfo)
|
||||
}
|
||||
|
||||
res.Body = ioutil.NopCloser(bytes.NewReader(bs))
|
||||
return res, nil
|
||||
}
|
74
internal/identity/oidc/userinfo_test.go
Normal file
74
internal/identity/oidc/userinfo_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestUserInfoRoundTrip(t *testing.T) {
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
io.WriteString(w, `
|
||||
{
|
||||
"authorization_endpoint": "`+srv.URL+`/oauth2/authorize",
|
||||
"id_token_signing_alg_values_supported": [
|
||||
"RS256"
|
||||
],
|
||||
"issuer": "`+srv.URL+`",
|
||||
"jwks_uri": "`+srv.URL+`/.well-known/jwks.json",
|
||||
"response_types_supported": [
|
||||
"code",
|
||||
"token"
|
||||
],
|
||||
"scopes_supported": [
|
||||
"openid",
|
||||
"email",
|
||||
"phone",
|
||||
"profile"
|
||||
],
|
||||
"subject_types_supported": [
|
||||
"public"
|
||||
],
|
||||
"token_endpoint": "`+srv.URL+`/oauth2/token",
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"client_secret_basic",
|
||||
"client_secret_post"
|
||||
],
|
||||
"userinfo_endpoint": "`+srv.URL+`/oauth2/userInfo"
|
||||
}`)
|
||||
case "/oauth2/userInfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
io.WriteString(w, `{ "email_verified": "true" }`)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
provider, err := oidc.NewProvider(context.Background(), srv.URL)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
token := oauth2.StaticTokenSource(&oauth2.Token{
|
||||
AccessToken: "access-token",
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: "refresh-token",
|
||||
Expiry: time.Now().Add(time.Minute),
|
||||
})
|
||||
|
||||
userInfo, err := getUserInfo(context.Background(), provider, token)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.True(t, userInfo.EmailVerified)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue