diff --git a/internal/mcp/handler.go b/internal/mcp/handler.go index 52bcfff77..ba9c7255a 100644 --- a/internal/mcp/handler.go +++ b/internal/mcp/handler.go @@ -30,10 +30,11 @@ const ( ) type Handler struct { - prefix string - trace oteltrace.TracerProvider - storage *Storage - cipher cipher.AEAD + prefix string + trace oteltrace.TracerProvider + storage *Storage + cipher cipher.AEAD + relyingParties *OAuth2Configs } func New( @@ -54,10 +55,11 @@ func New( } return &Handler{ - prefix: prefix, - trace: tracerProvider, - storage: NewStorage(client), - cipher: cipher, + prefix: prefix, + trace: tracerProvider, + storage: NewStorage(client), + cipher: cipher, + relyingParties: NewOAuthConfig(cfg, http.DefaultClient), }, nil } diff --git a/internal/mcp/handler_authorization.go b/internal/mcp/handler_authorization.go index 714ae5ae6..7e8c5f0cf 100644 --- a/internal/mcp/handler_authorization.go +++ b/internal/mcp/handler_authorization.go @@ -68,7 +68,12 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) { return } - srv.AuthorizationResponse(ctx, w, r, id, v) + loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, id) + if ok { + http.Redirect(w, r, loginURL, http.StatusFound) + } else { + srv.AuthorizationResponse(ctx, w, r, id, v) + } } // AuthorizationResponse generates the successful authorization response diff --git a/internal/mcp/oauth_config.go b/internal/mcp/oauth_config.go new file mode 100644 index 000000000..648384e56 --- /dev/null +++ b/internal/mcp/oauth_config.go @@ -0,0 +1,93 @@ +package mcp + +import ( + "net/http" + "net/url" + "path" + "sync" + + "golang.org/x/oauth2" + + "github.com/pomerium/pomerium/config" +) + +type OAuth2Configs struct { + cfg *config.Config + prefix string + httpClient *http.Client + + buildOnce sync.Once + perHost map[string]*oauth2.Config +} + +func NewOAuthConfig( + cfg *config.Config, + httpClient *http.Client, +) *OAuth2Configs { + return &OAuth2Configs{ + prefix: DefaultPrefix, + cfg: cfg, + httpClient: httpClient, + } +} + +func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) { + r.buildOnce.Do(r.build) + + cfg, ok := r.perHost[host] + if !ok { + return "", false + } + + return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true +} + +func (r *OAuth2Configs) build() { + r.perHost = BuildOAuthConfig(r.cfg, r.prefix) +} + +// BuildOAuthConfig builds a map of OAuth2 configs per host +func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Config { + configs := make(map[string]*oauth2.Config) + for policy := range cfg.Options.GetAllPolicies() { + if !policy.IsMCPServer() || policy.MCP.UpstreamOAuth2 == nil { + continue + } + u, err := url.Parse(policy.GetFrom()) + if err != nil { + continue + } + host := u.Hostname() + if _, ok := configs[host]; ok { + continue + } + cfg := &oauth2.Config{ + ClientID: policy.MCP.UpstreamOAuth2.ClientID, + ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: policy.MCP.UpstreamOAuth2.Endpoint.AuthURL, + TokenURL: policy.MCP.UpstreamOAuth2.Endpoint.TokenURL, + AuthStyle: authStyleEnum(policy.MCP.UpstreamOAuth2.Endpoint.AuthStyle), + }, + RedirectURL: (&url.URL{ + Scheme: "https", + Host: host, + Path: path.Join(prefix, oauthCallbackEndpoint), + }).String(), + Scopes: policy.MCP.UpstreamOAuth2.Scopes, + } + configs[host] = cfg + } + return configs +} + +func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle { + switch o { + case config.OAuth2EndpointAuthStyleInHeader: + return oauth2.AuthStyleInHeader + case config.OAuth2EndpointAuthStyleInParams: + return oauth2.AuthStyleInParams + default: + return oauth2.AuthStyleAutoDetect + } +} diff --git a/internal/mcp/oauth_config_test.go b/internal/mcp/oauth_config_test.go new file mode 100644 index 000000000..99fd0551f --- /dev/null +++ b/internal/mcp/oauth_config_test.go @@ -0,0 +1,57 @@ +package mcp_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/mcp" +) + +func TestBuildOAuthConfig(t *testing.T) { + cfg := &config.Config{ + Options: &config.Options{ + Policies: []config.Policy{ + { + From: "https://regular.example.com", + }, + { + From: "https://mcp1.example.com", + MCP: &config.MCP{}, + }, + { + From: "https://mcp2.example.com", + MCP: &config.MCP{ + UpstreamOAuth2: &config.UpstreamOAuth2{ + ClientID: "client_id", + ClientSecret: "client_secret", + Endpoint: config.OAuth2Endpoint{ + AuthURL: "https://auth.example.com/auth", + TokenURL: "https://auth.example.com/token", + AuthStyle: config.OAuth2EndpointAuthStyleInParams, + }, + }, + }, + }, + }, + }, + } + got := mcp.BuildOAuthConfig(cfg, "/prefix") + diff := cmp.Diff(got, map[string]*oauth2.Config{ + "mcp2.example.com": { + ClientID: "client_id", + ClientSecret: "client_secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "https://auth.example.com/auth", + TokenURL: "https://auth.example.com/token", + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: "https://mcp2.example.com/prefix/oauth/callback", + }, + }, cmpopts.IgnoreUnexported(oauth2.Config{})) + require.Empty(t, diff) +}