mcp: add oauth metadata endpoint (#5579)

This commit is contained in:
Denis Mishin 2025-04-23 12:24:00 -04:00 committed by GitHub
parent 2e7d1c7f12
commit cb0e8aaf06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 324 additions and 32 deletions

View file

@ -1273,23 +1273,27 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
}
// GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
func (o *Options) GetAllRouteableHTTPHosts() ([]string, map[string]bool, error) {
hosts := goset.NewTreeSet(cmp.Compare[string])
mcpHosts := make(map[string]bool)
if IsAuthenticate(o.Services) {
if o.AuthenticateInternalURLString != "" {
authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil {
return nil, err
return nil, nil, err
}
hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
domains := urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(domains)
}
if o.AuthenticateURLString != "" {
authenticateURL, err := o.GetAuthenticateURL()
if err != nil {
return nil, err
return nil, nil, err
}
hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
domains := urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(domains)
}
}
@ -1298,18 +1302,35 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
for policy := range o.GetAllPolicies() {
fromURL, err := urlutil.ParseAndValidateURL(policy.From)
if err != nil {
return nil, err
return nil, nil, err
}
domains := urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(domains)
// Track if the domains are associated with an MCP policy
if policy.IsMCP() {
for _, domain := range domains {
mcpHosts[domain] = true
}
}
hosts.InsertSlice(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
if policy.TLSDownstreamServerName != "" {
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
hosts.InsertSlice(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
tlsDomains := urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(tlsDomains)
// Track if the TLS domains are associated with an MCP policy
if policy.IsMCP() {
for _, domain := range tlsDomains {
mcpHosts[domain] = true
}
}
}
}
}
return hosts.Slice(), nil
return hosts.Slice(), mcpHosts, nil
}
// GetClientSecret gets the client secret.