From 9fa65e069c2306057d779ea223a4268841667fc3 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 18 Aug 2021 09:20:08 -0600 Subject: [PATCH] github: support provider URL (#2490) --- internal/httputil/client.go | 12 +++++++----- internal/httputil/client_test.go | 2 +- internal/identity/oauth/github/github.go | 22 ++++++++++++++++------ internal/urlutil/url.go | 18 ++++++++++++++++++ internal/urlutil/url_test.go | 8 ++++++++ 5 files changed, 50 insertions(+), 12 deletions(-) diff --git a/internal/httputil/client.go b/internal/httputil/client.go index 3b6f3fd28..1bcf963f2 100644 --- a/internal/httputil/client.go +++ b/internal/httputil/client.go @@ -80,10 +80,12 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) { return c.Client.Do(req) } -// defaultClient avoids leaks by setting an upper limit for timeouts. -var defaultClient = &httpClient{ - &http.Client{Timeout: 1 * time.Minute}, - requestid.NewRoundTripper(http.DefaultTransport), +// getDefaultClient returns an HTTP client that avoids leaks by setting an upper limit for timeouts. +func getDefaultClient() *httpClient { + return &httpClient{ + &http.Client{Timeout: 1 * time.Minute}, + requestid.NewRoundTripper(http.DefaultTransport), + } } // Do provides a simple helper interface to make HTTP requests @@ -113,7 +115,7 @@ func Do(ctx context.Context, method, endpoint, userAgent string, headers map[str req.Header.Set(k, v) } - resp, err := defaultClient.Do(req) + resp, err := getDefaultClient().Do(req) if err != nil { return err } diff --git a/internal/httputil/client_test.go b/internal/httputil/client_test.go index 7bcdb1f12..5967b84e8 100644 --- a/internal/httputil/client_test.go +++ b/internal/httputil/client_test.go @@ -21,5 +21,5 @@ func TestDefaultClient(t *testing.T) { defer ts.Close() req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) req = req.WithContext(requestid.WithValue(context.Background(), "foo")) - _, _ = defaultClient.Do(req) + _, _ = getDefaultClient().Do(req) } diff --git a/internal/identity/oauth/github/github.go b/internal/identity/oauth/github/github.go index e86fac835..36c4c6cc2 100644 --- a/internal/identity/oauth/github/github.go +++ b/internal/identity/oauth/github/github.go @@ -21,6 +21,7 @@ import ( "github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" ) @@ -50,7 +51,8 @@ var defaultScopes = []string{"user:email", "read:org"} type Provider struct { Oauth *oauth2.Config - userEndpoint string + userEndpoint string + emailEndpoint string } // New instantiates an OAuth2 provider for Github. @@ -59,6 +61,16 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { if o.ProviderURL == "" { o.ProviderURL = defaultProviderURL } + + // when the default provider url is used, use the Github API endpoint + if o.ProviderURL == defaultProviderURL { + p.userEndpoint = urlutil.Join(githubAPIURL, userPath) + p.emailEndpoint = urlutil.Join(githubAPIURL, emailPath) + } else { + p.userEndpoint = urlutil.Join(o.ProviderURL, userPath) + p.emailEndpoint = urlutil.Join(o.ProviderURL, emailPath) + } + if len(o.Scopes) == 0 { o.Scopes = defaultScopes } @@ -68,11 +80,10 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { Scopes: o.Scopes, RedirectURL: o.RedirectURL.String(), Endpoint: oauth2.Endpoint{ - AuthURL: o.ProviderURL + authURL, - TokenURL: o.ProviderURL + tokenURL, + AuthURL: urlutil.Join(o.ProviderURL, authURL), + TokenURL: urlutil.Join(o.ProviderURL, tokenURL), }, } - p.userEndpoint = githubAPIURL + userPath return &p, nil } @@ -133,8 +144,7 @@ func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{} Visibility string `json:"visibility"` } headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)} - emailURL := githubAPIURL + emailPath - err := httputil.Do(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response) + err := httputil.Do(ctx, http.MethodGet, p.emailEndpoint, version.UserAgent(), headers, nil, &response) if err != nil { return err } diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index cbec61e3b..87a873d72 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -113,3 +113,21 @@ func GetDomainsForURL(u url.URL) []string { func IsTCP(u *url.URL) bool { return u.Scheme == "tcp+http" || u.Scheme == "tcp+https" } + +// Join joins elements of a URL with '/'. +func Join(elements ...string) string { + var builder strings.Builder + appendSlash := false + for i, el := range elements { + if appendSlash { + builder.WriteByte('/') + } + if i > 0 && strings.HasPrefix(el, "/") { + builder.WriteString(el[1:]) + } else { + builder.WriteString(el) + } + appendSlash = !strings.HasSuffix(el, "/") + } + return builder.String() +} diff --git a/internal/urlutil/url_test.go b/internal/urlutil/url_test.go index be0384b2c..06e8ee4c4 100644 --- a/internal/urlutil/url_test.go +++ b/internal/urlutil/url_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" ) func Test_StripPort(t *testing.T) { @@ -158,3 +159,10 @@ func TestGetDomainsForURL(t *testing.T) { }) } } + +func TestJoin(t *testing.T) { + assert.Equal(t, "/x/y/z/", Join("/x", "y/z/")) + assert.Equal(t, "/x/y/z/", Join("/x/", "y/z/")) + assert.Equal(t, "/x/y/z/", Join("/x", "/y/z/")) + assert.Equal(t, "/x/y/z/", Join("/x/", "/y/z/")) +}