mcp: add list-routes client helper (#5596)

This commit is contained in:
Denis Mishin 2025-05-01 15:02:28 -04:00 committed by GitHub
parent d2e2f56d57
commit 6caf65a117
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 213 additions and 43 deletions

View file

@ -293,6 +293,8 @@ var internalPathsNeedingLogin = set.From([]string{
"/.pomerium/routes",
"/.pomerium/api/v1/routes",
"/.pomerium/mcp/authorize",
"/.pomerium/mcp/routes",
"/.pomerium/mcp/connect",
})
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {

View file

@ -29,7 +29,7 @@ import (
// Policy contains route specific configuration and access settings.
type Policy struct {
ID string `mapstructure:"-" yaml:"-" json:"-"`
Name string `mapstructure:"-" yaml:"-" json:"-"`
Name string `mapstructure:"name" yaml:"-" json:"name,omitempty"`
Description string `mapstructure:"description" yaml:"description,omitempty" json:"description,omitempty"`
LogoURL string `mapstructure:"logo_url" yaml:"logo_url,omitempty" json:"logo_url,omitempty"`

View file

@ -27,6 +27,8 @@ const (
registerEndpoint = "/register"
revocationEndpoint = "/revoke"
tokenEndpoint = "/token"
listRoutesEndpoint = "/routes"
connectEndpoint = "/connect"
)
type Handler struct {
@ -78,6 +80,8 @@ func (srv *Handler) HandlerFunc() http.HandlerFunc {
r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize)
r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback)
r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token)
r.Path(path.Join(srv.prefix, listRoutesEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.ListRoutes)
r.Path(path.Join(srv.prefix, connectEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Connect)
return r.ServeHTTP
}

View file

@ -82,7 +82,7 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
return
}
requiresUpstreamOAuth2Token := srv.relyingParties.HasConfigForHost(r.Host)
requiresUpstreamOAuth2Token := srv.relyingParties.HasOAuth2ConfigForHost(r.Host)
var authReqID string
var hasUpstreamOAuth2Token bool
{

View file

@ -0,0 +1,11 @@
package mcp
import (
"net/http"
)
// Connect is a helper method for MCP clients to ensure that the current user
// has an active upstream Oauth2 session for the route.
func (srv *Handler) Connect(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not implemented", http.StatusNotImplemented)
}

View file

@ -0,0 +1,110 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/internal/log"
)
// ListMCPServers returns a list of MCP servers that are registered,
// and whether the current user has access to them.
func (srv *Handler) ListRoutes(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "invalid method", http.StatusMethodNotAllowed)
return
}
err := srv.listMCPServers(w, r)
if err != nil {
log.Ctx(r.Context()).Error().Err(err).Msg("failed to list MCP servers")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
func (srv *Handler) listMCPServers(w http.ResponseWriter, r *http.Request) error {
claims, err := getClaimsFromRequest(r)
if err != nil {
return fmt.Errorf("failed to get claims from request: %w", err)
}
userID, ok := getUserIDFromClaims(claims)
if !ok {
return fmt.Errorf("user id is not present in claims")
}
var servers []serverInfo
for v := range srv.relyingParties.All() {
servers = append(servers, serverInfo{
Name: v.Name,
Description: v.Description,
LogoURL: v.LogoURL,
URL: v.URL,
needsOauth: v.Config != nil,
})
}
servers, err = srv.checkHostsConnectedForUser(r.Context(), userID, servers)
if err != nil {
return fmt.Errorf("failed to check hosts connected for user %s: %w", userID, err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
type response struct {
Servers []serverInfo `json:"servers"`
}
return json.NewEncoder(w).Encode(response{
Servers: servers,
})
}
func (srv *Handler) checkHostsConnectedForUser(
ctx context.Context,
userID string,
servers []serverInfo,
) ([]serverInfo, error) {
eg, ctx := errgroup.WithContext(ctx)
for i := range servers {
if !servers[i].needsOauth {
servers[i].Connected = true
continue
}
eg.Go(func() error {
_, err := srv.storage.GetUpstreamOAuth2Token(ctx, servers[i].host, userID)
if err != nil && status.Code(err) != codes.NotFound {
return fmt.Errorf("failed to get oauth2 token for user %s: %w", userID, err)
}
servers[i].Connected = err == nil
return nil
})
}
err := eg.Wait()
if err != nil {
return nil, fmt.Errorf("failed to check hosts connected for user %s: %w", userID, err)
}
return servers, nil
}
type serverInfo struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
LogoURL string `json:"logo_url,omitempty"`
URL string `json:"url"`
Connected bool `json:"connected"`
needsOauth bool `json:"-"`
host string `json:"-"`
}

View file

@ -3,6 +3,8 @@ package mcp
import (
"context"
"fmt"
"iter"
"maps"
"net/http"
"net/url"
"path"
@ -22,7 +24,16 @@ type OAuth2Configs struct {
httpClient *http.Client
buildOnce sync.Once
perHost map[string]*oauth2.Config
perHost map[string]HostInfo
}
type HostInfo struct {
Name string
Description string
LogoURL string
Host string
URL string
Config *oauth2.Config
}
func NewOAuthConfig(
@ -43,39 +54,45 @@ func (r *OAuth2Configs) CodeExchangeForHost(
) (*oauth2.Token, error) {
r.buildOnce.Do(r.build)
cfg, ok := r.perHost[host]
if !ok {
if !ok || cfg.Config == nil {
return nil, fmt.Errorf("no oauth2 config for host %s", host)
}
return cfg.Exchange(ctx, code)
return cfg.Config.Exchange(ctx, code)
}
func (r *OAuth2Configs) HasConfigForHost(host string) bool {
func (r *OAuth2Configs) HasOAuth2ConfigForHost(host string) bool {
r.buildOnce.Do(r.build)
_, ok := r.perHost[host]
return ok
v, ok := r.perHost[host]
return ok && v.Config != nil
}
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
r.buildOnce.Do(r.build)
cfg, ok := r.perHost[host]
if !ok {
if !ok || cfg.Config == nil {
return "", false
}
return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true
return cfg.Config.AuthCodeURL(state, oauth2.AccessTypeOffline), true
}
func (r *OAuth2Configs) All() iter.Seq[HostInfo] {
r.buildOnce.Do(r.build)
return maps.Values(r.perHost)
}
func (r *OAuth2Configs) build() {
r.perHost = BuildOAuthConfig(r.cfg, r.prefix)
r.perHost = BuildHostInfo(r.cfg, r.prefix)
}
// BuildOAuthConfig builds a map of OAuth2 configs per host
func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Config {
configs := make(map[string]*oauth2.Config)
// BuildHostInfo indexes all policies by host
// and builds the oauth2.Config for each host if present.
func BuildHostInfo(cfg *config.Config, prefix string) map[string]HostInfo {
info := make(map[string]HostInfo)
for policy := range cfg.Options.GetAllPolicies() {
if !policy.IsMCPServer() || policy.MCP.UpstreamOAuth2 == nil {
if !policy.IsMCPServer() {
continue
}
u, err := url.Parse(policy.GetFrom())
@ -83,10 +100,18 @@ func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Conf
continue
}
host := u.Hostname()
if _, ok := configs[host]; ok {
if _, ok := info[host]; ok {
continue
}
cfg := &oauth2.Config{
v := HostInfo{
Name: policy.Name,
Description: policy.Description,
LogoURL: policy.LogoURL,
Host: host,
URL: policy.GetFrom(),
}
if policy.MCP.UpstreamOAuth2 != nil {
v.Config = &oauth2.Config{
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
Endpoint: oauth2.Endpoint{
@ -101,9 +126,10 @@ func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Conf
}).String(),
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
}
configs[host] = cfg
}
return configs
info[host] = v
}
return info
}
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {

View file

@ -17,13 +17,18 @@ func TestBuildOAuthConfig(t *testing.T) {
Options: &config.Options{
Policies: []config.Policy{
{
Name: "test",
From: "https://regular.example.com",
},
{
Name: "mcp-1",
Description: "description-1",
LogoURL: "https://logo.example.com",
From: "https://mcp1.example.com",
MCP: &config.MCP{},
},
{
Name: "mcp-2",
From: "https://mcp2.example.com",
MCP: &config.MCP{
UpstreamOAuth2: &config.UpstreamOAuth2{
@ -40,9 +45,20 @@ func TestBuildOAuthConfig(t *testing.T) {
},
},
}
got := mcp.BuildOAuthConfig(cfg, "/prefix")
diff := cmp.Diff(got, map[string]*oauth2.Config{
got := mcp.BuildHostInfo(cfg, "/prefix")
diff := cmp.Diff(got, map[string]mcp.HostInfo{
"mcp1.example.com": {
Name: "mcp-1",
Host: "mcp1.example.com",
URL: "https://mcp1.example.com",
Description: "description-1",
LogoURL: "https://logo.example.com",
},
"mcp2.example.com": {
Name: "mcp-2",
Host: "mcp2.example.com",
URL: "https://mcp2.example.com",
Config: &oauth2.Config{
ClientID: "client_id",
ClientSecret: "client_secret",
Endpoint: oauth2.Endpoint{
@ -52,6 +68,7 @@ func TestBuildOAuthConfig(t *testing.T) {
},
RedirectURL: "https://mcp2.example.com/prefix/oauth/callback",
},
},
}, cmpopts.IgnoreUnexported(oauth2.Config{}))
require.Empty(t, diff)
}