Merge branch 'wasaga/mcp-client-registration' into wasaga/mcp-authorize-request

This commit is contained in:
Denis Mishin 2025-04-23 22:55:48 -04:00
commit 52af622cc4
6 changed files with 114 additions and 1405 deletions

View file

@ -1,12 +1,13 @@
package mcp package mcp
import ( import (
"encoding/json"
"io" "io"
"net/http" "net/http"
"time"
"github.com/bufbuild/protovalidate-go" "github.com/bufbuild/protovalidate-go"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
@ -33,7 +34,7 @@ func (srv *Handler) RegisterClient(w http.ResponseWriter, r *http.Request) {
return return
} }
v := new(rfc7591v1.ClientRegistrationRequest) v := new(rfc7591v1.ClientMetadata)
err = protojson.Unmarshal(data, v) err = protojson.Unmarshal(data, v)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to unmarshal request body") log.Ctx(ctx).Error().Err(err).Msg("failed to unmarshal request body")
@ -44,19 +45,26 @@ func (srv *Handler) RegisterClient(w http.ResponseWriter, r *http.Request) {
err = protovalidate.Validate(v) err = protovalidate.Validate(v)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to validate request body") log.Ctx(ctx).Error().Err(err).Msg("failed to validate request body")
clientRegistrationError(w, err, rfc7591v1.ErrorCode_ERROR_CODE_INVALID_CLIENT_METADATA) clientRegistrationBadRequest(w, err)
return return
} }
resp, err := srv.storage.RegisterClient(ctx, v) id, err := srv.storage.RegisterClient(ctx, v)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to register client") log.Ctx(ctx).Error().Err(err).Msg("failed to register client")
http.Error(w, "failed to register client", http.StatusInternalServerError) http.Error(w, "failed to register client", http.StatusInternalServerError)
} }
data, err = protojson.MarshalOptions{ resp := struct {
UseProtoNames: true, *rfc7591v1.ClientMetadata
}.Marshal(resp) ClientID string `json:"client_id"`
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
}{
ClientMetadata: v,
ClientID: id,
ClientIDIssuedAt: time.Now().Unix(),
}
data, err = json.Marshal(resp)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to marshal response") log.Ctx(ctx).Error().Err(err).Msg("failed to marshal response")
http.Error(w, "failed to marshal response", http.StatusInternalServerError) http.Error(w, "failed to marshal response", http.StatusInternalServerError)
@ -70,13 +78,18 @@ func (srv *Handler) RegisterClient(w http.ResponseWriter, r *http.Request) {
} }
} }
func clientRegistrationError(w http.ResponseWriter, err error, code rfc7591v1.ErrorCode) { func clientRegistrationBadRequest(w http.ResponseWriter, err error) {
v := &rfc7591v1.ClientRegistrationErrorResponse{ v := struct {
Error: code, Error string `json:"error"`
ErrorDescription: proto.String(err.Error()), ErrorDescription string `json:"error_description,omitempty"`
}{
Error: "invalid_client_metadata",
ErrorDescription: err.Error(),
} }
data, _ := protojson.Marshal(v) data, _ := json.Marshal(v)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(data) _, _ = w.Write(data)
} }

View file

