mcp: redirect to upstream oauth2 for authentication (#5594)

This commit is contained in:
Denis Mishin 2025-05-01 12:16:44 -04:00 committed by GitHub
parent 5b024a8ada
commit 561b6040b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 166 additions and 9 deletions

View file

@ -34,6 +34,7 @@ type Handler struct {
trace oteltrace.TracerProvider trace oteltrace.TracerProvider
storage *Storage storage *Storage
cipher cipher.AEAD cipher cipher.AEAD
relyingParties *OAuth2Configs
} }
func New( func New(
@ -58,6 +59,7 @@ func New(
trace: tracerProvider, trace: tracerProvider,
storage: NewStorage(client), storage: NewStorage(client),
cipher: cipher, cipher: cipher,
relyingParties: NewOAuthConfig(cfg, http.DefaultClient),
}, nil }, nil
} }

View file

@ -68,8 +68,13 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
return return
} }
loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, id)
if ok {
http.Redirect(w, r, loginURL, http.StatusFound)
} else {
srv.AuthorizationResponse(ctx, w, r, id, v) srv.AuthorizationResponse(ctx, w, r, id, v)
} }
}
// AuthorizationResponse generates the successful authorization response // AuthorizationResponse generates the successful authorization response
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2 // see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2

View file

@ -0,0 +1,93 @@
package mcp
import (
"net/http"
"net/url"
"path"
"sync"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/config"
)
type OAuth2Configs struct {
cfg *config.Config
prefix string
httpClient *http.Client
buildOnce sync.Once
perHost map[string]*oauth2.Config
}
func NewOAuthConfig(
cfg *config.Config,
httpClient *http.Client,
) *OAuth2Configs {
return &OAuth2Configs{
prefix: DefaultPrefix,
cfg: cfg,
httpClient: httpClient,
}
}
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
r.buildOnce.Do(r.build)
cfg, ok := r.perHost[host]
if !ok {
return "", false
}
return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true
}
func (r *OAuth2Configs) build() {
r.perHost = BuildOAuthConfig(r.cfg, r.prefix)
}
// BuildOAuthConfig builds a map of OAuth2 configs per host
func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Config {
configs := make(map[string]*oauth2.Config)
for policy := range cfg.Options.GetAllPolicies() {
if !policy.IsMCPServer() || policy.MCP.UpstreamOAuth2 == nil {
continue
}
u, err := url.Parse(policy.GetFrom())
if err != nil {
continue
}
host := u.Hostname()
if _, ok := configs[host]; ok {
continue
}
cfg := &oauth2.Config{
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: policy.MCP.UpstreamOAuth2.Endpoint.AuthURL,
TokenURL: policy.MCP.UpstreamOAuth2.Endpoint.TokenURL,
AuthStyle: authStyleEnum(policy.MCP.UpstreamOAuth2.Endpoint.AuthStyle),
},
RedirectURL: (&url.URL{
Scheme: "https",
Host: host,
Path: path.Join(prefix, oauthCallbackEndpoint),
}).String(),
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
}
configs[host] = cfg
}
return configs
}
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
switch o {
case config.OAuth2EndpointAuthStyleInHeader:
return oauth2.AuthStyleInHeader
case config.OAuth2EndpointAuthStyleInParams:
return oauth2.AuthStyleInParams
default:
return oauth2.AuthStyleAutoDetect
}
}

View file

@ -0,0 +1,57 @@
package mcp_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/mcp"
)
func TestBuildOAuthConfig(t *testing.T) {
cfg := &config.Config{
Options: &config.Options{
Policies: []config.Policy{
{
From: "https://regular.example.com",
},
{
From: "https://mcp1.example.com",
MCP: &config.MCP{},
},
{
From: "https://mcp2.example.com",
MCP: &config.MCP{
UpstreamOAuth2: &config.UpstreamOAuth2{
ClientID: "client_id",
ClientSecret: "client_secret",
Endpoint: config.OAuth2Endpoint{
AuthURL: "https://auth.example.com/auth",
TokenURL: "https://auth.example.com/token",
AuthStyle: config.OAuth2EndpointAuthStyleInParams,
},
},
},
},
},
},
}
got := mcp.BuildOAuthConfig(cfg, "/prefix")
diff := cmp.Diff(got, map[string]*oauth2.Config{
"mcp2.example.com": {
ClientID: "client_id",
ClientSecret: "client_secret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://auth.example.com/auth",
TokenURL: "https://auth.example.com/token",
AuthStyle: oauth2.AuthStyleInParams,
},
RedirectURL: "https://mcp2.example.com/prefix/oauth/callback",
},
}, cmpopts.IgnoreUnexported(oauth2.Config{}))
require.Empty(t, diff)
}