pomerium/internal/mcp/code.go
2025-04-25 14:47:11 -04:00

85 lines
1.9 KiB
Go

package mcp
import (
"crypto/cipher"
"encoding/base64"
"fmt"
"time"
"github.com/bufbuild/protovalidate-go"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
const (
CodeTypeAuthorization = oauth21proto.CodeType_CODE_TYPE_AUTHORIZATION
CodeTypeRefresh = oauth21proto.CodeType_CODE_TYPE_REFRESH
CodeTypeAccess = oauth21proto.CodeType_CODE_TYPE_ACCESS
)
func CreateCode(
typ oauth21proto.CodeType,
id string,
expires time.Time,
ad string,
cipher cipher.AEAD,
) (string, error) {
if expires.IsZero() {
return "", fmt.Errorf("validate: zero expiration")
}
v := oauth21proto.Code{
Id: id,
ExpiresAt: timestamppb.New(expires),
GrantType: typ,
}
err := protovalidate.Validate(&v)
if err != nil {
return "", fmt.Errorf("validate: %w", err)
}
b, err := proto.Marshal(&v)
if err != nil {
return "", err
}
ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func DecryptCode(
typ oauth21proto.CodeType,
code string,
cipher cipher.AEAD,
ad string,
now time.Time,
) (*oauth21proto.Code, error) {
b, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err)
}
var v oauth21proto.Code
err = proto.Unmarshal(plaintext, &v)
if err != nil {
return nil, fmt.Errorf("unmarshal: %w", err)
}
err = protovalidate.Validate(&v)
if err != nil {
return nil, fmt.Errorf("validate: %w", err)
}
if v.GrantType != typ {
return nil, fmt.Errorf("code type mismatch: expected %v, got %v", typ, v.GrantType)
}
if v.ExpiresAt.AsTime().Before(now) {
return nil, fmt.Errorf("code expired")
}
return &v, nil
}