pomerium/internal/mcp/handler.go
Denis Mishin b944e68232
mcp: implement connect (#5640)
## Summary

adds implementation of `/.pomerium/mcp/connect` method, that takes a
`redirect_url` parameter and would ensure the user goes thru required
redirects so that its session is hydrated with the upstream Oauth token
for the MCP server.
the `redirect_url` parameter host must match one of the _client_ mcp
routes (currently identified by the presence of `mcp:
pass_upstream_access_token: true` in the route.

## Related issues

Fix
https://linear.app/pomerium/issue/ENG-2321/mcp-support-handling-external-oauth-servers

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [x] ready for review
2025-06-02 17:19:34 -04:00

113 lines
3.5 KiB
Go

package mcp
import (
"context"
"crypto/cipher"
"fmt"
"net/http"
"path"
"github.com/gorilla/mux"
"github.com/rs/cors"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/singleflight"
googlegrpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
)
const (
DefaultPrefix = "/.pomerium/mcp"
authorizationEndpoint = "/authorize"
oauthCallbackEndpoint = "/oauth/callback"
registerEndpoint = "/register"
revocationEndpoint = "/revoke"
tokenEndpoint = "/token"
listRoutesEndpoint = "/routes"
connectEndpoint = "/connect"
)
type Handler struct {
prefix string
trace oteltrace.TracerProvider
storage *Storage
cipher cipher.AEAD
hosts *HostInfo
hostsSingleFlight singleflight.Group
}
func New(
ctx context.Context,
prefix string,
cfg *config.Config,
) (*Handler, error) {
tracerProvider := trace.NewTracerProvider(ctx, "MCP")
client, err := getDatabrokerServiceClient(ctx, cfg, tracerProvider)
if err != nil {
return nil, fmt.Errorf("databroker client: %w", err)
}
cipher, err := getCipher(cfg)
if err != nil {
return nil, fmt.Errorf("get cipher: %w", err)
}
return &Handler{
prefix: prefix,
trace: tracerProvider,
storage: NewStorage(client),
cipher: cipher,
hosts: NewHostInfo(cfg, http.DefaultClient),
}, nil
}
// HandlerFunc returns a http.HandlerFunc that handles the mcp endpoints.
func (srv *Handler) HandlerFunc() http.HandlerFunc {
r := mux.NewRouter()
r.Use(cors.New(cors.Options{
AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"content-type", "mcp-protocol-version"},
}).Handler)
r.Methods(http.MethodOptions).HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
r.Path(path.Join(srv.prefix, registerEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.RegisterClient)
r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize)
r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback)
r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token)
r.Path(path.Join(srv.prefix, listRoutesEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.ListRoutes)
r.Path(path.Join(srv.prefix, connectEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Connect)
return r.ServeHTTP
}
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
func getDatabrokerServiceClient(
ctx context.Context,
cfg *config.Config,
tracerProvider oteltrace.TracerProvider,
) (databroker.DataBrokerServiceClient, error) {
sharedKey, err := cfg.Options.GetSharedKey()
if err != nil {
return nil, fmt.Errorf("shared key: %w", err)
}
dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey,
}, googlegrpc.WithStatsHandler(otelgrpc.NewClientHandler(otelgrpc.WithTracerProvider(tracerProvider))))
if err != nil {
return nil, fmt.Errorf("databroker connection: %w", err)
}
return databroker.NewDataBrokerServiceClient(dataBrokerConn), nil
}