mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 21:04:39 +02:00
mcp: add list-routes client helper (#5596)
This commit is contained in:
parent
d2e2f56d57
commit
6caf65a117
8 changed files with 213 additions and 43 deletions
|
@ -293,6 +293,8 @@ var internalPathsNeedingLogin = set.From([]string{
|
|||
"/.pomerium/routes",
|
||||
"/.pomerium/api/v1/routes",
|
||||
"/.pomerium/mcp/authorize",
|
||||
"/.pomerium/mcp/routes",
|
||||
"/.pomerium/mcp/connect",
|
||||
})
|
||||
|
||||
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
|
||||
|
|
|
@ -29,7 +29,7 @@ import (
|
|||
// Policy contains route specific configuration and access settings.
|
||||
type Policy struct {
|
||||
ID string `mapstructure:"-" yaml:"-" json:"-"`
|
||||
Name string `mapstructure:"-" yaml:"-" json:"-"`
|
||||
Name string `mapstructure:"name" yaml:"-" json:"name,omitempty"`
|
||||
Description string `mapstructure:"description" yaml:"description,omitempty" json:"description,omitempty"`
|
||||
LogoURL string `mapstructure:"logo_url" yaml:"logo_url,omitempty" json:"logo_url,omitempty"`
|
||||
|
||||
|
|
|
@ -27,6 +27,8 @@ const (
|
|||
registerEndpoint = "/register"
|
||||
revocationEndpoint = "/revoke"
|
||||
tokenEndpoint = "/token"
|
||||
listRoutesEndpoint = "/routes"
|
||||
connectEndpoint = "/connect"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
|
@ -78,6 +80,8 @@ func (srv *Handler) HandlerFunc() http.HandlerFunc {
|
|||
r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize)
|
||||
r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback)
|
||||
r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token)
|
||||
r.Path(path.Join(srv.prefix, listRoutesEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.ListRoutes)
|
||||
r.Path(path.Join(srv.prefix, connectEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Connect)
|
||||
|
||||
return r.ServeHTTP
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
requiresUpstreamOAuth2Token := srv.relyingParties.HasConfigForHost(r.Host)
|
||||
requiresUpstreamOAuth2Token := srv.relyingParties.HasOAuth2ConfigForHost(r.Host)
|
||||
var authReqID string
|
||||
var hasUpstreamOAuth2Token bool
|
||||
{
|
||||
|
|
11
internal/mcp/handler_connect.go
Normal file
11
internal/mcp/handler_connect.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Connect is a helper method for MCP clients to ensure that the current user
|
||||
// has an active upstream Oauth2 session for the route.
|
||||
func (srv *Handler) Connect(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
}
|
110
internal/mcp/handler_list_routes.go
Normal file
110
internal/mcp/handler_list_routes.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// ListMCPServers returns a list of MCP servers that are registered,
|
||||
// and whether the current user has access to them.
|
||||
func (srv *Handler) ListRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "invalid method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
err := srv.listMCPServers(w, r)
|
||||
if err != nil {
|
||||
log.Ctx(r.Context()).Error().Err(err).Msg("failed to list MCP servers")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Handler) listMCPServers(w http.ResponseWriter, r *http.Request) error {
|
||||
claims, err := getClaimsFromRequest(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get claims from request: %w", err)
|
||||
}
|
||||
|
||||
userID, ok := getUserIDFromClaims(claims)
|
||||
if !ok {
|
||||
return fmt.Errorf("user id is not present in claims")
|
||||
}
|
||||
|
||||
var servers []serverInfo
|
||||
for v := range srv.relyingParties.All() {
|
||||
servers = append(servers, serverInfo{
|
||||
Name: v.Name,
|
||||
Description: v.Description,
|
||||
LogoURL: v.LogoURL,
|
||||
URL: v.URL,
|
||||
needsOauth: v.Config != nil,
|
||||
})
|
||||
}
|
||||
|
||||
servers, err = srv.checkHostsConnectedForUser(r.Context(), userID, servers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check hosts connected for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
|
||||
type response struct {
|
||||
Servers []serverInfo `json:"servers"`
|
||||
}
|
||||
|
||||
return json.NewEncoder(w).Encode(response{
|
||||
Servers: servers,
|
||||
})
|
||||
}
|
||||
|
||||
func (srv *Handler) checkHostsConnectedForUser(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
servers []serverInfo,
|
||||
) ([]serverInfo, error) {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
for i := range servers {
|
||||
if !servers[i].needsOauth {
|
||||
servers[i].Connected = true
|
||||
continue
|
||||
}
|
||||
eg.Go(func() error {
|
||||
_, err := srv.storage.GetUpstreamOAuth2Token(ctx, servers[i].host, userID)
|
||||
if err != nil && status.Code(err) != codes.NotFound {
|
||||
return fmt.Errorf("failed to get oauth2 token for user %s: %w", userID, err)
|
||||
}
|
||||
servers[i].Connected = err == nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check hosts connected for user %s: %w", userID, err)
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
type serverInfo struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
LogoURL string `json:"logo_url,omitempty"`
|
||||
URL string `json:"url"`
|
||||
Connected bool `json:"connected"`
|
||||
needsOauth bool `json:"-"`
|
||||
host string `json:"-"`
|
||||
}
|
|
@ -3,6 +3,8 @@ package mcp
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
|
@ -22,7 +24,16 @@ type OAuth2Configs struct {
|
|||
httpClient *http.Client
|
||||
|
||||
buildOnce sync.Once
|
||||
perHost map[string]*oauth2.Config
|
||||
perHost map[string]HostInfo
|
||||
}
|
||||
|
||||
type HostInfo struct {
|
||||
Name string
|
||||
Description string
|
||||
LogoURL string
|
||||
Host string
|
||||
URL string
|
||||
Config *oauth2.Config
|
||||
}
|
||||
|
||||
func NewOAuthConfig(
|
||||
|
@ -43,39 +54,45 @@ func (r *OAuth2Configs) CodeExchangeForHost(
|
|||
) (*oauth2.Token, error) {
|
||||
r.buildOnce.Do(r.build)
|
||||
cfg, ok := r.perHost[host]
|
||||
if !ok {
|
||||
if !ok || cfg.Config == nil {
|
||||
return nil, fmt.Errorf("no oauth2 config for host %s", host)
|
||||
}
|
||||
|
||||
return cfg.Exchange(ctx, code)
|
||||
return cfg.Config.Exchange(ctx, code)
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) HasConfigForHost(host string) bool {
|
||||
func (r *OAuth2Configs) HasOAuth2ConfigForHost(host string) bool {
|
||||
r.buildOnce.Do(r.build)
|
||||
_, ok := r.perHost[host]
|
||||
return ok
|
||||
v, ok := r.perHost[host]
|
||||
return ok && v.Config != nil
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
|
||||
r.buildOnce.Do(r.build)
|
||||
|
||||
cfg, ok := r.perHost[host]
|
||||
if !ok {
|
||||
if !ok || cfg.Config == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true
|
||||
return cfg.Config.AuthCodeURL(state, oauth2.AccessTypeOffline), true
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) All() iter.Seq[HostInfo] {
|
||||
r.buildOnce.Do(r.build)
|
||||
return maps.Values(r.perHost)
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) build() {
|
||||
r.perHost = BuildOAuthConfig(r.cfg, r.prefix)
|
||||
r.perHost = BuildHostInfo(r.cfg, r.prefix)
|
||||
}
|
||||
|
||||
// BuildOAuthConfig builds a map of OAuth2 configs per host
|
||||
func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Config {
|
||||
configs := make(map[string]*oauth2.Config)
|
||||
// BuildHostInfo indexes all policies by host
|
||||
// and builds the oauth2.Config for each host if present.
|
||||
func BuildHostInfo(cfg *config.Config, prefix string) map[string]HostInfo {
|
||||
info := make(map[string]HostInfo)
|
||||
for policy := range cfg.Options.GetAllPolicies() {
|
||||
if !policy.IsMCPServer() || policy.MCP.UpstreamOAuth2 == nil {
|
||||
if !policy.IsMCPServer() {
|
||||
continue
|
||||
}
|
||||
u, err := url.Parse(policy.GetFrom())
|
||||
|
@ -83,10 +100,18 @@ func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Conf
|
|||
continue
|
||||
}
|
||||
host := u.Hostname()
|
||||
if _, ok := configs[host]; ok {
|
||||
if _, ok := info[host]; ok {
|
||||
continue
|
||||
}
|
||||
cfg := &oauth2.Config{
|
||||
v := HostInfo{
|
||||
Name: policy.Name,
|
||||
Description: policy.Description,
|
||||
LogoURL: policy.LogoURL,
|
||||
Host: host,
|
||||
URL: policy.GetFrom(),
|
||||
}
|
||||
if policy.MCP.UpstreamOAuth2 != nil {
|
||||
v.Config = &oauth2.Config{
|
||||
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
|
||||
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
|
@ -101,9 +126,10 @@ func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Conf
|
|||
}).String(),
|
||||
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
|
||||
}
|
||||
configs[host] = cfg
|
||||
}
|
||||
return configs
|
||||
info[host] = v
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
|
|
@ -17,13 +17,18 @@ func TestBuildOAuthConfig(t *testing.T) {
|
|||
Options: &config.Options{
|
||||
Policies: []config.Policy{
|
||||
{
|
||||
Name: "test",
|
||||
From: "https://regular.example.com",
|
||||
},
|
||||
{
|
||||
Name: "mcp-1",
|
||||
Description: "description-1",
|
||||
LogoURL: "https://logo.example.com",
|
||||
From: "https://mcp1.example.com",
|
||||
MCP: &config.MCP{},
|
||||
},
|
||||
{
|
||||
Name: "mcp-2",
|
||||
From: "https://mcp2.example.com",
|
||||
MCP: &config.MCP{
|
||||
UpstreamOAuth2: &config.UpstreamOAuth2{
|
||||
|
@ -40,9 +45,20 @@ func TestBuildOAuthConfig(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
got := mcp.BuildOAuthConfig(cfg, "/prefix")
|
||||
diff := cmp.Diff(got, map[string]*oauth2.Config{
|
||||
got := mcp.BuildHostInfo(cfg, "/prefix")
|
||||
diff := cmp.Diff(got, map[string]mcp.HostInfo{
|
||||
"mcp1.example.com": {
|
||||
Name: "mcp-1",
|
||||
Host: "mcp1.example.com",
|
||||
URL: "https://mcp1.example.com",
|
||||
Description: "description-1",
|
||||
LogoURL: "https://logo.example.com",
|
||||
},
|
||||
"mcp2.example.com": {
|
||||
Name: "mcp-2",
|
||||
Host: "mcp2.example.com",
|
||||
URL: "https://mcp2.example.com",
|
||||
Config: &oauth2.Config{
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
|
@ -52,6 +68,7 @@ func TestBuildOAuthConfig(t *testing.T) {
|
|||
},
|
||||
RedirectURL: "https://mcp2.example.com/prefix/oauth/callback",
|
||||
},
|
||||
},
|
||||
}, cmpopts.IgnoreUnexported(oauth2.Config{}))
|
||||
require.Empty(t, diff)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue