diff --git a/internal/mcp/handler.go b/internal/mcp/handler.go index 47e284d21..0ecb86cfa 100644 --- a/internal/mcp/handler.go +++ b/internal/mcp/handler.go @@ -1,5 +1,18 @@ package mcp +import ( + "context" + "net/http" + "path" + + "github.com/gorilla/mux" + "github.com/rs/cors" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/telemetry/trace" +) + const ( DefaultPrefix = "/.pomerium/mcp" @@ -9,3 +22,39 @@ const ( revocationEndpoint = "/revoke" tokenEndpoint = "/token" ) + +type Handler struct { + prefix string + trace oteltrace.TracerProvider +} + +func New( + ctx context.Context, + prefix string, + _ *config.Config, +) (*Handler, error) { + tracerProvider := trace.NewTracerProvider(ctx, "MCP") + return &Handler{ + prefix: prefix, + trace: tracerProvider, + }, nil +} + +// HandlerFunc returns a http.HandlerFunc that handles the mcp endpoints. +func (srv *Handler) HandlerFunc() http.HandlerFunc { + r := mux.NewRouter() + r.Use(cors.New(cors.Options{ + AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, + AllowedOrigins: []string{"*"}, + AllowedHeaders: []string{"content-type", "mcp-protocol-version"}, + }).Handler) + r.Methods(http.MethodOptions).HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + r.Path(path.Join(srv.prefix, registerEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.RegisterClient) + r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize) + r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback) + r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token) + + return r.ServeHTTP +} diff --git a/internal/mcp/handler_authorization.go b/internal/mcp/handler_authorization.go new file mode 100644 index 000000000..1197c99e6 --- /dev/null +++ b/internal/mcp/handler_authorization.go @@ -0,0 +1,10 @@ +package mcp + +import ( + "net/http" +) + +// Authorize handles the /authorize endpoint. +func (srv *Handler) Authorize(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} diff --git a/internal/mcp/handler_oauth_callback.go b/internal/mcp/handler_oauth_callback.go new file mode 100644 index 000000000..911a7aef4 --- /dev/null +++ b/internal/mcp/handler_oauth_callback.go @@ -0,0 +1,9 @@ +package mcp + +import ( + "net/http" +) + +func (srv *Handler) OAuthCallback(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} diff --git a/internal/mcp/handler_register_client.go b/internal/mcp/handler_register_client.go new file mode 100644 index 000000000..c97b07a6c --- /dev/null +++ b/internal/mcp/handler_register_client.go @@ -0,0 +1,11 @@ +package mcp + +import ( + "net/http" +) + +// RegisterClient handles the /register endpoint. +// It is used to register a new client with the MCP server. +func (srv *Handler) RegisterClient(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} diff --git a/internal/mcp/handler_token.go b/internal/mcp/handler_token.go new file mode 100644 index 000000000..673028766 --- /dev/null +++ b/internal/mcp/handler_token.go @@ -0,0 +1,10 @@ +package mcp + +import ( + "net/http" +) + +// Token handles the /token endpoint. +func (srv *Handler) Token(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} diff --git a/proxy/handlers.go b/proxy/handlers.go index de51ae37e..f5498aed2 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -24,6 +24,9 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) * h := httputil.DashboardSubrouter(r) h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) + // model context protocol + h.PathPrefix("/mcp").Handler(p.mcp.HandlerFunc()) + // special pomerium endpoints for users to view their session h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet) h.Path("/device-enrolled").Handler(httputil.HandlerFunc(p.deviceEnrolled)) diff --git a/proxy/proxy.go b/proxy/proxy.go index a6681da9f..064a87d58 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -18,6 +18,7 @@ import ( "github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/mcp" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/storage" @@ -63,6 +64,7 @@ type Proxy struct { webauthn *webauthn.Handler tracerProvider oteltrace.TracerProvider logoProvider portal.LogoProvider + mcp *mcp.Handler } // New takes a Proxy service from options and a validation function. @@ -74,12 +76,18 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { return nil, err } + mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) + if err != nil { + return nil, fmt.Errorf("proxy: failed to create mcp handler: %w", err) + } + p := &Proxy{ tracerProvider: tracerProvider, state: atomicutil.NewValue(state), currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}), currentRouter: atomicutil.NewValue(httputil.NewRouter()), logoProvider: portal.NewLogoProvider(), + mcp: mcp, } p.OnConfigChange(ctx, cfg) p.webauthn = webauthn.New(p.getWebauthnState)