mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-07 20:08:17 +02:00
## Summary adds implementation of `/.pomerium/mcp/connect` method, that takes a `redirect_url` parameter and would ensure the user goes thru required redirects so that its session is hydrated with the upstream Oauth token for the MCP server. the `redirect_url` parameter host must match one of the _client_ mcp routes (currently identified by the presence of `mcp: pass_upstream_access_token: true` in the route. ## Related issues Fix https://linear.app/pomerium/issue/ENG-2321/mcp-support-handling-external-oauth-servers ## User Explanation <!-- How would you explain this change to the user? If this change doesn't create any user-facing changes, you can leave this blank. If filled out, add the `docs` label --> ## Checklist - [x] reference any related issues - [x] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review
113 lines
3.5 KiB
Go
113 lines
3.5 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"crypto/cipher"
|
|
"fmt"
|
|
"net/http"
|
|
"path"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/rs/cors"
|
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
|
oteltrace "go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/sync/singleflight"
|
|
googlegrpc "google.golang.org/grpc"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/pkg/grpc"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
|
)
|
|
|
|
const (
|
|
DefaultPrefix = "/.pomerium/mcp"
|
|
|
|
authorizationEndpoint = "/authorize"
|
|
oauthCallbackEndpoint = "/oauth/callback"
|
|
registerEndpoint = "/register"
|
|
revocationEndpoint = "/revoke"
|
|
tokenEndpoint = "/token"
|
|
listRoutesEndpoint = "/routes"
|
|
connectEndpoint = "/connect"
|
|
)
|
|
|
|
type Handler struct {
|
|
prefix string
|
|
trace oteltrace.TracerProvider
|
|
storage *Storage
|
|
cipher cipher.AEAD
|
|
hosts *HostInfo
|
|
hostsSingleFlight singleflight.Group
|
|
}
|
|
|
|
func New(
|
|
ctx context.Context,
|
|
prefix string,
|
|
cfg *config.Config,
|
|
) (*Handler, error) {
|
|
tracerProvider := trace.NewTracerProvider(ctx, "MCP")
|
|
|
|
client, err := getDatabrokerServiceClient(ctx, cfg, tracerProvider)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("databroker client: %w", err)
|
|
}
|
|
|
|
cipher, err := getCipher(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get cipher: %w", err)
|
|
}
|
|
|
|
return &Handler{
|
|
prefix: prefix,
|
|
trace: tracerProvider,
|
|
storage: NewStorage(client),
|
|
cipher: cipher,
|
|
hosts: NewHostInfo(cfg, http.DefaultClient),
|
|
}, nil
|
|
}
|
|
|
|
// HandlerFunc returns a http.HandlerFunc that handles the mcp endpoints.
|
|
func (srv *Handler) HandlerFunc() http.HandlerFunc {
|
|
r := mux.NewRouter()
|
|
r.Use(cors.New(cors.Options{
|
|
AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
|
|
AllowedOrigins: []string{"*"},
|
|
AllowedHeaders: []string{"content-type", "mcp-protocol-version"},
|
|
}).Handler)
|
|
r.Methods(http.MethodOptions).HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
})
|
|
r.Path(path.Join(srv.prefix, registerEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.RegisterClient)
|
|
r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize)
|
|
r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback)
|
|
r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token)
|
|
r.Path(path.Join(srv.prefix, listRoutesEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.ListRoutes)
|
|
r.Path(path.Join(srv.prefix, connectEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Connect)
|
|
|
|
return r.ServeHTTP
|
|
}
|
|
|
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
|
|
|
func getDatabrokerServiceClient(
|
|
ctx context.Context,
|
|
cfg *config.Config,
|
|
tracerProvider oteltrace.TracerProvider,
|
|
) (databroker.DataBrokerServiceClient, error) {
|
|
sharedKey, err := cfg.Options.GetSharedKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("shared key: %w", err)
|
|
}
|
|
|
|
dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
|
|
OutboundPort: cfg.OutboundPort,
|
|
InstallationID: cfg.Options.InstallationID,
|
|
ServiceName: cfg.Options.Services,
|
|
SignedJWTKey: sharedKey,
|
|
}, googlegrpc.WithStatsHandler(otelgrpc.NewClientHandler(otelgrpc.WithTracerProvider(tracerProvider))))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("databroker connection: %w", err)
|
|
}
|
|
return databroker.NewDataBrokerServiceClient(dataBrokerConn), nil
|
|
}
|