diff --git a/internal/mcp/handler_token.go b/internal/mcp/handler_token.go index 673028766..a601a746a 100644 --- a/internal/mcp/handler_token.go +++ b/internal/mcp/handler_token.go @@ -2,9 +2,64 @@ package mcp import ( "net/http" + "time" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/oauth21" + oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" ) // Token handles the /token endpoint. -func (srv *Handler) Token(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotImplemented) +func (srv *Handler) Token(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + req, err := oauth21.ParseTokenRequest(r) + if err != nil { + log.Ctx(r.Context()).Error().Err(err).Msg("failed to parse token request") + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest) + return + } + + switch req.GrantType { + case "authorization_code": + srv.handleAuthorizationCodeToken(w, r, req) + default: + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.UnsupportedGrantType) + return + } +} + +func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.Request, req *oauth21proto.TokenRequest) { + ctx := r.Context() + + if req.ClientId == nil { + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidClient) + return + } + if req.Code == nil { + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) + return + } + code, err := DecryptCode(CodeTypeAuthorization, *req.Code, srv.cipher, *req.ClientId, time.Now()) + if err != nil { + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) + return + } + + authReq, err := srv.storage.GetAuthorizationRequest(ctx, code.Id) + if err != nil { + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) + return + } + + err = AuthorizeTokenRequest(req, authReq) + if err != nil { + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) + return + } + + http.Error(w, "Not Implemented", http.StatusNotImplemented) } diff --git a/internal/mcp/token.go b/internal/mcp/token.go new file mode 100644 index 000000000..238cecc38 --- /dev/null +++ b/internal/mcp/token.go @@ -0,0 +1,37 @@ +package mcp + +import ( + "fmt" + + "github.com/pomerium/pomerium/internal/oauth21" + "github.com/pomerium/pomerium/internal/oauth21/gen" +) + +func AuthorizeTokenRequest( + tokReq *gen.TokenRequest, + authReq *gen.AuthorizationRequest, +) error { + if tokReq.GrantType != "authorization_code" { + return fmt.Errorf("unexpected grant type: %s", tokReq.GrantType) + } + + if tokReq.ClientId == nil { + return fmt.Errorf("token request: missing client_id") + } else if *tokReq.ClientId != authReq.ClientId { + return fmt.Errorf("token request: client_id does not match authorization request") + } + + if authReq.CodeChallengeMethod == nil || *authReq.CodeChallengeMethod == "plain" { + if !oauth21.VerifyPKCEPlain(*tokReq.CodeVerifier, authReq.CodeChallenge) { + return fmt.Errorf("plain: code verifier does not match code challenge") + } + } else if *authReq.CodeChallengeMethod == "S256" { + if !oauth21.VerifyPKCES256(*tokReq.CodeVerifier, authReq.CodeChallenge) { + return fmt.Errorf("S256: code verifier does not match code challenge") + } + } else { + return fmt.Errorf("unsupported code challenge method: %s", *authReq.CodeChallengeMethod) + } + + return nil +} diff --git a/internal/oauth21/buf.gen.yaml b/internal/oauth21/buf.gen.yaml index 068b0bde9..cf68b29bc 100644 --- a/internal/oauth21/buf.gen.yaml +++ b/internal/oauth21/buf.gen.yaml @@ -11,7 +11,7 @@ managed: enabled: true override: - file_option: go_package_prefix - value: github.com/bufbuild/buf-examples/protovalidate/quickstart-go/start/gen + value: github.com/pomerium/pomerium/internal/oauth21/gen # Don't modify any file option or field option for protovalidate. Without # this, generated Go will fail to compile. disable: diff --git a/internal/oauth21/gen/authorization_request.pb.go b/internal/oauth21/gen/authorization_request.pb.go index 5f465c9e7..8a426080d 100644 --- a/internal/oauth21/gen/authorization_request.pb.go +++ b/internal/oauth21/gen/authorization_request.pb.go @@ -175,18 +175,17 @@ var file_authorization_request_proto_rawDesc = string([]byte{ 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x5f, 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, - 0x42, 0xac, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, + 0x42, 0x97, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x19, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, - 0x6c, 0x64, 0x2f, 0x62, 0x75, 0x66, 0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, - 0x69, 0x63, 0x6b, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, - 0x74, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, - 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, - 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, + 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, + 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, + 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, + 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, }) var ( diff --git a/internal/oauth21/gen/code.pb.go b/internal/oauth21/gen/code.pb.go index e9643099d..94db174d5 100644 --- a/internal/oauth21/gen/code.pb.go +++ b/internal/oauth21/gen/code.pb.go @@ -161,18 +161,16 @@ var file_code_proto_rawDesc = string([]byte{ 0x55, 0x54, 0x48, 0x4f, 0x52, 0x49, 0x5a, 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x10, 0x01, 0x12, 0x14, 0x0a, 0x10, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x41, 0x43, 0x43, 0x45, 0x53, 0x53, 0x10, 0x02, 0x12, 0x15, 0x0a, 0x11, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x54, 0x59, 0x50, - 0x45, 0x5f, 0x52, 0x45, 0x46, 0x52, 0x45, 0x53, 0x48, 0x10, 0x03, 0x42, 0x9c, 0x01, 0x0a, 0x0b, + 0x45, 0x5f, 0x52, 0x45, 0x46, 0x52, 0x45, 0x53, 0x48, 0x10, 0x03, 0x42, 0x87, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x09, 0x43, 0x6f, 0x64, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x62, 0x75, - 0x66, 0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x73, 0x74, - 0x61, 0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2f, 0x67, 0x65, 0x6e, - 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, - 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, - 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, - 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, + 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, + 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, + 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, + 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, + 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, + 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, }) var ( diff --git a/internal/oauth21/gen/token.pb.go b/internal/oauth21/gen/token.pb.go new file mode 100644 index 000000000..f3d04c0ef --- /dev/null +++ b/internal/oauth21/gen/token.pb.go @@ -0,0 +1,254 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.5 +// protoc (unknown) +// source: token.proto + +package gen + +import ( + _ "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Represents the request sent to the Token Endpoint (Section 3.2.2). +// Different parameters are required based on the grant_type. +type TokenRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // REQUIRED. Identifies the grant type being used. + // See Sections 3.2.2, 4.1.3, 4.2.1, 4.3.1, 4.4. + GrantType string `protobuf:"bytes,1,opt,name=grant_type,json=grantType,proto3" json:"grant_type,omitempty"` + // REQUIRED for grant_type="authorization_code". + // The authorization code received from the authorization server. + Code *string `protobuf:"bytes,2,opt,name=code,proto3,oneof" json:"code,omitempty"` + // REQUIRED for grant_type="authorization_code" if the original authorization request + // included a "code_challenge". MUST NOT be sent otherwise. (Section 4.1.3) + // The original PKCE code verifier string. + CodeVerifier *string `protobuf:"bytes,3,opt,name=code_verifier,json=codeVerifier,proto3,oneof" json:"code_verifier,omitempty"` + // REQUIRED for grant_type="authorization_code" if the client is public + // and not authenticating with the authorization server via other means. (Section 4.1.3) + // Also used for body-parameter client authentication (Section 2.4.1) or + // when grant_type requires public client identification (Section 3.2.2). + ClientId *string `protobuf:"bytes,4,opt,name=client_id,json=clientId,proto3,oneof" json:"client_id,omitempty"` + // REQUIRED for grant_type="refresh_token". + // The refresh token issued to the client. + RefreshToken *string `protobuf:"bytes,5,opt,name=refresh_token,json=refreshToken,proto3,oneof" json:"refresh_token,omitempty"` + // OPTIONAL for grant_type="client_credentials" (Section 4.2.1) or + // grant_type="refresh_token" (Section 4.3.1). + // The requested scope of the access request. Space-delimited list. + Scope *string `protobuf:"bytes,6,opt,name=scope,proto3,oneof" json:"scope,omitempty"` + // REQUIRED when using body parameters for client authentication. + // The client secret. + ClientSecret *string `protobuf:"bytes,7,opt,name=client_secret,json=clientSecret,proto3,oneof" json:"client_secret,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TokenRequest) Reset() { + *x = TokenRequest{} + mi := &file_token_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TokenRequest) ProtoMessage() {} + +func (x *TokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_token_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TokenRequest.ProtoReflect.Descriptor instead. +func (*TokenRequest) Descriptor() ([]byte, []int) { + return file_token_proto_rawDescGZIP(), []int{0} +} + +func (x *TokenRequest) GetGrantType() string { + if x != nil { + return x.GrantType + } + return "" +} + +func (x *TokenRequest) GetCode() string { + if x != nil && x.Code != nil { + return *x.Code + } + return "" +} + +func (x *TokenRequest) GetCodeVerifier() string { + if x != nil && x.CodeVerifier != nil { + return *x.CodeVerifier + } + return "" +} + +func (x *TokenRequest) GetClientId() string { + if x != nil && x.ClientId != nil { + return *x.ClientId + } + return "" +} + +func (x *TokenRequest) GetRefreshToken() string { + if x != nil && x.RefreshToken != nil { + return *x.RefreshToken + } + return "" +} + +func (x *TokenRequest) GetScope() string { + if x != nil && x.Scope != nil { + return *x.Scope + } + return "" +} + +func (x *TokenRequest) GetClientSecret() string { + if x != nil && x.ClientSecret != nil { + return *x.ClientSecret + } + return "" +} + +var File_token_proto protoreflect.FileDescriptor + +var file_token_proto_rawDesc = string([]byte{ + 0x0a, 0x0b, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, + 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x22, 0xc1, 0x06, 0x0a, 0x0c, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x5b, 0x0a, 0x0a, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x5f, 0x74, 0x79, + 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x3c, 0xba, 0x48, 0x39, 0x72, 0x37, 0x52, + 0x12, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x63, + 0x6f, 0x64, 0x65, 0x52, 0x0d, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, + 0x65, 0x6e, 0x52, 0x12, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x63, 0x72, 0x65, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x73, 0x52, 0x09, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x20, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, + 0x07, 0xba, 0x48, 0x04, 0x72, 0x02, 0x10, 0x01, 0x48, 0x00, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, + 0x88, 0x01, 0x01, 0x12, 0x34, 0x0a, 0x0d, 0x63, 0x6f, 0x64, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x69, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0a, 0xba, 0x48, 0x07, 0x72, + 0x05, 0x10, 0x2b, 0x18, 0x80, 0x01, 0x48, 0x01, 0x52, 0x0c, 0x63, 0x6f, 0x64, 0x65, 0x56, 0x65, + 0x72, 0x69, 0x66, 0x69, 0x65, 0x72, 0x88, 0x01, 0x01, 0x12, 0x29, 0x0a, 0x09, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xba, 0x48, + 0x04, 0x72, 0x02, 0x10, 0x01, 0x48, 0x02, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, + 0x64, 0x88, 0x01, 0x01, 0x12, 0x31, 0x0a, 0x0d, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xba, 0x48, 0x04, + 0x72, 0x02, 0x10, 0x01, 0x48, 0x03, 0x52, 0x0c, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x88, 0x01, 0x01, 0x12, 0x22, 0x0a, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xba, 0x48, 0x04, 0x72, 0x02, 0x10, 0x01, 0x48, + 0x04, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x88, 0x01, 0x01, 0x12, 0x31, 0x0a, 0x0d, 0x63, + 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x09, 0x42, 0x07, 0xba, 0x48, 0x04, 0x72, 0x02, 0x10, 0x01, 0x48, 0x05, 0x52, 0x0c, 0x63, + 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x88, 0x01, 0x01, 0x3a, 0xef, + 0x02, 0xba, 0x48, 0xeb, 0x02, 0x1a, 0xa8, 0x01, 0x0a, 0x2f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, + 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x5f, 0x72, 0x65, 0x71, + 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x66, 0x6f, 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x63, + 0x6f, 0x64, 0x65, 0x5f, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x12, 0x38, 0x63, 0x6f, 0x64, 0x65, 0x20, + 0x69, 0x73, 0x20, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x20, 0x77, 0x68, 0x65, 0x6e, + 0x20, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x20, 0x69, 0x73, 0x20, 0x27, + 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, + 0x64, 0x65, 0x27, 0x1a, 0x3b, 0x28, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x67, 0x72, 0x61, 0x6e, 0x74, + 0x5f, 0x74, 0x79, 0x70, 0x65, 0x20, 0x21, 0x3d, 0x20, 0x27, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x27, 0x29, 0x20, 0x7c, + 0x7c, 0x20, 0x68, 0x61, 0x73, 0x28, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x29, + 0x1a, 0xbd, 0x01, 0x0a, 0x3c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x2e, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x5f, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x66, 0x6f, 0x72, 0x5f, 0x72, 0x65, + 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x67, 0x72, 0x61, 0x6e, + 0x74, 0x12, 0x3c, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x20, 0x69, 0x73, 0x20, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x20, 0x77, 0x68, 0x65, + 0x6e, 0x20, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x20, 0x69, 0x73, 0x20, + 0x27, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x27, 0x1a, + 0x3f, 0x28, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x67, 0x72, 0x61, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, + 0x65, 0x20, 0x21, 0x3d, 0x20, 0x27, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x27, 0x29, 0x20, 0x7c, 0x7c, 0x20, 0x68, 0x61, 0x73, 0x28, 0x74, 0x68, 0x69, + 0x73, 0x2e, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x29, + 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x63, 0x6f, + 0x64, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x65, 0x72, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, + 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, + 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x42, 0x08, 0x0a, 0x06, 0x5f, + 0x73, 0x63, 0x6f, 0x70, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x88, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, + 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x0a, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, + 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75, + 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, + 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, + 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, + 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) + +var ( + file_token_proto_rawDescOnce sync.Once + file_token_proto_rawDescData []byte +) + +func file_token_proto_rawDescGZIP() []byte { + file_token_proto_rawDescOnce.Do(func() { + file_token_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_token_proto_rawDesc), len(file_token_proto_rawDesc))) + }) + return file_token_proto_rawDescData +} + +var file_token_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_token_proto_goTypes = []any{ + (*TokenRequest)(nil), // 0: oauth21.TokenRequest +} +var file_token_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_token_proto_init() } +func file_token_proto_init() { + if File_token_proto != nil { + return + } + file_token_proto_msgTypes[0].OneofWrappers = []any{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_proto_rawDesc), len(file_token_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_token_proto_goTypes, + DependencyIndexes: file_token_proto_depIdxs, + MessageInfos: file_token_proto_msgTypes, + }.Build() + File_token_proto = out.File + file_token_proto_goTypes = nil + file_token_proto_depIdxs = nil +} diff --git a/internal/oauth21/pkce.go b/internal/oauth21/pkce.go new file mode 100644 index 000000000..fc752fb80 --- /dev/null +++ b/internal/oauth21/pkce.go @@ -0,0 +1,29 @@ +package oauth21 + +import ( + "crypto/sha256" + "crypto/subtle" + "encoding/base64" +) + +// VerifyPKCES256 verifies a PKCE challenge using the S256 method. +// It performs a constant-time comparison to mitigate timing attacks. +// +// - codeVerifier: The verifier string sent by the client in the token request. +// - storedCodeChallenge: The challenge string stored by the server during the authorization request. +// Returns true if the verifier is valid, false otherwise. +func VerifyPKCES256(codeVerifier, storedCodeChallenge string) bool { + sha256Hash := sha256.Sum256([]byte(codeVerifier)) + calculatedChallenge := base64.RawURLEncoding.EncodeToString(sha256Hash[:]) + return subtle.ConstantTimeCompare([]byte(calculatedChallenge), []byte(storedCodeChallenge)) == 1 +} + +// VerifyPKCEPlain verifies a PKCE challenge using the plain method. +// It performs a constant-time comparison to mitigate timing attacks. +// +// - codeVerifier: The verifier string sent by the client in the token request. +// - storedCodeChallenge: The challenge string stored by the server during the authorization request. +// Returns true if the verifier is valid, false otherwise. +func VerifyPKCEPlain(codeVerifier, storedCodeChallenge string) bool { + return subtle.ConstantTimeCompare([]byte(codeVerifier), []byte(storedCodeChallenge)) == 1 +} diff --git a/internal/oauth21/pkce_test.go b/internal/oauth21/pkce_test.go new file mode 100644 index 000000000..d356333d3 --- /dev/null +++ b/internal/oauth21/pkce_test.go @@ -0,0 +1,81 @@ +package oauth21_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/oauth21" +) + +// TestVerifyPKCES256 tests the S256 PKCE verification method. +func TestVerifyPKCES256(t *testing.T) { + // Example values from RFC 7636 Appendix B + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + challenge := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + tests := []struct { + name string + verifier string + challenge string + want bool + }{ + { + name: "Correct Verifier", + verifier: verifier, + challenge: challenge, + want: true, + }, + { + name: "Incorrect Verifier", + verifier: "incorrect_verifier_string", + challenge: challenge, + want: false, + }, + { + name: "Incorrect Challenge", + verifier: verifier, + challenge: "incorrect_challenge_string", + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := oauth21.VerifyPKCES256(tc.verifier, tc.challenge) + assert.Equal(t, tc.want, got) + }) + } +} + +// TestVerifyPKCEPlain tests the Plain PKCE verification method. +func TestVerifyPKCEPlain(t *testing.T) { + verifierPlain := "this-is-a-plain-verifier-43-chars-long-askldfj" + + tests := []struct { + name string + verifier string + challenge string + want bool + }{ + { + name: "Correct Verifier", + verifier: verifierPlain, + challenge: verifierPlain, + want: true, + }, + { + name: "Incorrect Verifier", + verifier: "incorrect_verifier_string", + challenge: verifierPlain, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := oauth21.VerifyPKCEPlain(tc.verifier, tc.challenge) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/internal/oauth21/proto/token.proto b/internal/oauth21/proto/token.proto new file mode 100644 index 000000000..430c7b29e --- /dev/null +++ b/internal/oauth21/proto/token.proto @@ -0,0 +1,83 @@ +syntax = "proto3"; + +package oauth21; + +import "buf/validate/validate.proto"; + +option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen"; + +// Defines messages for OAuth 2.1 Token Endpoint requests and responses. +// Based on draft-ietf-oauth-v2-1-12. + +// Represents the request sent to the Token Endpoint (Section 3.2.2). +// Different parameters are required based on the grant_type. +message TokenRequest { + // REQUIRED. Identifies the grant type being used. + // See Sections 3.2.2, 4.1.3, 4.2.1, 4.3.1, 4.4. + string grant_type = 1 [ + (buf.validate.field).string = { + in: ["authorization_code", "refresh_token", "client_credentials"], + } + ]; + + // --- Authorization Code Grant Parameters (Section 4.1.3) --- + + // REQUIRED for grant_type="authorization_code". + // The authorization code received from the authorization server. + optional string code = 2 [ + (buf.validate.field).string = { + min_len: 1, + } + ]; + option (buf.validate.message).cel = { + id: "token_request.code_required_for_auth_code_grant", + message: "code is required when grant_type is 'authorization_code'", + expression: "(this.grant_type != 'authorization_code') || has(this.code)", + }; + + // REQUIRED for grant_type="authorization_code" if the original authorization request + // included a "code_challenge". MUST NOT be sent otherwise. (Section 4.1.3) + // The original PKCE code verifier string. + optional string code_verifier = 3 [(buf.validate.field).string = { + min_len: 43, + max_len: 128, + }]; + + // REQUIRED for grant_type="authorization_code" if the client is public + // and not authenticating with the authorization server via other means. (Section 4.1.3) + // Also used for body-parameter client authentication (Section 2.4.1) or + // when grant_type requires public client identification (Section 3.2.2). + optional string client_id = 4 [ + (buf.validate.field).string.min_len = 1 + ]; + + // --- Refresh Token Grant Parameters (Section 4.3.1) --- + + // REQUIRED for grant_type="refresh_token". + // The refresh token issued to the client. + optional string refresh_token = 5 [ + (buf.validate.field).string = { + min_len: 1, + } + ]; + option (buf.validate.message).cel = { + id: "token_request.refresh_token_required_for_refresh_token_grant", + message: "refresh_token is required when grant_type is 'refresh_token'", + expression: "(this.grant_type != 'refresh_token') || has(this.refresh_token)", + }; + + // --- Client Credentials Grant & Refresh Token Grant Parameters --- + + // OPTIONAL for grant_type="client_credentials" (Section 4.2.1) or + // grant_type="refresh_token" (Section 4.3.1). + // The requested scope of the access request. Space-delimited list. + optional string scope = 6 [(buf.validate.field).string.min_len = 1]; + + // --- Client Authentication via Body Parameters (Section 2.4.1) --- + // Used when including credentials directly in the request body instead of e.g. HTTP Basic Auth. + // client_id (field 4) is also used in this case. + + // REQUIRED when using body parameters for client authentication. + // The client secret. + optional string client_secret = 7 [(buf.validate.field).string.min_len = 1]; +} diff --git a/internal/oauth21/token.go b/internal/oauth21/token.go new file mode 100644 index 000000000..00319e9e2 --- /dev/null +++ b/internal/oauth21/token.go @@ -0,0 +1,34 @@ +package oauth21 + +import ( + "fmt" + "net/http" + + "github.com/bufbuild/protovalidate-go" + + "github.com/pomerium/pomerium/internal/oauth21/gen" +) + +func ParseTokenRequest(r *http.Request) (*gen.TokenRequest, error) { + err := r.ParseForm() + if err != nil { + return nil, fmt.Errorf("failed to parse form: %w", err) + } + + v := &gen.TokenRequest{ + GrantType: r.Form.Get("grant_type"), + Code: optionalFormParam(r, "code"), + CodeVerifier: optionalFormParam(r, "code_verifier"), + ClientId: optionalFormParam(r, "client_id"), + RefreshToken: optionalFormParam(r, "refresh_token"), + Scope: optionalFormParam(r, "scope"), + ClientSecret: optionalFormParam(r, "client_secret"), + } + + err = protovalidate.Validate(v) + if err != nil { + return nil, fmt.Errorf("failed to validate token request: %w", err) + } + + return v, nil +} diff --git a/internal/oauth21/validate_token_test.go b/internal/oauth21/validate_token_test.go new file mode 100644 index 000000000..e5182fe23 --- /dev/null +++ b/internal/oauth21/validate_token_test.go @@ -0,0 +1,132 @@ +package oauth21_test + +import ( + "testing" + + "github.com/bufbuild/protovalidate-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/internal/oauth21/gen" +) + +func TestTokenRequestValidation(t *testing.T) { + validator, err := protovalidate.New() + require.NoError(t, err) + + testCases := []struct { + name string + request *gen.TokenRequest + expectError bool + errorMsg string + }{ + { + name: "valid authorization_code grant", + request: &gen.TokenRequest{ + GrantType: "authorization_code", + Code: proto.String("some_code"), + CodeVerifier: proto.String("code_verifier_should_be_at_least_43_characters_long_123456"), + ClientId: proto.String("client_id"), + }, + expectError: false, + }, + { + name: "missing code for authorization_code grant", + request: &gen.TokenRequest{ + GrantType: "authorization_code", + ClientId: proto.String("client_id"), + }, + expectError: true, + errorMsg: "code is required when grant_type is 'authorization_code'", + }, + { + name: "code_verifier too short", + request: &gen.TokenRequest{ + GrantType: "authorization_code", + Code: proto.String("some_code"), + CodeVerifier: proto.String("too_short"), + ClientId: proto.String("client_id"), + }, + expectError: true, + errorMsg: "value length must be at least 43 characters", + }, + { + name: "valid refresh_token grant", + request: &gen.TokenRequest{ + GrantType: "refresh_token", + RefreshToken: proto.String("refresh_token"), + Scope: proto.String("scope1 scope2"), + }, + expectError: false, + }, + { + name: "missing refresh_token for refresh_token grant", + request: &gen.TokenRequest{ + GrantType: "refresh_token", + }, + expectError: true, + errorMsg: "refresh_token is required when grant_type is 'refresh_token'", + }, + { + name: "valid client_credentials grant", + request: &gen.TokenRequest{ + GrantType: "client_credentials", + ClientId: proto.String("client_id"), + Scope: proto.String("scope1 scope2"), + }, + expectError: false, + }, + { + name: "invalid grant_type", + request: &gen.TokenRequest{ + GrantType: "invalid_grant_type", + }, + expectError: true, + errorMsg: "value must be in list", + }, + { + name: "empty client_id", + request: &gen.TokenRequest{ + GrantType: "client_credentials", + ClientId: proto.String(""), + }, + expectError: true, + errorMsg: "value length must be at least 1", + }, + { + name: "empty scope", + request: &gen.TokenRequest{ + GrantType: "client_credentials", + ClientId: proto.String("client_id"), + Scope: proto.String(""), + }, + expectError: true, + errorMsg: "value length must be at least 1", + }, + { + name: "empty client_secret", + request: &gen.TokenRequest{ + GrantType: "client_credentials", + ClientId: proto.String("client_id"), + ClientSecret: proto.String(""), + }, + expectError: true, + errorMsg: "value length must be at least 1", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validator.Validate(tc.request) + if tc.expectError { + require.Error(t, err) + if tc.errorMsg != "" { + assert.Contains(t, err.Error(), tc.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +}