From 63ccf6ab93d7abc8c8f102f817acbc4e53c4e781 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Thu, 24 Apr 2025 14:59:12 -0400 Subject: [PATCH] mcp: authorize request (pt1) (#5585) --- authorize/evaluator/evaluator.go | 1 + internal/mcp/handler_authorization.go | 81 +++++- internal/mcp/storage.go | 44 ++++ internal/mcp/storage_test.go | 64 +++++ internal/oauth21/authorize.go | 35 +++ internal/oauth21/buf.gen.yaml | 19 ++ internal/oauth21/buf.lock | 6 + internal/oauth21/buf.yaml | 12 + internal/oauth21/error.go | 42 +++ .../oauth21/gen/authorization_request.pb.go | 239 ++++++++++++++++++ internal/oauth21/generate.go | 4 + internal/oauth21/parse.go | 10 + .../oauth21/proto/authorization_request.proto | 51 ++++ internal/oauth21/validate_client.go | 40 +++ internal/oauth21/validate_client_test.go | 73 ++++++ 15 files changed, 719 insertions(+), 2 deletions(-) create mode 100644 internal/mcp/storage_test.go create mode 100644 internal/oauth21/authorize.go create mode 100644 internal/oauth21/buf.gen.yaml create mode 100644 internal/oauth21/buf.lock create mode 100644 internal/oauth21/buf.yaml create mode 100644 internal/oauth21/error.go create mode 100644 internal/oauth21/gen/authorization_request.pb.go create mode 100644 internal/oauth21/generate.go create mode 100644 internal/oauth21/parse.go create mode 100644 internal/oauth21/proto/authorization_request.proto create mode 100644 internal/oauth21/validate_client.go create mode 100644 internal/oauth21/validate_client_test.go diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index e0d5b7026..b6ebe7c58 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -292,6 +292,7 @@ var internalPathsNeedingLogin = set.From([]string{ "/.pomerium/webauthn", "/.pomerium/routes", "/.pomerium/api/v1/routes", + "/.pomerium/mcp/authorize", }) func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) { diff --git a/internal/mcp/handler_authorization.go b/internal/mcp/handler_authorization.go index 1197c99e6..e15fdbb00 100644 --- a/internal/mcp/handler_authorization.go +++ b/internal/mcp/handler_authorization.go @@ -1,10 +1,87 @@ package mcp import ( + "errors" + "fmt" "net/http" + + "github.com/go-jose/go-jose/v3/jwt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/oauth21" ) // Authorize handles the /authorize endpoint. -func (srv *Handler) Authorize(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotImplemented) +func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "invalid method", http.StatusMethodNotAllowed) + return + } + + ctx := r.Context() + + sessionID, err := getSessionFromRequest(r) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("session is not present, this is a misconfigured request") + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + v, err := oauth21.ParseCodeGrantAuthorizeRequest(r, sessionID) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to parse authorization request") + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest) + return + } + + client, err := srv.storage.GetClientByID(ctx, v.ClientId) + if err != nil && status.Code(err) == codes.NotFound { + oauth21.ErrorResponse(w, http.StatusUnauthorized, oauth21.InvalidClient) + return + } + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to get client") + http.Error(w, "cannot fetch client", http.StatusInternalServerError) + return + } + + if err := oauth21.ValidateAuthorizationRequest(client, v); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to validate authorization request") + ve := oauth21.Error{Code: oauth21.InvalidRequest} + _ = errors.As(err, &ve) + oauth21.ErrorResponse(w, http.StatusBadRequest, ve.Code) + return + } + + _, err = srv.storage.CreateAuthorizationRequest(ctx, v) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to store authorization request") + http.Error(w, "cannot create authorization request", http.StatusInternalServerError) + return + } + + http.Error(w, "not implemented", http.StatusNotImplemented) +} + +func getSessionFromRequest(r *http.Request) (string, error) { + h := r.Header.Get(httputil.HeaderPomeriumJWTAssertion) + if h == "" { + return "", fmt.Errorf("missing %s header", httputil.HeaderPomeriumJWTAssertion) + } + + token, err := jwt.ParseSigned(h) + if err != nil { + return "", fmt.Errorf("failed to parse JWT: %w", err) + } + var m map[string]any + _ = token.UnsafeClaimsWithoutVerification(&m) + sessionID, ok := m["sid"].(string) + if !ok { + return "", fmt.Errorf("missing session ID in JWT") + } + + return sessionID, nil } diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 77d9ff821..7d988a845 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -2,9 +2,13 @@ package mcp import ( "context" + "fmt" "github.com/google/uuid" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" @@ -41,3 +45,43 @@ func (storage *Storage) RegisterClient( } return id, nil } + +func (storage *Storage) GetClientByID( + ctx context.Context, + id string, +) (*rfc7591v1.ClientMetadata, error) { + v := new(rfc7591v1.ClientMetadata) + rec, err := storage.client.Get(ctx, &databroker.GetRequest{ + Type: protoutil.GetTypeURL(v), + Id: id, + }) + if err != nil { + return nil, fmt.Errorf("failed to get client by ID: %w", err) + } + + err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err) + } + + return v, nil +} + +func (storage *Storage) CreateAuthorizationRequest( + ctx context.Context, + req *oauth21proto.AuthorizationRequest, +) (string, error) { + data := protoutil.NewAny(req) + id := uuid.NewString() + _, err := storage.client.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{{ + Id: id, + Data: data, + Type: data.TypeUrl, + }}, + }) + if err != nil { + return "", err + } + return id, nil +} diff --git a/internal/mcp/storage_test.go b/internal/mcp/storage_test.go new file mode 100644 index 000000000..28b709165 --- /dev/null +++ b/internal/mcp/storage_test.go @@ -0,0 +1,64 @@ +package mcp_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/mcp" + rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" + "github.com/pomerium/pomerium/internal/testutil" + databroker_grpc "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +func TestStorage(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute*5) + + list := bufconn.Listen(1024 * 1024) + t.Cleanup(func() { + list.Close() + }) + + srv := databroker.New(ctx, noop.NewTracerProvider()) + grpcServer := grpc.NewServer() + databroker_grpc.RegisterDataBrokerServiceServer(grpcServer, srv) + + go func() { + if err := grpcServer.Serve(list); err != nil { + t.Errorf("failed to serve: %v", err) + } + }() + t.Cleanup(func() { + grpcServer.Stop() + }) + + conn, err := grpc.DialContext(ctx, "bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return list.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + client := databroker_grpc.NewDataBrokerServiceClient(conn) + + t.Run("client registration", func(t *testing.T) { + storage := mcp.NewStorage(client) + + id, err := storage.RegisterClient(ctx, &rfc7591v1.ClientMetadata{}) + require.NoError(t, err) + require.NotEmpty(t, id) + + _, err = storage.GetClientByID(ctx, id) + require.NoError(t, err) + }) +} diff --git a/internal/oauth21/authorize.go b/internal/oauth21/authorize.go new file mode 100644 index 000000000..bd0fc0d4d --- /dev/null +++ b/internal/oauth21/authorize.go @@ -0,0 +1,35 @@ +package oauth21 + +import ( + "fmt" + "net/http" + + "github.com/bufbuild/protovalidate-go" + + "github.com/pomerium/pomerium/internal/oauth21/gen" +) + +// ParseCodeGrantAuthorizeRequest parses the authorization request for the code grant flow. +// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1 +// scopes are ignored +func ParseCodeGrantAuthorizeRequest(r *http.Request, sessionID string) (*gen.AuthorizationRequest, error) { + if err := r.ParseForm(); err != nil { + return nil, fmt.Errorf("failed to parse form: %w", err) + } + + v := &gen.AuthorizationRequest{ + ClientId: r.Form.Get("client_id"), + RedirectUri: optionalFormParam(r, "redirect_uri"), + ResponseType: r.Form.Get("response_type"), + State: optionalFormParam(r, "state"), + CodeChallenge: r.Form.Get("code_challenge"), + CodeChallengeMethod: optionalFormParam(r, "code_challenge_method"), + SessionId: sessionID, + } + + if err := protovalidate.Validate(v); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + return v, nil +} diff --git a/internal/oauth21/buf.gen.yaml b/internal/oauth21/buf.gen.yaml new file mode 100644 index 000000000..068b0bde9 --- /dev/null +++ b/internal/oauth21/buf.gen.yaml @@ -0,0 +1,19 @@ + +version: v2 +inputs: + - directory: proto +plugins: + - remote: buf.build/protocolbuffers/go:v1.36.5 + out: gen + opt: + - paths=source_relative +managed: + enabled: true + override: + - file_option: go_package_prefix + value: github.com/bufbuild/buf-examples/protovalidate/quickstart-go/start/gen + # Don't modify any file option or field option for protovalidate. Without + # this, generated Go will fail to compile. + disable: + - file_option: go_package + module: buf.build/bufbuild/protovalidate diff --git a/internal/oauth21/buf.lock b/internal/oauth21/buf.lock new file mode 100644 index 000000000..09123d359 --- /dev/null +++ b/internal/oauth21/buf.lock @@ -0,0 +1,6 @@ +# Generated by buf. DO NOT EDIT. +version: v2 +deps: + - name: buf.build/bufbuild/protovalidate + commit: 7712fb530c574b95bc1d57c0877543c3 + digest: b5:b3e9c9428384357e3b73e4d5a4614328b0a4b1595b10163bbe9483fa16204749274c41797bd49b0d716479c855aa35c1172a94f471fa120ba8369637fd138829 diff --git a/internal/oauth21/buf.yaml b/internal/oauth21/buf.yaml new file mode 100644 index 000000000..67c10b231 --- /dev/null +++ b/internal/oauth21/buf.yaml @@ -0,0 +1,12 @@ + +version: v2 +modules: + - path: proto +deps: + - buf.build/bufbuild/protovalidate +lint: + use: + - STANDARD +breaking: + use: + - FILE diff --git a/internal/oauth21/error.go b/internal/oauth21/error.go new file mode 100644 index 000000000..295690b7a --- /dev/null +++ b/internal/oauth21/error.go @@ -0,0 +1,42 @@ +package oauth21 + +import ( + "encoding/json" + "net/http" +) + +type ErrorCode string + +const ( + // InvalidRequest The request is missing a required parameter, includes an unsupported parameter value (other than grant type), repeats a parameter, includes multiple credentials, utilizes more than one mechanism for authenticating the client, contains a code_verifier although no code_challenge was sent in the authorization request, or is otherwise malformed. + InvalidRequest ErrorCode = "invalid_request" + // InvalidClient Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method). The authorization server MAY return an HTTP 401 (Unauthorized) status code to indicate which HTTP authentication schemes are supported. If the client attempted to authenticate via the Authorization request header field, the authorization server MUST respond with an HTTP 401 (Unauthorized) status code and include the WWW-Authenticate response header field matching the authentication scheme used by the client. + InvalidClient ErrorCode = "invalid_client" + // InvalidGrant The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirect URI used in the authorization request, or was issued to another client. + InvalidGrant ErrorCode = "invalid_grant" + // UnauthorizedClient The authenticated client is not authorized to use this authorization grant type. + UnauthorizedClient ErrorCode = "unauthorized_client" + // UnsupportedGrantType The authorization grant type is not supported by the authorization server. + UnsupportedGrantType ErrorCode = "unsupported_grant_type" + // InvalidScope The requested scope is invalid, unknown, malformed, or exceeds the scope granted by the resource owner. + InvalidScope ErrorCode = "invalid_scope" +) + +type Error struct { + Code ErrorCode `json:"error"` + Description string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` +} + +func (e Error) Error() string { + return string(e.Code) +} + +// ErrorResponse writes an error response according to https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-3.2.4 +func ErrorResponse(w http.ResponseWriter, hc int, ec ErrorCode) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(hc) + if err := json.NewEncoder(w).Encode(Error{Code: ec}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/internal/oauth21/gen/authorization_request.pb.go b/internal/oauth21/gen/authorization_request.pb.go new file mode 100644 index 000000000..5f465c9e7 --- /dev/null +++ b/internal/oauth21/gen/authorization_request.pb.go @@ -0,0 +1,239 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.5 +// protoc (unknown) +// source: authorization_request.proto + +package gen + +import ( + _ "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// modeled based on +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1 +type AuthorizationRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The client identifier as described in Section 2.2. + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + // OPTIONAL if only one redirect URI is registered for this client. REQUIRED + // if multiple redirict URIs are registered for this client. + RedirectUri *string `protobuf:"bytes,2,opt,name=redirect_uri,json=redirectUri,proto3,oneof" json:"redirect_uri,omitempty"` + // REQUIRED. The authorization endpoint supports different sets of request and + // response parameters. The client determines the type of flow by using a + // certain response_type value. This specification defines the value code, + // which must be used to signal that the client wants to use the authorization + // code flow. + ResponseType string `protobuf:"bytes,3,opt,name=response_type,json=responseType,proto3" json:"response_type,omitempty"` + // OPTIONAL. An opaque value used by the client to maintain state between the + // request and callback. The authorization server includes this value when + // redirecting the user agent back to the client. + State *string `protobuf:"bytes,4,opt,name=state,proto3,oneof" json:"state,omitempty"` + // OPTIONAL. The scope of the access request as described by Section 1.4.1. + Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"` + // REQUIRED, assumes https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1 + CodeChallenge string `protobuf:"bytes,6,opt,name=code_challenge,json=codeChallenge,proto3" json:"code_challenge,omitempty"` + // OPTIONAL, defaults to plain if not present in the request. Code verifier + // transformation method is S256 or plain. + CodeChallengeMethod *string `protobuf:"bytes,7,opt,name=code_challenge_method,json=codeChallengeMethod,proto3,oneof" json:"code_challenge_method,omitempty"` + // session this authorization request is associated with. + // This is a Pomerium implementation specific field. + SessionId string `protobuf:"bytes,8,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthorizationRequest) Reset() { + *x = AuthorizationRequest{} + mi := &file_authorization_request_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthorizationRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthorizationRequest) ProtoMessage() {} + +func (x *AuthorizationRequest) ProtoReflect() protoreflect.Message { + mi := &file_authorization_request_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthorizationRequest.ProtoReflect.Descriptor instead. +func (*AuthorizationRequest) Descriptor() ([]byte, []int) { + return file_authorization_request_proto_rawDescGZIP(), []int{0} +} + +func (x *AuthorizationRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *AuthorizationRequest) GetRedirectUri() string { + if x != nil && x.RedirectUri != nil { + return *x.RedirectUri + } + return "" +} + +func (x *AuthorizationRequest) GetResponseType() string { + if x != nil { + return x.ResponseType + } + return "" +} + +func (x *AuthorizationRequest) GetState() string { + if x != nil && x.State != nil { + return *x.State + } + return "" +} + +func (x *AuthorizationRequest) GetScopes() []string { + if x != nil { + return x.Scopes + } + return nil +} + +func (x *AuthorizationRequest) GetCodeChallenge() string { + if x != nil { + return x.CodeChallenge + } + return "" +} + +func (x *AuthorizationRequest) GetCodeChallengeMethod() string { + if x != nil && x.CodeChallengeMethod != nil { + return *x.CodeChallengeMethod + } + return "" +} + +func (x *AuthorizationRequest) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +var File_authorization_request_proto protoreflect.FileDescriptor + +var file_authorization_request_proto_rawDesc = string([]byte{ + 0x0a, 0x1b, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, + 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, + 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x22, 0xaa, 0x03, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x09, + 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, + 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, + 0x64, 0x12, 0x26, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, + 0x69, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x55, 0x72, 0x69, 0x88, 0x01, 0x01, 0x12, 0x33, 0x0a, 0x0d, 0x72, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x42, 0x0e, 0xba, 0x48, 0x0b, 0xc8, 0x01, 0x01, 0x72, 0x06, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, + 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x19, + 0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, + 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x88, 0x01, 0x01, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x63, 0x6f, + 0x70, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, + 0x73, 0x12, 0x34, 0x0a, 0x0e, 0x63, 0x6f, 0x64, 0x65, 0x5f, 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, + 0x6e, 0x67, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0d, 0xba, 0x48, 0x0a, 0xc8, 0x01, + 0x01, 0x72, 0x05, 0x10, 0x2b, 0x18, 0x80, 0x01, 0x52, 0x0d, 0x63, 0x6f, 0x64, 0x65, 0x43, 0x68, + 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x12, 0x4b, 0x0a, 0x15, 0x63, 0x6f, 0x64, 0x65, 0x5f, + 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x42, 0x12, 0xba, 0x48, 0x0f, 0x72, 0x0d, 0x52, 0x04, 0x53, + 0x32, 0x35, 0x36, 0x52, 0x05, 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x48, 0x02, 0x52, 0x13, 0x63, 0x6f, + 0x64, 0x65, 0x43, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x4d, 0x65, 0x74, 0x68, 0x6f, + 0x64, 0x88, 0x01, 0x01, 0x12, 0x25, 0x0a, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, + 0x69, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, + 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x5f, + 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x42, 0x08, 0x0a, 0x06, + 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x5f, + 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, + 0x42, 0xac, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, + 0x42, 0x19, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, + 0x6c, 0x64, 0x2f, 0x62, 0x75, 0x66, 0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, + 0x69, 0x63, 0x6b, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, + 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, + 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) + +var ( + file_authorization_request_proto_rawDescOnce sync.Once + file_authorization_request_proto_rawDescData []byte +) + +func file_authorization_request_proto_rawDescGZIP() []byte { + file_authorization_request_proto_rawDescOnce.Do(func() { + file_authorization_request_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_authorization_request_proto_rawDesc), len(file_authorization_request_proto_rawDesc))) + }) + return file_authorization_request_proto_rawDescData +} + +var file_authorization_request_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_authorization_request_proto_goTypes = []any{ + (*AuthorizationRequest)(nil), // 0: oauth21.AuthorizationRequest +} +var file_authorization_request_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_authorization_request_proto_init() } +func file_authorization_request_proto_init() { + if File_authorization_request_proto != nil { + return + } + file_authorization_request_proto_msgTypes[0].OneofWrappers = []any{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_authorization_request_proto_rawDesc), len(file_authorization_request_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_authorization_request_proto_goTypes, + DependencyIndexes: file_authorization_request_proto_depIdxs, + MessageInfos: file_authorization_request_proto_msgTypes, + }.Build() + File_authorization_request_proto = out.File + file_authorization_request_proto_goTypes = nil + file_authorization_request_proto_depIdxs = nil +} diff --git a/internal/oauth21/generate.go b/internal/oauth21/generate.go new file mode 100644 index 000000000..ebd89dea9 --- /dev/null +++ b/internal/oauth21/generate.go @@ -0,0 +1,4 @@ +package oauth21 + +//go:generate go run github.com/bufbuild/buf/cmd/buf@v1.53.0 dep update +//go:generate go run github.com/bufbuild/buf/cmd/buf@v1.53.0 generate diff --git a/internal/oauth21/parse.go b/internal/oauth21/parse.go new file mode 100644 index 000000000..892241151 --- /dev/null +++ b/internal/oauth21/parse.go @@ -0,0 +1,10 @@ +package oauth21 + +import "net/http" + +func optionalFormParam(r *http.Request, key string) *string { + if v := r.FormValue(key); v != "" { + return &v + } + return nil +} diff --git a/internal/oauth21/proto/authorization_request.proto b/internal/oauth21/proto/authorization_request.proto new file mode 100644 index 000000000..532167ce2 --- /dev/null +++ b/internal/oauth21/proto/authorization_request.proto @@ -0,0 +1,51 @@ +syntax = "proto3"; + +package oauth21; + +import "buf/validate/validate.proto"; + +option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen"; + +// modeled based on +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1 +message AuthorizationRequest { + // The client identifier as described in Section 2.2. + string client_id = 1 [ (buf.validate.field).required = true ]; + + // OPTIONAL if only one redirect URI is registered for this client. REQUIRED + // if multiple redirict URIs are registered for this client. + optional string redirect_uri = 2; + + // REQUIRED. The authorization endpoint supports different sets of request and + // response parameters. The client determines the type of flow by using a + // certain response_type value. This specification defines the value code, + // which must be used to signal that the client wants to use the authorization + // code flow. + string response_type = 3 [ + (buf.validate.field).required = true, + (buf.validate.field).string = {in : [ "code" ]} + ]; + + // OPTIONAL. An opaque value used by the client to maintain state between the + // request and callback. The authorization server includes this value when + // redirecting the user agent back to the client. + optional string state = 4; + + // OPTIONAL. The scope of the access request as described by Section 1.4.1. + repeated string scopes = 5; + + // REQUIRED, assumes https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1 + string code_challenge = 6 [ + (buf.validate.field).required = true, + (buf.validate.field).string = {min_len : 43, max_len : 128} + ]; + + // OPTIONAL, defaults to plain if not present in the request. Code verifier + // transformation method is S256 or plain. + optional string code_challenge_method = 7 + [ (buf.validate.field).string = {in : [ "S256", "plain" ]} ]; + + // session this authorization request is associated with. + // This is a Pomerium implementation specific field. + string session_id = 8 [ (buf.validate.field).required = true ]; +} diff --git a/internal/oauth21/validate_client.go b/internal/oauth21/validate_client.go new file mode 100644 index 000000000..d34ca81b3 --- /dev/null +++ b/internal/oauth21/validate_client.go @@ -0,0 +1,40 @@ +package oauth21 + +import ( + "slices" + + "github.com/pomerium/pomerium/internal/oauth21/gen" + rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" +) + +func ValidateAuthorizationRequest( + client *rfc7591v1.ClientMetadata, + req *gen.AuthorizationRequest, +) error { + if err := ValidateAuthorizationRequestRedirectURI(client, req.RedirectUri); err != nil { + return err + } + return nil +} + +func ValidateAuthorizationRequestRedirectURI( + client *rfc7591v1.ClientMetadata, + redirectURI *string, +) error { + if len(client.RedirectUris) == 0 { + return Error{Code: InvalidClient, Description: "client has no redirect URIs"} + } + + if redirectURI == nil { + if len(client.RedirectUris) != 1 { + return Error{Code: InvalidRequest, Description: "client has multiple redirect URIs and none were provided"} + } + return nil + } + + if !slices.Contains(client.RedirectUris, *redirectURI) { + return Error{Code: InvalidGrant, Description: "client redirect URI does not match registered redirect URIs"} + } + + return nil +} diff --git a/internal/oauth21/validate_client_test.go b/internal/oauth21/validate_client_test.go new file mode 100644 index 000000000..8b70e359a --- /dev/null +++ b/internal/oauth21/validate_client_test.go @@ -0,0 +1,73 @@ +package oauth21_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/internal/oauth21" + "github.com/pomerium/pomerium/internal/oauth21/gen" + rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" +) + +func TestValidateRequest(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + name string + client *rfc7591v1.ClientMetadata + req *gen.AuthorizationRequest + err bool + }{ + { + "optional redirect_uri, multiple redirect_uris", + &rfc7591v1.ClientMetadata{ + RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"}, + }, + &gen.AuthorizationRequest{ + RedirectUri: nil, + }, + true, + }, + { + "optional redirect_uri, single redirect_uri", + &rfc7591v1.ClientMetadata{ + RedirectUris: []string{"https://example.com/callback"}, + }, + &gen.AuthorizationRequest{ + RedirectUri: nil, + }, + false, + }, + { + "matching redirect_uri", + &rfc7591v1.ClientMetadata{ + RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"}, + }, + &gen.AuthorizationRequest{ + RedirectUri: proto.String("https://example.com/callback"), + }, + false, + }, + { + "non-matching redirect_uri", + &rfc7591v1.ClientMetadata{ + RedirectUris: []string{"https://example.com/callback", "https://example.com/other-callback"}, + }, + &gen.AuthorizationRequest{ + RedirectUri: proto.String("https://example.com/invalid-callback"), + }, + true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := oauth21.ValidateAuthorizationRequest(tc.client, tc.req) + if tc.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +}