mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-25 20:49:30 +02:00
mcp: authorize: response with code
This commit is contained in:
parent
52af622cc4
commit
565f08db49
7 changed files with 346 additions and 4 deletions
50
internal/mcp/cipher.go
Normal file
50
internal/mcp/cipher.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
func getCipher(
|
||||
cfg *config.Config,
|
||||
) (cipher.AEAD, error) {
|
||||
secret, err := cfg.Options.GetSharedKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("shared key: %w", err)
|
||||
}
|
||||
|
||||
rnd := hkdf.New(sha256.New, secret, nil, nil)
|
||||
cipher, err := initCipher(rnd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new aead cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
func readKey(r io.Reader) ([]byte, error) {
|
||||
b := make([]byte, cryptutil.DefaultKeySize)
|
||||
_, err := io.ReadFull(r, b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from hkdf: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func initCipher(r io.Reader) (cipher.AEAD, error) {
|
||||
cipherKey, err := readKey(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read key: %w", err)
|
||||
}
|
||||
cipher, err := cryptutil.NewAEADCipher(cipherKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new aead cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
68
internal/mcp/code.go
Normal file
68
internal/mcp/code.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bufbuild/protovalidate-go"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
func CreateCode(
|
||||
id string,
|
||||
expires time.Time,
|
||||
ad string,
|
||||
cipher cipher.AEAD,
|
||||
) (string, error) {
|
||||
v := oauth21proto.Code{
|
||||
Id: id,
|
||||
ExpiresAt: timestamppb.New(expires),
|
||||
}
|
||||
|
||||
err := protovalidate.Validate(&v)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
|
||||
b, err := proto.Marshal(&v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
func DecryptCode(
|
||||
code string,
|
||||
cipher cipher.AEAD,
|
||||
ad string,
|
||||
now time.Time,
|
||||
) (*oauth21proto.Code, error) {
|
||||
b, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
var v oauth21proto.Code
|
||||
err = proto.Unmarshal(plaintext, &v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
if v.ExpiresAt == nil {
|
||||
return nil, fmt.Errorf("expiration is nil")
|
||||
}
|
||||
if v.ExpiresAt.AsTime().Before(now) {
|
||||
return nil, fmt.Errorf("code expired")
|
||||
}
|
||||
return &v, nil
|
||||
}
|
|
@ -2,6 +2,7 @@ package mcp
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
|
@ -32,6 +33,7 @@ type Handler struct {
|
|||
prefix string
|
||||
trace oteltrace.TracerProvider
|
||||
storage *Storage
|
||||
cipher cipher.AEAD
|
||||
}
|
||||
|
||||
func New(
|
||||
|
@ -46,10 +48,16 @@ func New(
|
|||
return nil, fmt.Errorf("databroker client: %w", err)
|
||||
}
|
||||
|
||||
cipher, err := getCipher(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get cipher: %w", err)
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
prefix: prefix,
|
||||
trace: tracerProvider,
|
||||
storage: NewStorage(client),
|
||||
cipher: cipher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/oauth21"
|
||||
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
)
|
||||
|
||||
// Authorize handles the /authorize endpoint.
|
||||
|
@ -29,6 +33,7 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
client, err := srv.storage.GetClientByID(ctx, v.ClientId)
|
||||
if err != nil && status.Code(err) == codes.NotFound {
|
||||
log.Ctx(ctx).Error().Err(err).Str("id", v.ClientId).Msg("client id not found")
|
||||
oauth21.ErrorResponse(w, http.StatusUnauthorized, oauth21.InvalidClient)
|
||||
return
|
||||
}
|
||||
|
@ -46,12 +51,46 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
_, err = srv.storage.CreateAuthorizationRequest(ctx, v)
|
||||
id, err := srv.storage.CreateAuthorizationRequest(ctx, v)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to store authorization request")
|
||||
http.Error(w, "cannot create authorization request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
srv.AuthorizationResponse(ctx, w, r, id, v)
|
||||
}
|
||||
|
||||
// AuthorizationResponse generates the successful authorization response
|
||||
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2
|
||||
func (srv *Handler) AuthorizationResponse(
|
||||
ctx context.Context,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
id string,
|
||||
req *oauth21proto.AuthorizationRequest,
|
||||
) {
|
||||
code, err := CreateCode(
|
||||
id,
|
||||
time.Now().Add(time.Minute*10),
|
||||
req.ClientId,
|
||||
srv.cipher,
|
||||
)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to create code")
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
to, err := url.Parse(req.GetRedirectUri())
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to parse redirect uri")
|
||||
http.Error(w, "invalid redirect uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
q := to.Query()
|
||||
q.Set("code", code)
|
||||
q.Set("state", req.GetState())
|
||||
to.RawQuery = q.Encode()
|
||||
http.Redirect(w, r, to.String(), http.StatusFound)
|
||||
}
|
||||
|
|
|
@ -50,14 +50,16 @@ func (storage *Storage) GetClientByID(
|
|||
ctx context.Context,
|
||||
id string,
|
||||
) (*rfc7591v1.ClientMetadata, error) {
|
||||
v := new(rfc7591v1.ClientMetadata)
|
||||
|
||||
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
|
||||
Type: protoutil.GetTypeURL(v),
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get client by ID: %w", err)
|
||||
}
|
||||
|
||||
v := new(rfc7591v1.ClientMetadata)
|
||||
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err)
|
||||
|
|
155
internal/oauth21/gen/code.pb.go
Normal file
155
internal/oauth21/gen/code.pb.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.5
|
||||
// protoc (unknown)
|
||||
// source: code.proto
|
||||
|
||||
package gen
|
||||
|
||||
import (
|
||||
_ "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// Code is a code used in the authorization code flow.
|
||||
type Code struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
|
||||
ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expires_at,json=expiresAt,proto3" json:"expires_at,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Code) Reset() {
|
||||
*x = Code{}
|
||||
mi := &file_code_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *Code) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Code) ProtoMessage() {}
|
||||
|
||||
func (x *Code) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_code_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Code.ProtoReflect.Descriptor instead.
|
||||
func (*Code) Descriptor() ([]byte, []int) {
|
||||
return file_code_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *Code) GetId() string {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Code) GetExpiresAt() *timestamppb.Timestamp {
|
||||
if x != nil {
|
||||
return x.ExpiresAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_code_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_code_proto_rawDesc = string([]byte{
|
||||
0x0a, 0x0a, 0x63, 0x6f, 0x64, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, 0x61,
|
||||
0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70,
|
||||
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69,
|
||||
0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x22, 0x65, 0x0a, 0x04, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x02, 0x69,
|
||||
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0a, 0xba, 0x48, 0x07, 0xc8, 0x01, 0x01, 0x72,
|
||||
0x02, 0x10, 0x01, 0x52, 0x02, 0x69, 0x64, 0x12, 0x41, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72,
|
||||
0x65, 0x73, 0x5f, 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f,
|
||||
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69,
|
||||
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52,
|
||||
0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x42, 0x9c, 0x01, 0x0a, 0x0b, 0x63,
|
||||
0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x09, 0x43, 0x6f, 0x64, 0x65,
|
||||
0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
|
||||
0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x62, 0x75, 0x66,
|
||||
0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76,
|
||||
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x73, 0x74, 0x61,
|
||||
0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2f, 0x67, 0x65, 0x6e, 0xa2,
|
||||
0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca,
|
||||
0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74,
|
||||
0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea,
|
||||
0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x33,
|
||||
})
|
||||
|
||||
var (
|
||||
file_code_proto_rawDescOnce sync.Once
|
||||
file_code_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_code_proto_rawDescGZIP() []byte {
|
||||
file_code_proto_rawDescOnce.Do(func() {
|
||||
file_code_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_code_proto_rawDesc), len(file_code_proto_rawDesc)))
|
||||
})
|
||||
return file_code_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_code_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_code_proto_goTypes = []any{
|
||||
(*Code)(nil), // 0: oauth21.Code
|
||||
(*timestamppb.Timestamp)(nil), // 1: google.protobuf.Timestamp
|
||||
}
|
||||
var file_code_proto_depIdxs = []int32{
|
||||
1, // 0: oauth21.Code.expires_at:type_name -> google.protobuf.Timestamp
|
||||
1, // [1:1] is the sub-list for method output_type
|
||||
1, // [1:1] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_code_proto_init() }
|
||||
func file_code_proto_init() {
|
||||
if File_code_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_code_proto_rawDesc), len(file_code_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_code_proto_goTypes,
|
||||
DependencyIndexes: file_code_proto_depIdxs,
|
||||
MessageInfos: file_code_proto_msgTypes,
|
||||
}.Build()
|
||||
File_code_proto = out.File
|
||||
file_code_proto_goTypes = nil
|
||||
file_code_proto_depIdxs = nil
|
||||
}
|
20
internal/oauth21/proto/code.proto
Normal file
20
internal/oauth21/proto/code.proto
Normal file
|
@ -0,0 +1,20 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package oauth21;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "buf/validate/validate.proto";
|
||||
|
||||
option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen";
|
||||
|
||||
// Code is a code used in the authorization code flow.
|
||||
message Code {
|
||||
string id = 1 [
|
||||
(buf.validate.field).required = true,
|
||||
(buf.validate.field).string = {
|
||||
min_len : 1,
|
||||
}
|
||||
];
|
||||
google.protobuf.Timestamp expires_at = 2
|
||||
[ (buf.validate.field).required = true ];
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue