From b56666135310555d1b18c2d2c20a184847c5b4ba Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Thu, 24 Apr 2025 14:54:31 -0400 Subject: [PATCH] mcp: client registration: store to the databroker (#5584) --- internal/mcp/handler_register_client.go | 88 ++++++++++++++++++++++++- internal/mcp/storage.go | 29 +++++++- 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/internal/mcp/handler_register_client.go b/internal/mcp/handler_register_client.go index c97b07a6c..9ef54f510 100644 --- a/internal/mcp/handler_register_client.go +++ b/internal/mcp/handler_register_client.go @@ -1,11 +1,95 @@ package mcp import ( + "encoding/json" + "io" "net/http" + "time" + + "github.com/bufbuild/protovalidate-go" + "google.golang.org/protobuf/encoding/protojson" + + "github.com/pomerium/pomerium/internal/log" + rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" ) +const maxClientRegistrationPayload = 1024 * 1024 // 1MB + // RegisterClient handles the /register endpoint. // It is used to register a new client with the MCP server. -func (srv *Handler) RegisterClient(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotImplemented) +func (srv *Handler) RegisterClient(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if r.Method != http.MethodPost { + http.Error(w, "invalid method", http.StatusMethodNotAllowed) + return + } + + src := io.LimitReader(r.Body, maxClientRegistrationPayload) + defer r.Body.Close() + + data, err := io.ReadAll(src) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to read request body") + http.Error(w, "failed to read request body", http.StatusBadRequest) + return + } + + v := new(rfc7591v1.ClientMetadata) + err = protojson.Unmarshal(data, v) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to unmarshal request body") + http.Error(w, "failed to unmarshal request body", http.StatusBadRequest) + return + } + + err = protovalidate.Validate(v) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to validate request body") + clientRegistrationBadRequest(w, err) + return + } + + id, err := srv.storage.RegisterClient(ctx, v) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to register client") + http.Error(w, "failed to register client", http.StatusInternalServerError) + } + + resp := struct { + *rfc7591v1.ClientMetadata + ClientID string `json:"client_id"` + ClientIDIssuedAt int64 `json:"client_id_issued_at"` + }{ + ClientMetadata: v, + ClientID: id, + ClientIDIssuedAt: time.Now().Unix(), + } + data, err = json.Marshal(resp) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to marshal response") + http.Error(w, "failed to marshal response", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(data) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to write response") + return + } +} + +func clientRegistrationBadRequest(w http.ResponseWriter, err error) { + v := struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + }{ + Error: "invalid_client_metadata", + ErrorDescription: err.Error(), + } + data, _ := json.Marshal(v) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(data) } diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 94af2ac0c..77d9ff821 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -1,6 +1,14 @@ package mcp -import "github.com/pomerium/pomerium/pkg/grpc/databroker" +import ( + "context" + + "github.com/google/uuid" + + rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) type Storage struct { client databroker.DataBrokerServiceClient @@ -14,3 +22,22 @@ func NewStorage( client: client, } } + +func (storage *Storage) RegisterClient( + ctx context.Context, + req *rfc7591v1.ClientMetadata, +) (string, error) { + data := protoutil.NewAny(req) + id := uuid.NewString() + _, err := storage.client.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{{ + Id: id, + Data: data, + Type: data.TypeUrl, + }}, + }) + if err != nil { + return "", err + } + return id, nil +}