From 8a89c975d961ea21a3dcb3f0e2a8799cbbfd076a Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Tue, 8 Jul 2025 09:46:45 -0700 Subject: [PATCH] mcp: delete upstream oauth2 token (#5707) ## Summary Adds `POST /.pomerium/mcp/routes/disconnect` that allows an MCP client application to request upstream OAuth2 tokens to be purged, so that a user may get a new ones with possibly different scopes. ## Related issues Fix https://linear.app/pomerium/issue/ENG-2545/mcp-user-should-be-able-to-purge-their-upstream-oauth2-token ## User Explanation ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review --- internal/mcp/handler.go | 4 +- internal/mcp/handler_connect.go | 107 ++++++++++++++++++++++++++-- internal/mcp/handler_list_routes.go | 12 ++-- internal/mcp/storage.go | 21 ++++++ internal/mcp/storage_test.go | 9 +++ 5 files changed, 141 insertions(+), 12 deletions(-) diff --git a/internal/mcp/handler.go b/internal/mcp/handler.go index a5b70cb06..2e3b1b52a 100644 --- a/internal/mcp/handler.go +++ b/internal/mcp/handler.go @@ -30,6 +30,7 @@ const ( tokenEndpoint = "/token" listRoutesEndpoint = "/routes" connectEndpoint = "/connect" + disconnectEndpoint = "/routes/disconnect" ) type Handler struct { @@ -83,7 +84,8 @@ func (srv *Handler) HandlerFunc() http.HandlerFunc { r.Path(path.Join(srv.prefix, oauthCallbackEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.OAuthCallback) r.Path(path.Join(srv.prefix, tokenEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.Token) r.Path(path.Join(srv.prefix, listRoutesEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.ListRoutes) - r.Path(path.Join(srv.prefix, connectEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.Connect) + r.Path(path.Join(srv.prefix, connectEndpoint)).Methods(http.MethodGet).HandlerFunc(srv.ConnectGet) + r.Path(path.Join(srv.prefix, disconnectEndpoint)).Methods(http.MethodPost).HandlerFunc(srv.DisconnectRoutes) return r.ServeHTTP } diff --git a/internal/mcp/handler_connect.go b/internal/mcp/handler_connect.go index ed7e5aaf7..31640fd39 100644 --- a/internal/mcp/handler_connect.go +++ b/internal/mcp/handler_connect.go @@ -1,6 +1,7 @@ package mcp import ( + "encoding/json" "fmt" "net/http" "net/url" @@ -15,16 +16,11 @@ import ( const InternalConnectClientID = "pomerium-connect-7549ebe0-a67d-4d2b-a90d-d0a483b85f72" -// Connect is a helper method for MCP clients to ensure that the current user +// ConnectGet is a helper method for MCP clients to ensure that the current user // has an active upstream Oauth2 session for the route. // GET /mcp/connect?redirect_url= // It will redirect to the provided redirect_url once the user has an active session. -func (srv *Handler) Connect(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "invalid method", http.StatusMethodNotAllowed) - return - } - +func (srv *Handler) ConnectGet(w http.ResponseWriter, r *http.Request) { ctx := r.Context() redirectURL, err := srv.checkClientRedirectURL(r) @@ -118,3 +114,100 @@ func (srv *Handler) checkClientRedirectURL(r *http.Request) (string, error) { } return redirectURL, nil } + +// DisconnectRoutes is a bulk helper method for MCP clients to purge upstream OAuth2 tokens +// for multiple routes. This is necessary because frontend clients cannot execute direct +// DELETE calls to other routes. +// +// POST /mcp/routes/disconnect +// +// Request body should contain a JSON object with a "routes" array: +// +// { +// "routes": ["https://server1.example.com", "https://server2.example.com"] +// } +// +// Response returns the same format as GET /mcp/routes, showing the updated connection status: +// +// { +// "servers": [ +// { +// "name": "Server 1", +// "url": "https://server1.example.com", +// "connected": false, +// "needs_oauth": true +// }, +// { +// "name": "Server 2", +// "url": "https://server2.example.com", +// "connected": false, +// "needs_oauth": true +// } +// ] +// } +func (srv *Handler) DisconnectRoutes(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + claims, err := getClaimsFromRequest(r) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to get claims from request") + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + + userID, ok := getUserIDFromClaims(claims) + if !ok { + log.Ctx(ctx).Error().Msg("user id is not present, this is a misconfigured request") + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + type disconnectRequest struct { + Routes []string `json:"routes"` + } + + var req disconnectRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to decode disconnect request") + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + if len(req.Routes) == 0 { + log.Ctx(ctx).Error().Msg("no routes provided in disconnect request") + http.Error(w, "no routes provided", http.StatusBadRequest) + return + } + + for _, routeURL := range req.Routes { + parsedURL, err := url.Parse(routeURL) + if err != nil { + log.Ctx(ctx).Error().Err(err).Str("url", routeURL).Msg("failed to parse route URL") + continue + } + + host := parsedURL.Host + if host == "" { + log.Ctx(ctx).Error().Str("url", routeURL).Msg("route URL has no host") + continue + } + + requiresUpstreamOAuth2Token := srv.hosts.HasOAuth2ConfigForHost(host) + if !requiresUpstreamOAuth2Token { + log.Ctx(ctx).Debug().Str("host", host).Msg("host does not require oauth2 token - ignoring") + continue + } + + err = srv.storage.DeleteUpstreamOAuth2Token(ctx, host, userID) + if err != nil { + log.Ctx(ctx).Error().Err(err).Str("host", host).Msg("failed to delete upstream oauth2 token") + } + } + + err = srv.listMCPServersForUser(ctx, w, userID) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to list MCP servers after disconnect") + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } +} diff --git a/internal/mcp/handler_list_routes.go b/internal/mcp/handler_list_routes.go index d8f784566..7017ca1bf 100644 --- a/internal/mcp/handler_list_routes.go +++ b/internal/mcp/handler_list_routes.go @@ -40,6 +40,10 @@ func (srv *Handler) listMCPServers(w http.ResponseWriter, r *http.Request) error return fmt.Errorf("user id is not present in claims") } + return srv.listMCPServersForUser(r.Context(), w, userID) +} + +func (srv *Handler) listMCPServersForUser(ctx context.Context, w http.ResponseWriter, userID string) error { var servers []serverInfo for v := range srv.hosts.All() { servers = append(servers, serverInfo{ @@ -47,12 +51,12 @@ func (srv *Handler) listMCPServers(w http.ResponseWriter, r *http.Request) error Description: v.Description, LogoURL: v.LogoURL, URL: v.URL, - needsOauth: v.Config != nil, + NeedsOauth: v.Config != nil, host: v.Host, }) } - servers, err = srv.checkHostsConnectedForUser(r.Context(), userID, servers) + servers, err := srv.checkHostsConnectedForUser(ctx, userID, servers) if err != nil { return fmt.Errorf("failed to check hosts connected for user %s: %w", userID, err) } @@ -79,7 +83,7 @@ func (srv *Handler) checkHostsConnectedForUser( ) ([]serverInfo, error) { eg, ctx := errgroup.WithContext(ctx) for i := range servers { - if !servers[i].needsOauth { + if !servers[i].NeedsOauth { servers[i].Connected = true continue } @@ -106,6 +110,6 @@ type serverInfo struct { LogoURL string `json:"logo_url,omitempty"` URL string `json:"url"` Connected bool `json:"connected"` - needsOauth bool `json:"-"` + NeedsOauth bool `json:"needs_oauth"` host string `json:"-"` } diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 86f67309e..ac88fe6e0 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -189,3 +189,24 @@ func (storage *Storage) GetUpstreamOAuth2Token( return v, nil } + +// DeleteUpstreamOAuth2Token removes the upstream OAuth2 token for a given host and user ID +func (storage *Storage) DeleteUpstreamOAuth2Token( + ctx context.Context, + host string, + userID string, +) error { + data := protoutil.NewAny(&oauth21proto.TokenResponse{}) + _, err := storage.client.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{{ + Id: fmt.Sprintf("%s|%s", host, userID), + Data: data, + Type: data.TypeUrl, + DeletedAt: timestamppb.Now(), + }}, + }) + if err != nil { + return fmt.Errorf("failed to delete upstream oauth2 token for session: %w", err) + } + return nil +} diff --git a/internal/mcp/storage_test.go b/internal/mcp/storage_test.go index 9b23ec1da..04d2e6a95 100644 --- a/internal/mcp/storage_test.go +++ b/internal/mcp/storage_test.go @@ -102,5 +102,14 @@ func TestStorage(t *testing.T) { _, err = storage.GetUpstreamOAuth2Token(ctx, "host", "non-existent-user-id") assert.Equal(t, codes.NotFound, status.Code(err)) + + err = storage.DeleteUpstreamOAuth2Token(ctx, "host", "user-id") + require.NoError(t, err) + + _, err = storage.GetUpstreamOAuth2Token(ctx, "host", "user-id") + assert.Equal(t, codes.NotFound, status.Code(err)) + + err = storage.DeleteUpstreamOAuth2Token(ctx, "non-existent-host", "user-id") + assert.NoError(t, err) }) }