mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-11 15:22:46 +02:00
mcp: add list-routes client helper (#5596)
This commit is contained in:
parent
d2e2f56d57
commit
6caf65a117
8 changed files with 213 additions and 43 deletions
|
@ -293,6 +293,8 @@ var internalPathsNeedingLogin = set.From([]string{
|
||||||
"/.pomerium/routes",
|
"/.pomerium/routes",
|
||||||
"/.pomerium/api/v1/routes",
|
"/.pomerium/api/v1/routes",
|
||||||
"/.pomerium/mcp/authorize",
|
"/.pomerium/mcp/authorize",
|
||||||
|
"/.pomerium/mcp/routes",
|
||||||
|
"/.pomerium/mcp/connect",
|
||||||
})
|
})
|
||||||
|
|
||||||
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
|
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
|
||||||
|
|
|
@ -29,7 +29,7 @@ import (
|
||||||
// Policy contains route specific configuration and access settings.
|
// Policy contains route specific configuration and access settings.
|
||||||
type Policy struct {
|
type Policy struct {
|
||||||
ID string `mapstructure:"-" yaml:"-" json:"-"`
|
ID string `mapstructure:"-" yaml:"-" json:"-"`
|
||||||
Name string `mapstructure:"-" yaml:"-" json:"-"`
|
Name string `mapstructure:"name" yaml:"-" json:"name,omitempty"`
|
||||||
Description string `mapstructure:"description" yaml:"description,omitempty" json:"description,omitempty"`
|
Description string `mapstructure:"description" yaml:"description,omitempty" json:"description,omitempty"`
|
||||||
LogoURL string `mapstructure:"logo_url" yaml:"logo_url,omitempty" json:"logo_url,omitempty"`
|
LogoURL string `mapstructure:"logo_url" yaml:"logo_url,omitempty" json:"logo_url,omitempty"`
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,8 @@ const (
|
||||||
registerEndpoint = "/register"
|
registerEndpoint = "/register"
|
||||||
revocationEndpoint = "/revoke"
|
revocationEndpoint = "/revoke"
|
||||||
tokenEndpoint = "/token"
|
tokenEndpoint = "/token"
|
||||||
|
listRoutesEndpoint = "/routes"
|
||||||
|
connectEndpoint = "/connect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
|
@ -78,6 +80,8 @@ func (srv *Handler) HandlerFunc() http.HandlerFunc {
|
||||||
r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize)
|
r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize)
|
||||||
r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback)
|
r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback)
|
||||||
r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token)
|
r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token)
|
||||||
|
r.Path(path.Join(srv.prefix, listRoutesEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.ListRoutes)
|
||||||
|
r.Path(path.Join(srv.prefix, connectEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Connect)
|
||||||
|
|
||||||
return r.ServeHTTP
|
return r.ServeHTTP
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,7 +82,7 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
requiresUpstreamOAuth2Token := srv.relyingParties.HasConfigForHost(r.Host)
|
requiresUpstreamOAuth2Token := srv.relyingParties.HasOAuth2ConfigForHost(r.Host)
|
||||||
var authReqID string
|
var authReqID string
|
||||||
var hasUpstreamOAuth2Token bool
|
var hasUpstreamOAuth2Token bool
|
||||||
{
|
{
|
||||||
|
|
11
internal/mcp/handler_connect.go
Normal file
11
internal/mcp/handler_connect.go
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connect is a helper method for MCP clients to ensure that the current user
|
||||||
|
// has an active upstream Oauth2 session for the route.
|
||||||
|
func (srv *Handler) Connect(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||||
|
}
|
110
internal/mcp/handler_list_routes.go
Normal file
110
internal/mcp/handler_list_routes.go
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListMCPServers returns a list of MCP servers that are registered,
|
||||||
|
// and whether the current user has access to them.
|
||||||
|
func (srv *Handler) ListRoutes(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "invalid method", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := srv.listMCPServers(w, r)
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(r.Context()).Error().Err(err).Msg("failed to list MCP servers")
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Handler) listMCPServers(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
claims, err := getClaimsFromRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get claims from request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, ok := getUserIDFromClaims(claims)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("user id is not present in claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
var servers []serverInfo
|
||||||
|
for v := range srv.relyingParties.All() {
|
||||||
|
servers = append(servers, serverInfo{
|
||||||
|
Name: v.Name,
|
||||||
|
Description: v.Description,
|
||||||
|
LogoURL: v.LogoURL,
|
||||||
|
URL: v.URL,
|
||||||
|
needsOauth: v.Config != nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
servers, err = srv.checkHostsConnectedForUser(r.Context(), userID, servers)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check hosts connected for user %s: %w", userID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
|
||||||
|
w.Header().Set("Pragma", "no-cache")
|
||||||
|
w.Header().Set("Expires", "0")
|
||||||
|
|
||||||
|
type response struct {
|
||||||
|
Servers []serverInfo `json:"servers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.NewEncoder(w).Encode(response{
|
||||||
|
Servers: servers,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Handler) checkHostsConnectedForUser(
|
||||||
|
ctx context.Context,
|
||||||
|
userID string,
|
||||||
|
servers []serverInfo,
|
||||||
|
) ([]serverInfo, error) {
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
for i := range servers {
|
||||||
|
if !servers[i].needsOauth {
|
||||||
|
servers[i].Connected = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
eg.Go(func() error {
|
||||||
|
_, err := srv.storage.GetUpstreamOAuth2Token(ctx, servers[i].host, userID)
|
||||||
|
if err != nil && status.Code(err) != codes.NotFound {
|
||||||
|
return fmt.Errorf("failed to get oauth2 token for user %s: %w", userID, err)
|
||||||
|
}
|
||||||
|
servers[i].Connected = err == nil
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err := eg.Wait()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check hosts connected for user %s: %w", userID, err)
|
||||||
|
}
|
||||||
|
return servers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverInfo struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
LogoURL string `json:"logo_url,omitempty"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Connected bool `json:"connected"`
|
||||||
|
needsOauth bool `json:"-"`
|
||||||
|
host string `json:"-"`
|
||||||
|
}
|
|
@ -3,6 +3,8 @@ package mcp
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
@ -22,7 +24,16 @@ type OAuth2Configs struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
|
||||||
buildOnce sync.Once
|
buildOnce sync.Once
|
||||||
perHost map[string]*oauth2.Config
|
perHost map[string]HostInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type HostInfo struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
LogoURL string
|
||||||
|
Host string
|
||||||
|
URL string
|
||||||
|
Config *oauth2.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthConfig(
|
func NewOAuthConfig(
|
||||||
|
@ -43,39 +54,45 @@ func (r *OAuth2Configs) CodeExchangeForHost(
|
||||||
) (*oauth2.Token, error) {
|
) (*oauth2.Token, error) {
|
||||||
r.buildOnce.Do(r.build)
|
r.buildOnce.Do(r.build)
|
||||||
cfg, ok := r.perHost[host]
|
cfg, ok := r.perHost[host]
|
||||||
if !ok {
|
if !ok || cfg.Config == nil {
|
||||||
return nil, fmt.Errorf("no oauth2 config for host %s", host)
|
return nil, fmt.Errorf("no oauth2 config for host %s", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cfg.Exchange(ctx, code)
|
return cfg.Config.Exchange(ctx, code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OAuth2Configs) HasConfigForHost(host string) bool {
|
func (r *OAuth2Configs) HasOAuth2ConfigForHost(host string) bool {
|
||||||
r.buildOnce.Do(r.build)
|
r.buildOnce.Do(r.build)
|
||||||
_, ok := r.perHost[host]
|
v, ok := r.perHost[host]
|
||||||
return ok
|
return ok && v.Config != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
|
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
|
||||||
r.buildOnce.Do(r.build)
|
r.buildOnce.Do(r.build)
|
||||||
|
|
||||||
cfg, ok := r.perHost[host]
|
cfg, ok := r.perHost[host]
|
||||||
if !ok {
|
if !ok || cfg.Config == nil {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true
|
return cfg.Config.AuthCodeURL(state, oauth2.AccessTypeOffline), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OAuth2Configs) All() iter.Seq[HostInfo] {
|
||||||
|
r.buildOnce.Do(r.build)
|
||||||
|
return maps.Values(r.perHost)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OAuth2Configs) build() {
|
func (r *OAuth2Configs) build() {
|
||||||
r.perHost = BuildOAuthConfig(r.cfg, r.prefix)
|
r.perHost = BuildHostInfo(r.cfg, r.prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildOAuthConfig builds a map of OAuth2 configs per host
|
// BuildHostInfo indexes all policies by host
|
||||||
func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Config {
|
// and builds the oauth2.Config for each host if present.
|
||||||
configs := make(map[string]*oauth2.Config)
|
func BuildHostInfo(cfg *config.Config, prefix string) map[string]HostInfo {
|
||||||
|
info := make(map[string]HostInfo)
|
||||||
for policy := range cfg.Options.GetAllPolicies() {
|
for policy := range cfg.Options.GetAllPolicies() {
|
||||||
if !policy.IsMCPServer() || policy.MCP.UpstreamOAuth2 == nil {
|
if !policy.IsMCPServer() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
u, err := url.Parse(policy.GetFrom())
|
u, err := url.Parse(policy.GetFrom())
|
||||||
|
@ -83,10 +100,18 @@ func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Conf
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
host := u.Hostname()
|
host := u.Hostname()
|
||||||
if _, ok := configs[host]; ok {
|
if _, ok := info[host]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
cfg := &oauth2.Config{
|
v := HostInfo{
|
||||||
|
Name: policy.Name,
|
||||||
|
Description: policy.Description,
|
||||||
|
LogoURL: policy.LogoURL,
|
||||||
|
Host: host,
|
||||||
|
URL: policy.GetFrom(),
|
||||||
|
}
|
||||||
|
if policy.MCP.UpstreamOAuth2 != nil {
|
||||||
|
v.Config = &oauth2.Config{
|
||||||
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
|
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
|
||||||
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
|
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
|
||||||
Endpoint: oauth2.Endpoint{
|
Endpoint: oauth2.Endpoint{
|
||||||
|
@ -101,9 +126,10 @@ func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Conf
|
||||||
}).String(),
|
}).String(),
|
||||||
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
|
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
|
||||||
}
|
}
|
||||||
configs[host] = cfg
|
|
||||||
}
|
}
|
||||||
return configs
|
info[host] = v
|
||||||
|
}
|
||||||
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
|
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
|
|
@ -17,13 +17,18 @@ func TestBuildOAuthConfig(t *testing.T) {
|
||||||
Options: &config.Options{
|
Options: &config.Options{
|
||||||
Policies: []config.Policy{
|
Policies: []config.Policy{
|
||||||
{
|
{
|
||||||
|
Name: "test",
|
||||||
From: "https://regular.example.com",
|
From: "https://regular.example.com",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
Name: "mcp-1",
|
||||||
|
Description: "description-1",
|
||||||
|
LogoURL: "https://logo.example.com",
|
||||||
From: "https://mcp1.example.com",
|
From: "https://mcp1.example.com",
|
||||||
MCP: &config.MCP{},
|
MCP: &config.MCP{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
Name: "mcp-2",
|
||||||
From: "https://mcp2.example.com",
|
From: "https://mcp2.example.com",
|
||||||
MCP: &config.MCP{
|
MCP: &config.MCP{
|
||||||
UpstreamOAuth2: &config.UpstreamOAuth2{
|
UpstreamOAuth2: &config.UpstreamOAuth2{
|
||||||
|
@ -40,9 +45,20 @@ func TestBuildOAuthConfig(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
got := mcp.BuildOAuthConfig(cfg, "/prefix")
|
got := mcp.BuildHostInfo(cfg, "/prefix")
|
||||||
diff := cmp.Diff(got, map[string]*oauth2.Config{
|
diff := cmp.Diff(got, map[string]mcp.HostInfo{
|
||||||
|
"mcp1.example.com": {
|
||||||
|
Name: "mcp-1",
|
||||||
|
Host: "mcp1.example.com",
|
||||||
|
URL: "https://mcp1.example.com",
|
||||||
|
Description: "description-1",
|
||||||
|
LogoURL: "https://logo.example.com",
|
||||||
|
},
|
||||||
"mcp2.example.com": {
|
"mcp2.example.com": {
|
||||||
|
Name: "mcp-2",
|
||||||
|
Host: "mcp2.example.com",
|
||||||
|
URL: "https://mcp2.example.com",
|
||||||
|
Config: &oauth2.Config{
|
||||||
ClientID: "client_id",
|
ClientID: "client_id",
|
||||||
ClientSecret: "client_secret",
|
ClientSecret: "client_secret",
|
||||||
Endpoint: oauth2.Endpoint{
|
Endpoint: oauth2.Endpoint{
|
||||||
|
@ -52,6 +68,7 @@ func TestBuildOAuthConfig(t *testing.T) {
|
||||||
},
|
},
|
||||||
RedirectURL: "https://mcp2.example.com/prefix/oauth/callback",
|
RedirectURL: "https://mcp2.example.com/prefix/oauth/callback",
|
||||||
},
|
},
|
||||||
|
},
|
||||||
}, cmpopts.IgnoreUnexported(oauth2.Config{}))
|
}, cmpopts.IgnoreUnexported(oauth2.Config{}))
|
||||||
require.Empty(t, diff)
|
require.Empty(t, diff)
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue