authenticate: fix user-info call for AWS cognito (#792)

This commit is contained in:
Caleb Doxsey 2020-05-27 15:37:42 -06:00 committed by GitHub
parent b16bc5e090
commit 988477c90d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 1 deletions

View file

@ -125,7 +125,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v interface{})
//
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
func (p *Provider) updateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(t))
userInfo, err := getUserInfo(ctx, p.Provider, oauth2.StaticTokenSource(t))
if err != nil {
return fmt.Errorf("identity/oidc: user info endpoint: %w", err)
}

View file

@ -0,0 +1,65 @@
package oidc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strconv"
"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
)
// getUserInfo gets the user info for OIDC. We wrap the underlying call because AWS Cognito chose to violate the spec
// and return data in an invalid format. By using our own custom http client, we're able to modify the response to
// make it compliant, and then the rest of the library works as expected.
func getUserInfo(ctx context.Context, provider *oidc.Provider, tokenSource oauth2.TokenSource) (*oidc.UserInfo, error) {
originalClient := http.DefaultClient
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
originalClient = c
}
client := new(http.Client)
*client = *originalClient
client.Transport = &userInfoRoundTripper{underlying: client.Transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
return provider.UserInfo(ctx, tokenSource)
}
type userInfoRoundTripper struct {
underlying http.RoundTripper
}
func (transport *userInfoRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
underlying := transport.underlying
if underlying == nil {
underlying = http.DefaultTransport
}
res, err := underlying.RoundTrip(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
bs, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
// AWS Cognito returns email_verified as a string, so we'll make it a bool
var userInfo map[string]interface{}
if err := json.Unmarshal(bs, &userInfo); err == nil {
if ev, ok := userInfo["email_verified"]; ok {
userInfo["email_verified"], _ = strconv.ParseBool(fmt.Sprint(ev))
}
bs, _ = json.Marshal(userInfo)
}
res.Body = ioutil.NopCloser(bytes.NewReader(bs))
return res, nil
}

View file

@ -0,0 +1,74 @@
package oidc
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
func TestUserInfoRoundTrip(t *testing.T) {
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, `
{
"authorization_endpoint": "`+srv.URL+`/oauth2/authorize",
"id_token_signing_alg_values_supported": [
"RS256"
],
"issuer": "`+srv.URL+`",
"jwks_uri": "`+srv.URL+`/.well-known/jwks.json",
"response_types_supported": [
"code",
"token"
],
"scopes_supported": [
"openid",
"email",
"phone",
"profile"
],
"subject_types_supported": [
"public"
],
"token_endpoint": "`+srv.URL+`/oauth2/token",
"token_endpoint_auth_methods_supported": [
"client_secret_basic",
"client_secret_post"
],
"userinfo_endpoint": "`+srv.URL+`/oauth2/userInfo"
}`)
case "/oauth2/userInfo":
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, `{ "email_verified": "true" }`)
}
}))
defer srv.Close()
provider, err := oidc.NewProvider(context.Background(), srv.URL)
if !assert.NoError(t, err) {
return
}
token := oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "access-token",
TokenType: "Bearer",
RefreshToken: "refresh-token",
Expiry: time.Now().Add(time.Minute),
})
userInfo, err := getUserInfo(context.Background(), provider, token)
if !assert.NoError(t, err) {
return
}
assert.True(t, userInfo.EmailVerified)
}