mcp: authorize request (pt1) (#5585)

This commit is contained in:
Denis Mishin 2025-04-24 14:59:12 -04:00 committed by GitHub
parent b566661353
commit 63ccf6ab93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 719 additions and 2 deletions

View file

@ -292,6 +292,7 @@ var internalPathsNeedingLogin = set.From([]string{
"/.pomerium/webauthn", "/.pomerium/webauthn",
"/.pomerium/routes", "/.pomerium/routes",
"/.pomerium/api/v1/routes", "/.pomerium/api/v1/routes",
"/.pomerium/mcp/authorize",
}) })
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) { func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {

View file

@ -1,10 +1,87 @@
package mcp package mcp
import ( import (
"errors"
"fmt"
"net/http" "net/http"
"github.com/go-jose/go-jose/v3/jwt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/oauth21"
) )
// Authorize handles the /authorize endpoint. // Authorize handles the /authorize endpoint.
func (srv *Handler) Authorize(w http.ResponseWriter, _ *http.Request) { func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented) if r.Method != http.MethodGet {
http.Error(w, "invalid method", http.StatusMethodNotAllowed)
return
}
ctx := r.Context()
sessionID, err := getSessionFromRequest(r)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("session is not present, this is a misconfigured request")
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
v, err := oauth21.ParseCodeGrantAuthorizeRequest(r, sessionID)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to parse authorization request")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest)
return
}
client, err := srv.storage.GetClientByID(ctx, v.ClientId)
if err != nil && status.Code(err) == codes.NotFound {
oauth21.ErrorResponse(w, http.StatusUnauthorized, oauth21.InvalidClient)
return
}
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to get client")
http.Error(w, "cannot fetch client", http.StatusInternalServerError)
return
}
if err := oauth21.ValidateAuthorizationRequest(client, v); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to validate authorization request")
ve := oauth21.Error{Code: oauth21.InvalidRequest}
_ = errors.As(err, &ve)
oauth21.ErrorResponse(w, http.StatusBadRequest, ve.Code)
return
}
_, err = srv.storage.CreateAuthorizationRequest(ctx, v)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to store authorization request")
http.Error(w, "cannot create authorization request", http.StatusInternalServerError)
return
}
http.Error(w, "not implemented", http.StatusNotImplemented)
}
func getSessionFromRequest(r *http.Request) (string, error) {
h := r.Header.Get(httputil.HeaderPomeriumJWTAssertion)
if h == "" {
return "", fmt.Errorf("missing %s header", httputil.HeaderPomeriumJWTAssertion)
}
token, err := jwt.ParseSigned(h)
if err != nil {
return "", fmt.Errorf("failed to parse JWT: %w", err)
}
var m map[string]any
_ = token.UnsafeClaimsWithoutVerification(&m)
sessionID, ok := m["sid"].(string)
if !ok {
return "", fmt.Errorf("missing session ID in JWT")
}
return sessionID, nil
} }

View file

@ -2,9 +2,13 @@ package mcp
import ( import (
"context" "context"
"fmt"
"github.com/google/uuid" "github.com/google/uuid"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
@ -41,3 +45,43 @@ func (storage *Storage) RegisterClient(
} }
return id, nil return id, nil
} }
func (storage *Storage) GetClientByID(
ctx context.Context,
id string,
) (*rfc7591v1.ClientMetadata, error) {
v := new(rfc7591v1.ClientMetadata)
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
Type: protoutil.GetTypeURL(v),
Id: id,
})
if err != nil {
return nil, fmt.Errorf("failed to get client by ID: %w", err)
}
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err)
}
return v, nil
}
func (storage *Storage) CreateAuthorizationRequest(
ctx context.Context,
req *oauth21proto.AuthorizationRequest,
) (string, error) {
data := protoutil.NewAny(req)
id := uuid.NewString()
_, err := storage.client.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{{
Id: id,
Data: data,
Type: data.TypeUrl,
}},
})
if err != nil {
return "", err
}
return id, nil
}

View file