@ -29,11 +29,11 @@ func NewStorage(
func (storage *Storage) RegisterClient( func (storage *Storage) RegisterClient(
ctx context.Context, ctx context.Context,
req *rfc7591v1.ClientRegistrationRequest, req *rfc7591v1.ClientMetadata,
) (*rfc7591v1.ClientInformationResponse, error) { ) (string, error) {
data := protoutil.NewAny(req) data := protoutil.NewAny(req)
id := uuid.NewString() id := uuid.NewString()
rec, err := storage.client.Put(ctx, &databroker.PutRequest{ _, err := storage.client.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{{ Records: []*databroker.Record{{
Id: id, Id: id,
Data: data, Data: data,
@ -41,20 +41,15 @@ func (storage *Storage) RegisterClient(
}}, }},
}) })
if err != nil { if err != nil {
return nil, err return "", err
} }
if len(rec.Records) == 0 { return id, nil
return nil, fmt.Errorf("no records returned")
}
now := rec.Records[0].GetModifiedAt().Seconds
return getClientInformation(id, now, req), nil
} }
func (storage *Storage) GetClientByID( func (storage *Storage) GetClientByID(
ctx context.Context, ctx context.Context,
id string, id string,
) (*rfc7591v1.ClientRegistrationRequest, error) { ) (*rfc7591v1.ClientMetadata, error) {
rec, err := storage.client.Get(ctx, &databroker.GetRequest{ rec, err := storage.client.Get(ctx, &databroker.GetRequest{
Id: id, Id: id,
}) })
@ -62,7 +57,7 @@ func (storage *Storage) GetClientByID(
return nil, fmt.Errorf("failed to get client by ID: %w", err) return nil, fmt.Errorf("failed to get client by ID: %w", err)
} }
v := new(rfc7591v1.ClientRegistrationRequest) v := new(rfc7591v1.ClientMetadata)
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{}) err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err) return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err)
@ -89,35 +84,3 @@ func (storage *Storage) CreateAuthorizationRequest(
} }
return id, nil return id, nil
} }
func getClientInformation(
id string,
issuedAt int64,
req *rfc7591v1.ClientRegistrationRequest,
) *rfc7591v1.ClientInformationResponse {
return &rfc7591v1.ClientInformationResponse{
ClientId: id,
ClientIdIssuedAt: proto.Int64(issuedAt),
RedirectUris: req.RedirectUris,
TokenEndpointAuthMethod: req.TokenEndpointAuthMethod,
GrantTypes: req.GrantTypes,
ResponseTypes: req.ResponseTypes,
ClientName: req.ClientName,
ClientNameLocalized: req.ClientNameLocalized,
ClientUri: req.ClientUri,
ClientUriLocalized: req.ClientUriLocalized,
LogoUri: req.LogoUri,
LogoUriLocalized: req.LogoUriLocalized,
Scope: req.Scope,
Contacts: req.Contacts,
TosUri: req.TosUri,
TosUriLocalized: req.TosUriLocalized,
PolicyUri: req.PolicyUri,
PolicyUriLocalized: req.PolicyUriLocalized,
JwksUri: req.JwksUri,
Jwks: req.Jwks,
SoftwareId: req.SoftwareId,
SoftwareVersion: req.SoftwareVersion,
SoftwareStatement: req.SoftwareStatement,
}
}

View file

