diff --git a/internal/mcp/handler.go b/internal/mcp/handler.go index 0ecb86cfa..749d40edd 100644 --- a/internal/mcp/handler.go +++ b/internal/mcp/handler.go @@ -2,14 +2,19 @@ package mcp import ( "context" + "fmt" "net/http" "path" "github.com/gorilla/mux" "github.com/rs/cors" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" oteltrace "go.opentelemetry.io/otel/trace" + googlegrpc "google.golang.org/grpc" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/grpc" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/telemetry/trace" ) @@ -24,19 +29,27 @@ const ( ) type Handler struct { - prefix string - trace oteltrace.TracerProvider + prefix string + trace oteltrace.TracerProvider + storage *Storage } func New( ctx context.Context, prefix string, - _ *config.Config, + cfg *config.Config, ) (*Handler, error) { tracerProvider := trace.NewTracerProvider(ctx, "MCP") + + client, err := getDatabrokerServiceClient(ctx, cfg, tracerProvider) + if err != nil { + return nil, fmt.Errorf("databroker client: %w", err) + } + return &Handler{ - prefix: prefix, - trace: tracerProvider, + prefix: prefix, + trace: tracerProvider, + storage: NewStorage(client), }, nil } @@ -58,3 +71,27 @@ func (srv *Handler) HandlerFunc() http.HandlerFunc { return r.ServeHTTP } + +var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) + +func getDatabrokerServiceClient( + ctx context.Context, + cfg *config.Config, + tracerProvider oteltrace.TracerProvider, +) (databroker.DataBrokerServiceClient, error) { + sharedKey, err := cfg.Options.GetSharedKey() + if err != nil { + return nil, fmt.Errorf("shared key: %w", err) + } + + dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{ + OutboundPort: cfg.OutboundPort, + InstallationID: cfg.Options.InstallationID, + ServiceName: cfg.Options.Services, + SignedJWTKey: sharedKey, + }, googlegrpc.WithStatsHandler(otelgrpc.NewClientHandler(otelgrpc.WithTracerProvider(tracerProvider)))) + if err != nil { + return nil, fmt.Errorf("databroker connection: %w", err) + } + return databroker.NewDataBrokerServiceClient(dataBrokerConn), nil +} diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go new file mode 100644 index 000000000..94af2ac0c --- /dev/null +++ b/internal/mcp/storage.go @@ -0,0 +1,16 @@ +package mcp + +import "github.com/pomerium/pomerium/pkg/grpc/databroker" + +type Storage struct { + client databroker.DataBrokerServiceClient +} + +// NewStorage creates a new Storage instance. +func NewStorage( + client databroker.DataBrokerServiceClient, +) *Storage { + return &Storage{ + client: client, + } +} diff --git a/proxy/handlers.go b/proxy/handlers.go index f5498aed2..5856ba1b7 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -25,7 +25,7 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) * h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) // model context protocol - h.PathPrefix("/mcp").Handler(p.mcp.HandlerFunc()) + h.PathPrefix("/mcp").Handler(p.mcp.Load().HandlerFunc()) // special pomerium endpoints for users to view their session h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet) diff --git a/proxy/proxy.go b/proxy/proxy.go index 064a87d58..935ec79d8 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -64,7 +64,7 @@ type Proxy struct { webauthn *webauthn.Handler tracerProvider oteltrace.TracerProvider logoProvider portal.LogoProvider - mcp *mcp.Handler + mcp *atomicutil.Value[*mcp.Handler] } // New takes a Proxy service from options and a validation function. @@ -87,7 +87,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}), currentRouter: atomicutil.NewValue(httputil.NewRouter()), logoProvider: portal.NewLogoProvider(), - mcp: mcp, + mcp: atomicutil.NewValue(mcp), } p.OnConfigChange(ctx, cfg) p.webauthn = webauthn.New(p.getWebauthnState) @@ -110,6 +110,13 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) { return } + mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") + } else { + p.mcp.Store(mcp) + } + p.currentConfig.Store(cfg) if err := p.setHandlers(ctx, cfg.Options); err != nil { log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")