mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-22 20:48:10 +02:00
## Summary Fixes to MCP code registration and token requests. 1. ease some requirements on fields that are RECOMMENDED 2. fill in defaults 3. store both request and response in the client registration 4. check client secret in the /token request ## Related issues - Fixes https://linear.app/pomerium/issue/ENG-2462/mcp-ignore-unknown-grant-types-in-the-client-registration - Fixes https://linear.app/pomerium/issue/ENG-2461/mcp-support-client-secret-in-dynamic-client-registration ## User Explanation <!-- How would you explain this change to the user? If this change doesn't create any user-facing changes, you can leave this blank. If filled out, add the `docs` label --> ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review
191 lines
4.7 KiB
Go
191 lines
4.7 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
|
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
type Storage struct {
|
|
client databroker.DataBrokerServiceClient
|
|
}
|
|
|
|
// NewStorage creates a new Storage instance.
|
|
func NewStorage(
|
|
client databroker.DataBrokerServiceClient,
|
|
) *Storage {
|
|
return &Storage{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (storage *Storage) RegisterClient(
|
|
ctx context.Context,
|
|
req *rfc7591v1.ClientRegistration,
|
|
) (string, error) {
|
|
data := protoutil.NewAny(req)
|
|
id := uuid.NewString()
|
|
_, err := storage.client.Put(ctx, &databroker.PutRequest{
|
|
Records: []*databroker.Record{{
|
|
Id: id,
|
|
Data: data,
|
|
Type: data.TypeUrl,
|
|
}},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (storage *Storage) GetClient(
|
|
ctx context.Context,
|
|
id string,
|
|
) (*rfc7591v1.ClientRegistration, error) {
|
|
v := new(rfc7591v1.ClientRegistration)
|
|
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
|
|
Type: protoutil.GetTypeURL(v),
|
|
Id: id,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get client by ID: %w", err)
|
|
}
|
|
|
|
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err)
|
|
}
|
|
|
|
return v, nil
|
|
}
|
|
|
|
func (storage *Storage) CreateAuthorizationRequest(
|
|
ctx context.Context,
|
|
req *oauth21proto.AuthorizationRequest,
|
|
) (string, error) {
|
|
data := protoutil.NewAny(req)
|
|
id := uuid.NewString()
|
|
_, err := storage.client.Put(ctx, &databroker.PutRequest{
|
|
Records: []*databroker.Record{{
|
|
Id: id,
|
|
Data: data,
|
|
Type: data.TypeUrl,
|
|
}},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (storage *Storage) GetAuthorizationRequest(
|
|
ctx context.Context,
|
|
id string,
|
|
) (*oauth21proto.AuthorizationRequest, error) {
|
|
v := new(oauth21proto.AuthorizationRequest)
|
|
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
|
|
Type: protoutil.GetTypeURL(v),
|
|
Id: id,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get authorization request by ID: %w", err)
|
|
}
|
|
|
|
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal authorization request: %w", err)
|
|
}
|
|
|
|
return v, nil
|
|
}
|
|
|
|
func (storage *Storage) DeleteAuthorizationRequest(
|
|
ctx context.Context,
|
|
id string,
|
|
) error {
|
|
data := protoutil.NewAny(&oauth21proto.AuthorizationRequest{})
|
|
_, err := storage.client.Put(ctx, &databroker.PutRequest{
|
|
Records: []*databroker.Record{{
|
|
Id: id,
|
|
Data: data,
|
|
Type: data.TypeUrl,
|
|
DeletedAt: timestamppb.Now(),
|
|
}},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete authorization request by ID: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (storage *Storage) GetSession(ctx context.Context, id string) (*session.Session, error) {
|
|
v := new(session.Session)
|
|
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
|
|
Type: protoutil.GetTypeURL(v),
|
|
Id: id,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get session by ID: %w", err)
|
|
}
|
|
|
|
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
|
|
}
|
|
|
|
return v, nil
|
|
}
|
|
|
|
// StoreUpstreamOAuth2Token stores the upstream OAuth2 token for a given session and a host
|
|
func (storage *Storage) StoreUpstreamOAuth2Token(
|
|
ctx context.Context,
|
|
host string,
|
|
userID string,
|
|
token *oauth21proto.TokenResponse,
|
|
) error {
|
|
data := protoutil.NewAny(token)
|
|
_, err := storage.client.Put(ctx, &databroker.PutRequest{
|
|
Records: []*databroker.Record{{
|
|
Id: fmt.Sprintf("%s|%s", host, userID),
|
|
Data: data,
|
|
Type: data.TypeUrl,
|
|
}},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to store upstream oauth2 token for session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetUpstreamOAuth2Token loads the upstream OAuth2 token for a given session and a host
|
|
func (storage *Storage) GetUpstreamOAuth2Token(
|
|
ctx context.Context,
|
|
host string,
|
|
userID string,
|
|
) (*oauth21proto.TokenResponse, error) {
|
|
v := new(oauth21proto.TokenResponse)
|
|
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
|
|
Type: protoutil.GetTypeURL(v),
|
|
Id: fmt.Sprintf("%s|%s", host, userID),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get upstream oauth2 token for session: %w", err)
|
|
}
|
|
|
|
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal upstream oauth2 token: %w", err)
|
|
}
|
|
|
|
return v, nil
|
|
}
|