@ -8,7 +8,7 @@ import (
) )
func ValidateAuthorizationRequest( func ValidateAuthorizationRequest(
client *rfc7591v1.ClientRegistrationRequest, client *rfc7591v1.ClientMetadata,
req *gen.AuthorizationRequest, req *gen.AuthorizationRequest,
) error { ) error {
if err := ValidateAuthorizationRequestRedirectURI(client, req.RedirectUri); err != nil { if err := ValidateAuthorizationRequestRedirectURI(client, req.RedirectUri); err != nil {
@ -18,7 +18,7 @@ func ValidateAuthorizationRequest(
} }
func ValidateAuthorizationRequestRedirectURI( func ValidateAuthorizationRequestRedirectURI(
client *rfc7591v1.ClientRegistrationRequest, client *rfc7591v1.ClientMetadata,
redirectURI *string, redirectURI *string,
) error { ) error {
if len(client.RedirectUris) == 0 { if len(client.RedirectUris) == 0 {

View file

@ -4,23 +4,24 @@ import (
"testing" "testing"
"github.com/zeebo/assert" "github.com/zeebo/assert"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/oauth21" "github.com/pomerium/pomerium/internal/oauth21"
"github.com/pomerium/pomerium/internal/oauth21/gen" "github.com/pomerium/pomerium/internal/oauth21/gen"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
) )
func ValidateClientTest(t *testing.T) { func TestValidateRequest(t *testing.T) {
t.Parallel() t.Parallel()
for _, tc := range []struct { for _, tc := range []struct {
name string name string
client *rfc7591v1.ClientRegistrationRequest client *rfc7591v1.ClientMetadata
req *gen.AuthorizationRequest req *gen.AuthorizationRequest
err bool err bool
}{ }{
{ {
"optional redirect_uri, multiple redirect_uris", "optional redirect_uri, multiple redirect_uris",
&rfc7591v1.ClientRegistrationRequest{ &rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"}, RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"},
}, },
&gen.AuthorizationRequest{ &gen.AuthorizationRequest{
@ -28,6 +29,36 @@ func ValidateClientTest(t *testing.T) {
}, },
true, true,
}, },
{
"optional redirect_uri, single redirect_uri",
&rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback"},
},
&gen.AuthorizationRequest{
RedirectUri: nil,
},
false,
},
{
"matching redirect_uri",
&rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"},
},
&gen.AuthorizationRequest{
RedirectUri: proto.String("https://example.com/callback"),
},
false,
},
{
"non-matching redirect_uri",
&rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"},
},
&gen.AuthorizationRequest{
RedirectUri: proto.String("https://example.com/invalid-callback"),
},
true,
},
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()

File diff suppressed because it is too large Load diff

View file

@ -285,263 +285,3 @@ message ClientMetadata {
// types" // types"
// }; // };
} }
// Represents the request sent to the Client Registration Endpoint (RFC 7591
// Section 3.1).
message ClientRegistrationRequest {
// Fields correspond to ClientMetadata, indicating requested values.
// REQUIRED for redirect flows.
repeated string redirect_uris = 1 [ (buf.validate.field).repeated = {
min_items : 1,
items : {string : {uri : true, min_len : 1}}
} ];
// OPTIONAL. Default is "client_secret_basic".
optional string token_endpoint_auth_method = 2
[ (buf.validate.field).string = {
in : [ "none", "client_secret_post", "client_secret_basic" ]
} ];
// OPTIONAL. Default is ["authorization_code"].
repeated string grant_types = 3
[ (buf.validate.field).repeated .items.string = {
in : [
"authorization_code",
"implicit",
"password",
"client_credentials",
"refresh_token",
"urn:ietf:params:oauth:grant-type:jwt-bearer",
"urn:ietf:params:oauth:grant-type:saml2-bearer"
]
} ];
// OPTIONAL. Default is ["code"].
repeated string response_types = 4 [
(buf.validate.field).repeated .items.string = {in : [ "code", "token" ]}
];
// OPTIONAL. RECOMMENDED.
optional string client_name = 5
[ (buf.validate.field).string = {min_len : 1, max_len : 255} ];
// OPTIONAL.
map<string, string> client_name_localized = 6 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {min_len : 1, max_len : 255}}
} ];
// OPTIONAL. RECOMMENDED.
optional string client_uri = 7 [ (buf.validate.field).string.uri = true ];
// OPTIONAL.
map<string, string> client_uri_localized = 8 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL.
optional string logo_uri = 9 [ (buf.validate.field).string.uri = true ];
// OPTIONAL.
map<string, string> logo_uri_localized = 10 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL.
optional string scope = 11 [
(buf.validate.field).string = {pattern : "^\\S+( \\S+)*$", min_len : 1}
];
// OPTIONAL.
repeated string contacts = 12
[ (buf.validate.field).repeated .items.string.email = true ];
// OPTIONAL.
optional string tos_uri = 13 [ (buf.validate.field).string.uri = true ];
// OPTIONAL.
map<string, string> tos_uri_localized = 14 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL.
optional string policy_uri = 15 [ (buf.validate.field).string.uri = true ];
// OPTIONAL.
map<string, string> policy_uri_localized = 16 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL. Mutually exclusive with `jwks`.
optional string jwks_uri = 17 [ (buf.validate.field).string.uri = true ];
// OPTIONAL. Mutually exclusive with `jwks_uri`.
optional JsonWebKeySet jwks = 18;
// OPTIONAL.
optional string software_id = 19
[ (buf.validate.field).string = {min_len : 1, max_len : 255} ];
// OPTIONAL.
optional string software_version = 20
[ (buf.validate.field).string = {min_len : 1, max_len : 255} ];
// OPTIONAL. A software statement containing client metadata values about the
// client software as claims.
optional string software_statement = 21 [ (buf.validate.field).string = {
min_len : 1,
pattern : "^[a-zA-Z0-9\\-_]+\\.[a-zA-Z0-9\\-_]+\\.[a-zA-Z0-9\\-_]*$"
} ];
// Message level validation to ensure mutual exclusion of jwks and jwks_uri.
option (buf.validate.message).cel = {
id : "client_registration_request.jwks_mutual_exclusion",
expression : "!has(this.jwks_uri) || !has(this.jwks)",
message : "jwks_uri and jwks are mutually exclusive"
};
}
// Represents the successful response from the Client Registration Endpoint (RFC
// 7591 Section 3.2.1).
message ClientInformationResponse {
// REQUIRED. OAuth 2.0 client identifier string issued by the authorization
// server.
string client_id = 1 [
(buf.validate.field).required = true,
(buf.validate.field).string.min_len = 1
];
// OPTIONAL. OAuth 2.0 client secret string. Only issued for confidential
// clients.
optional string client_secret = 2 [ (buf.validate.field).string.min_len = 1 ];
// OPTIONAL. Time at which the client identifier was issued (Unix timestamp,
// seconds since epoch).
optional int64 client_id_issued_at = 3 [ (buf.validate.field).int64.gt = 0 ];
// REQUIRED if "client_secret" is issued, OPTIONAL otherwise. Time at which
// the client secret will expire (Unix timestamp, seconds since epoch), or 0
// if it will not expire.
optional int64 client_secret_expires_at = 4
[ (buf.validate.field).int64.gte = 0 ];
// Contains all registered metadata about this client, reflecting server
// state. REQUIRED if applicable to the client registration.
repeated string redirect_uris = 5 [ (buf.validate.field).repeated = {
min_items : 1,
items : {string : {uri : true, min_len : 1}}
} ];
// OPTIONAL (reflects registered value, may have default).
optional string token_endpoint_auth_method = 6
[ (buf.validate.field).string = {
in : [ "none", "client_secret_post", "client_secret_basic" ]
} ];
// OPTIONAL (reflects registered value, may have default).
repeated string grant_types = 7
[ (buf.validate.field).repeated .items.string = {
in : [
"authorization_code",
"implicit",
"password",
"client_credentials",
"refresh_token",
"urn:ietf:params:oauth:grant-type:jwt-bearer",
"urn:ietf:params:oauth:grant-type:saml2-bearer"
]
} ];
// OPTIONAL (reflects registered value, may have default).
repeated string response_types = 8 [
(buf.validate.field).repeated .items.string = {in : [ "code", "token" ]}
];
// OPTIONAL (reflects registered value).
optional string client_name = 9
[ (buf.validate.field).string = {min_len : 1, max_len : 255} ];
// OPTIONAL (reflects registered value).
map<string, string> client_name_localized = 10 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {min_len : 1, max_len : 255}}
} ];
// OPTIONAL (reflects registered value).
optional string client_uri = 11 [ (buf.validate.field).string.uri = true ];
// OPTIONAL (reflects registered value).
map<string, string> client_uri_localized = 12 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL (reflects registered value).
optional string logo_uri = 13 [ (buf.validate.field).string.uri = true ];
// OPTIONAL (reflects registered value).
map<string, string> logo_uri_localized = 14 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL (reflects registered value).
optional string scope = 15 [
(buf.validate.field).string = {pattern : "^\\S+( \\S+)*$", min_len : 1}
];
// OPTIONAL (reflects registered value).
repeated string contacts = 16
[ (buf.validate.field).repeated .items.string.email = true ];
// OPTIONAL (reflects registered value).
optional string tos_uri = 17 [ (buf.validate.field).string.uri = true ];
// OPTIONAL (reflects registered value).
map<string, string> tos_uri_localized = 18 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL (reflects registered value).
optional string policy_uri = 19 [ (buf.validate.field).string.uri = true ];
// OPTIONAL (reflects registered value).
map<string, string> policy_uri_localized = 20 [ (buf.validate.field).map = {
keys : {string : {pattern : "^[a-zA-Z]{1,8}(-[a-zA-Z0-9]{1,8})*$"}},
values : {string : {uri : true}}
} ];
// OPTIONAL (reflects registered value). Mutually exclusive with `jwks`.
optional string jwks_uri = 21 [ (buf.validate.field).string.uri = true ];
// OPTIONAL (reflects registered value). Mutually exclusive with `jwks_uri`.
optional JsonWebKeySet jwks = 22;
// OPTIONAL (reflects registered value).
optional string software_id = 23
[ (buf.validate.field).string = {min_len : 1, max_len : 255} ];
// OPTIONAL (reflects registered value).
optional string software_version = 24
[ (buf.validate.field).string = {min_len : 1, max_len : 255} ];
// OPTIONAL. If a software statement was used in the request, it MUST be
// returned unmodified.
optional string software_statement = 25 [ (buf.validate.field).string = {
min_len : 1,
pattern : "^[a-zA-Z0-9\\-_]+\\.[a-zA-Z0-9\\-_]+\\.[a-zA-Z0-9\\-_]*$"
} ];
// Message level validation
option (buf.validate.message).cel = {
id : "client_info_response.secret_expiry",
// client_secret_expires_at MUST be present if client_secret is present and
// non-empty
expression : "(!has(this.client_secret) || this.client_secret == '') || "
"has(this.client_secret_expires_at)",
message : "client_secret_expires_at is required when client_secret is "
"issued"
};
option (buf.validate.message).cel = {
id : "client_info_response.jwks_mutual_exclusion",
expression : "!has(this.jwks_uri) || !has(this.jwks)",
message : "jwks_uri and jwks fields are mutually exclusive in the response"
};
}
// Standard error codes for client registration errors (RFC 7591 Section 3.2.2).
enum ErrorCode {
ERROR_CODE_UNSPECIFIED = 0;
// The value of one or more redirection URIs is invalid.
ERROR_CODE_INVALID_REDIRECT_URI = 1;
// The value of one of the client metadata fields is invalid.
ERROR_CODE_INVALID_CLIENT_METADATA = 2;
// The software statement presented is invalid.
ERROR_CODE_INVALID_SOFTWARE_STATEMENT = 3;
// The software statement presented is not approved for use by this server.
ERROR_CODE_UNAPPROVED_SOFTWARE_STATEMENT = 4;
}
// Represents the error response from the Client Registration Endpoint (RFC 7591
// Section 3.2.2).
message ClientRegistrationErrorResponse {
// REQUIRED. Single ASCII error code string from the ErrorCode enum.
ErrorCode error = 1 [
(buf.validate.field).required = true,
(buf.validate.field).enum = {
defined_only : true,
not_in : [ 0 ]
}
];
// OPTIONAL. Human-readable ASCII text description of the error.
optional string error_description = 2
[ (buf.validate.field).string = {min_len : 1, max_len : 1024} ];
}