@ -0,0 +1,64 @@
package mcp_test
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/mcp"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
"github.com/pomerium/pomerium/internal/testutil"
databroker_grpc "github.com/pomerium/pomerium/pkg/grpc/databroker"
)
func TestStorage(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute*5)
list := bufconn.Listen(1024 * 1024)
t.Cleanup(func() {
list.Close()
})
srv := databroker.New(ctx, noop.NewTracerProvider())
grpcServer := grpc.NewServer()
databroker_grpc.RegisterDataBrokerServiceServer(grpcServer, srv)
go func() {
if err := grpcServer.Serve(list); err != nil {
t.Errorf("failed to serve: %v", err)
}
}()
t.Cleanup(func() {
grpcServer.Stop()
})
conn, err := grpc.DialContext(ctx, "bufnet",
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return list.Dial()
}),
grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
client := databroker_grpc.NewDataBrokerServiceClient(conn)
t.Run("client registration", func(t *testing.T) {
storage := mcp.NewStorage(client)
id, err := storage.RegisterClient(ctx, &rfc7591v1.ClientMetadata{})
require.NoError(t, err)
require.NotEmpty(t, id)
_, err = storage.GetClientByID(ctx, id)
require.NoError(t, err)
})
}

View file

@ -0,0 +1,35 @@
package oauth21
import (
"fmt"
"net/http"
"github.com/bufbuild/protovalidate-go"
"github.com/pomerium/pomerium/internal/oauth21/gen"
)
// ParseCodeGrantAuthorizeRequest parses the authorization request for the code grant flow.
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1
// scopes are ignored
func ParseCodeGrantAuthorizeRequest(r *http.Request, sessionID string) (*gen.AuthorizationRequest, error) {
if err := r.ParseForm(); err != nil {
return nil, fmt.Errorf("failed to parse form: %w", err)
}
v := &gen.AuthorizationRequest{
ClientId: r.Form.Get("client_id"),
RedirectUri: optionalFormParam(r, "redirect_uri"),
ResponseType: r.Form.Get("response_type"),
State: optionalFormParam(r, "state"),
CodeChallenge: r.Form.Get("code_challenge"),
CodeChallengeMethod: optionalFormParam(r, "code_challenge_method"),
SessionId: sessionID,
}
if err := protovalidate.Validate(v); err != nil {
return nil, fmt.Errorf("invalid request: %w", err)
}
return v, nil
}

View file

@ -0,0 +1,19 @@
version: v2
inputs:
- directory: proto
plugins:
- remote: buf.build/protocolbuffers/go:v1.36.5
out: gen
opt:
- paths=source_relative
managed:
enabled: true
override:
- file_option: go_package_prefix
value: github.com/bufbuild/buf-examples/protovalidate/quickstart-go/start/gen
# Don't modify any file option or field option for protovalidate. Without
# this, generated Go will fail to compile.
disable:
- file_option: go_package
module: buf.build/bufbuild/protovalidate

View file

@ -0,0 +1,6 @@
# Generated by buf. DO NOT EDIT.
version: v2
deps:
- name: buf.build/bufbuild/protovalidate
commit: 7712fb530c574b95bc1d57c0877543c3
digest: b5:b3e9c9428384357e3b73e4d5a4614328b0a4b1595b10163bbe9483fa16204749274c41797bd49b0d716479c855aa35c1172a94f471fa120ba8369637fd138829

12
internal/oauth21/buf.yaml Normal file
View file

@ -0,0 +1,12 @@
version: v2
modules:
- path: proto
deps:
- buf.build/bufbuild/protovalidate
lint:
use:
- STANDARD
breaking:
use:
- FILE

42
internal/oauth21/error.go Normal file
View file

