pomerium/internal/mcp/handler.go
2025-04-23 13:39:27 -04:00

97 lines
2.9 KiB
Go

package mcp
import (
"context"
"fmt"
"net/http"
"path"
"github.com/gorilla/mux"
"github.com/rs/cors"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
oteltrace "go.opentelemetry.io/otel/trace"
googlegrpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
)
const (
DefaultPrefix = "/.pomerium/mcp"
authorizationEndpoint = "/authorize"
oauthCallbackEndpoint = "/oauth/callback"
registerEndpoint = "/register"
revocationEndpoint = "/revoke"
tokenEndpoint = "/token"
)
type Handler struct {
prefix string
trace oteltrace.TracerProvider
storage *Storage
}
func New(
ctx context.Context,
prefix string,
cfg *config.Config,
) (*Handler, error) {
tracerProvider := trace.NewTracerProvider(ctx, "MCP")
client, err := getDatabrokerServiceClient(ctx, cfg, tracerProvider)
if err != nil {
return nil, fmt.Errorf("databroker client: %w", err)
}
return &Handler{
prefix: prefix,
trace: tracerProvider,
storage: NewStorage(client),
}, nil
}
// HandlerFunc returns a http.HandlerFunc that handles the mcp endpoints.
func (srv *Handler) HandlerFunc() http.HandlerFunc {
r := mux.NewRouter()
r.Use(cors.New(cors.Options{
AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"content-type", "mcp-protocol-version"},
}).Handler)
r.Methods(http.MethodOptions).HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
r.Path(path.Join(srv.prefix, registerEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.RegisterClient)
r.Path(path.Join(srv.prefix, authorizationEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Authorize)
r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback)
r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token)
return r.ServeHTTP
}
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
func getDatabrokerServiceClient(
ctx context.Context,
cfg *config.Config,
tracerProvider oteltrace.TracerProvider,
) (databroker.DataBrokerServiceClient, error) {
sharedKey, err := cfg.Options.GetSharedKey()
if err != nil {
return nil, fmt.Errorf("shared key: %w", err)
}
dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey,
}, googlegrpc.WithStatsHandler(otelgrpc.NewClientHandler(otelgrpc.WithTracerProvider(tracerProvider))))
if err != nil {
return nil, fmt.Errorf("databroker connection: %w", err)
}
return databroker.NewDataBrokerServiceClient(dataBrokerConn), nil
}