@ -0,0 +1,42 @@
package oauth21
import (
"encoding/json"
"net/http"
)
type ErrorCode string
const (
// InvalidRequest The request is missing a required parameter, includes an unsupported parameter value (other than grant type), repeats a parameter, includes multiple credentials, utilizes more than one mechanism for authenticating the client, contains a code_verifier although no code_challenge was sent in the authorization request, or is otherwise malformed.
InvalidRequest ErrorCode = "invalid_request"
// InvalidClient Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The authorization server MAY return an HTTP 401 (Unauthorized) status code to indicate which HTTP authentication schemes are supported. If the client attempted to authenticate via the Authorization request header field, the authorization server MUST respond with an HTTP 401 (Unauthorized) status code and include the WWW-Authenticate response header field matching the authentication scheme used by the client.
InvalidClient ErrorCode = "invalid_client"
// InvalidGrant The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirect URI used in the authorization request, or was issued to another client.
InvalidGrant ErrorCode = "invalid_grant"
// UnauthorizedClient The authenticated client is not authorized to use this authorization grant type.
UnauthorizedClient ErrorCode = "unauthorized_client"
// UnsupportedGrantType The authorization grant type is not supported by the authorization server.
UnsupportedGrantType ErrorCode = "unsupported_grant_type"
// InvalidScope The requested scope is invalid, unknown, malformed, or exceeds the scope granted by the resource owner.
InvalidScope ErrorCode = "invalid_scope"
)
type Error struct {
Code ErrorCode `json:"error"`
Description string `json:"error_description,omitempty"`
ErrorURI string `json:"error_uri,omitempty"`
}
func (e Error) Error() string {
return string(e.Code)
}
// ErrorResponse writes an error response according to https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-3.2.4
func ErrorResponse(w http.ResponseWriter, hc int, ec ErrorCode) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(hc)
if err := json.NewEncoder(w).Encode(Error{Code: ec}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View file

@ -0,0 +1,239 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.5
// protoc (unknown)
// source: authorization_request.proto
package gen
import (
_ "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// modeled based on
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1
type AuthorizationRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// The client identifier as described in Section 2.2.
ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"`
// OPTIONAL if only one redirect URI is registered for this client. REQUIRED
// if multiple redirict URIs are registered for this client.
RedirectUri *string `protobuf:"bytes,2,opt,name=redirect_uri,json=redirectUri,proto3,oneof" json:"redirect_uri,omitempty"`
// REQUIRED. The authorization endpoint supports different sets of request and
// response parameters. The client determines the type of flow by using a
// certain response_type value. This specification defines the value code,
// which must be used to signal that the client wants to use the authorization
// code flow.
ResponseType string `protobuf:"bytes,3,opt,name=response_type,json=responseType,proto3" json:"response_type,omitempty"`
// OPTIONAL. An opaque value used by the client to maintain state between the
// request and callback. The authorization server includes this value when
// redirecting the user agent back to the client.
State *string `protobuf:"bytes,4,opt,name=state,proto3,oneof" json:"state,omitempty"`
// OPTIONAL. The scope of the access request as described by Section 1.4.1.
Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"`
// REQUIRED, assumes https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1
CodeChallenge string `protobuf:"bytes,6,opt,name=code_challenge,json=codeChallenge,proto3" json:"code_challenge,omitempty"`
// OPTIONAL, defaults to plain if not present in the request. Code verifier
// transformation method is S256 or plain.
CodeChallengeMethod *string `protobuf:"bytes,7,opt,name=code_challenge_method,json=codeChallengeMethod,proto3,oneof" json:"code_challenge_method,omitempty"`
// session this authorization request is associated with.
// This is a Pomerium implementation specific field.
SessionId string `protobuf:"bytes,8,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthorizationRequest) Reset() {
*x = AuthorizationRequest{}
mi := &file_authorization_request_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthorizationRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthorizationRequest) ProtoMessage() {}
func (x *AuthorizationRequest) ProtoReflect() protoreflect.Message {
mi := &file_authorization_request_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthorizationRequest.ProtoReflect.Descriptor instead.
func (*AuthorizationRequest) Descriptor() ([]byte, []int) {
return file_authorization_request_proto_rawDescGZIP(), []int{0}
}
func (x *AuthorizationRequest) GetClientId() string {
if x != nil {
return x.ClientId
}
return ""
}
func (x *AuthorizationRequest) GetRedirectUri() string {
if x != nil && x.RedirectUri != nil {
return *x.RedirectUri
}
return ""
}
func (x *AuthorizationRequest) GetResponseType() string {
if x != nil {
return x.ResponseType
}
return ""
}
func (x *AuthorizationRequest) GetState() string {
if x != nil && x.State != nil {
return *x.State
}
return ""
}
func (x *AuthorizationRequest) GetScopes() []string {
if x != nil {
return x.Scopes
}
return nil
}
func (x *AuthorizationRequest) GetCodeChallenge() string {
if x != nil {
return x.CodeChallenge
}
return ""
}
func (x *AuthorizationRequest) GetCodeChallengeMethod() string {
if x != nil && x.CodeChallengeMethod != nil {
return *x.CodeChallengeMethod
}
return ""
}
func (x *AuthorizationRequest) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
var File_authorization_request_proto protoreflect.FileDescriptor
var file_authorization_request_proto_rawDesc = string([]byte{
0x0a, 0x1b, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f,
0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f,
0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69,
0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0xaa, 0x03, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a,
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x09,
0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42,
0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
0x64, 0x12, 0x26, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72,
0x69, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72,
0x65, 0x63, 0x74, 0x55, 0x72, 0x69, 0x88, 0x01, 0x01, 0x12, 0x33, 0x0a, 0x0d, 0x72, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
0x42, 0x0e, 0xba, 0x48, 0x0b, 0xc8, 0x01, 0x01, 0x72, 0x06, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65,
0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x19,
0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x48, 0x01, 0x52,
0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x88, 0x01, 0x01, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x63, 0x6f,
0x70, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65,
0x73, 0x12, 0x34, 0x0a, 0x0e, 0x63, 0x6f, 0x64, 0x65, 0x5f, 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65,
0x6e, 0x67, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0d, 0xba, 0x48, 0x0a, 0xc8, 0x01,
0x01, 0x72, 0x05, 0x10, 0x2b, 0x18, 0x80, 0x01, 0x52, 0x0d, 0x63, 0x6f, 0x64, 0x65, 0x43, 0x68,
0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x12, 0x4b, 0x0a, 0x15, 0x63, 0x6f, 0x64, 0x65, 0x5f,
0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64,
0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x42, 0x12, 0xba, 0x48, 0x0f, 0x72, 0x0d, 0x52, 0x04, 0x53,
0x32, 0x35, 0x36, 0x52, 0x05, 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x48, 0x02, 0x52, 0x13, 0x63, 0x6f,
0x64, 0x65, 0x43, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x4d, 0x65, 0x74, 0x68, 0x6f,
0x64, 0x88, 0x01, 0x01, 0x12, 0x25, 0x0a, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f,
0x69, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01,
0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x5f,
0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x42, 0x08, 0x0a, 0x06,
0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x5f,
0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64,
0x42, 0xac, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31,
0x42, 0x19, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69,
0x6c, 0x64, 0x2f, 0x62, 0x75, 0x66, 0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75,
0x69, 0x63, 0x6b, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72,
0x74, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61,
0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2,
0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74,
0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
})
var (
file_authorization_request_proto_rawDescOnce sync.Once
file_authorization_request_proto_rawDescData []byte
)
func file_authorization_request_proto_rawDescGZIP() []byte {
file_authorization_request_proto_rawDescOnce.Do(func() {
file_authorization_request_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_authorization_request_proto_rawDesc), len(file_authorization_request_proto_rawDesc)))
})
return file_authorization_request_proto_rawDescData
}
var file_authorization_request_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_authorization_request_proto_goTypes = []any{
(*AuthorizationRequest)(nil), // 0: oauth21.AuthorizationRequest
}
var file_authorization_request_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_authorization_request_proto_init() }
func file_authorization_request_proto_init() {
if File_authorization_request_proto != nil {
return
}
file_authorization_request_proto_msgTypes[0].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_authorization_request_proto_rawDesc), len(file_authorization_request_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_authorization_request_proto_goTypes,
DependencyIndexes: file_authorization_request_proto_depIdxs,
MessageInfos: file_authorization_request_proto_msgTypes,
}.Build()
File_authorization_request_proto = out.File
file_authorization_request_proto_goTypes = nil
file_authorization_request_proto_depIdxs = nil
}

View file

@ -0,0 +1,4 @@
package oauth21
//go:generate go run github.com/bufbuild/buf/cmd/buf@v1.53.0 dep update
//go:generate go run github.com/bufbuild/buf/cmd/buf@v1.53.0 generate

10
internal/oauth21/parse.go Normal file
View file

@ -0,0 +1,10 @@
package oauth21
import "net/http"
func optionalFormParam(r *http.Request, key string) *string {
if v := r.FormValue(key); v != "" {
return &v
}
return nil
}

View file

@ -0,0 +1,51 @@
syntax = "proto3";
package oauth21;
import "buf/validate/validate.proto";
option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen";
// modeled based on
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1
message AuthorizationRequest {
// The client identifier as described in Section 2.2.
string client_id = 1 [ (buf.validate.field).required = true ];
// OPTIONAL if only one redirect URI is registered for this client. REQUIRED
// if multiple redirict URIs are registered for this client.
optional string redirect_uri = 2;
// REQUIRED. The authorization endpoint supports different sets of request and
// response parameters. The client determines the type of flow by using a
// certain response_type value. This specification defines the value code,
// which must be used to signal that the client wants to use the authorization
// code flow.
string response_type = 3 [
(buf.validate.field).required = true,
(buf.validate.field).string = {in : [ "code" ]}
];
// OPTIONAL. An opaque value used by the client to maintain state between the
// request and callback. The authorization server includes this value when
// redirecting the user agent back to the client.
optional string state = 4;
// OPTIONAL. The scope of the access request as described by Section 1.4.1.
repeated string scopes = 5;
// REQUIRED, assumes https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1
string code_challenge = 6 [
(buf.validate.field).required = true,
(buf.validate.field).string = {min_len : 43, max_len : 128}
];
// OPTIONAL, defaults to plain if not present in the request. Code verifier
// transformation method is S256 or plain.
optional string code_challenge_method = 7
[ (buf.validate.field).string = {in : [ "S256", "plain" ]} ];
// session this authorization request is associated with.
// This is a Pomerium implementation specific field.
string session_id = 8 [ (buf.validate.field).required = true ];
}

View file

@ -0,0 +1,40 @@
package oauth21
import (
"slices"
"github.com/pomerium/pomerium/internal/oauth21/gen"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
)
func ValidateAuthorizationRequest(
client *rfc7591v1.ClientMetadata,
req *gen.AuthorizationRequest,
) error {
if err := ValidateAuthorizationRequestRedirectURI(client, req.RedirectUri); err != nil {
return err
}
return nil
}
func ValidateAuthorizationRequestRedirectURI(
client *rfc7591v1.ClientMetadata,
redirectURI *string,
) error {
if len(client.RedirectUris) == 0 {
return Error{Code: InvalidClient, Description: "client has no redirect URIs"}
}
if redirectURI == nil {
if len(client.RedirectUris) != 1 {
return Error{Code: InvalidRequest, Description: "client has multiple redirect URIs and none were provided"}
}
return nil
}
if !slices.Contains(client.RedirectUris, *redirectURI) {
return Error{Code: InvalidGrant, Description: "client redirect URI does not match registered redirect URIs"}
}
return nil
}

View file

@ -0,0 +1,73 @@
package oauth21_test
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/oauth21"
"github.com/pomerium/pomerium/internal/oauth21/gen"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
)
func TestValidateRequest(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
client *rfc7591v1.ClientMetadata
req *gen.AuthorizationRequest
err bool
}{
{
"optional redirect_uri, multiple redirect_uris",
&rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"},
},
&gen.AuthorizationRequest{
RedirectUri: nil,
},
true,
},
{
"optional redirect_uri, single redirect_uri",
&rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback"},
},
&gen.AuthorizationRequest{
RedirectUri: nil,
},
false,
},
{
"matching redirect_uri",
&rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"},
},
&gen.AuthorizationRequest{
RedirectUri: proto.String("https://example.com/callback"),
},
false,
},
{
"non-matching redirect_uri",
&rfc7591v1.ClientMetadata{
RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"},
},
&gen.AuthorizationRequest{
RedirectUri: proto.String("https://example.com/invalid-callback"),
},
true,
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := oauth21.ValidateAuthorizationRequest(tc.client, tc.req)
if tc.err {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}