mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-25 23:17:18 +02:00
mcp: authorize request (pt2) (#5586)
This commit is contained in:
parent
63ccf6ab93
commit
9e4947c62f
9 changed files with 567 additions and 6 deletions
72
internal/mcp/code.go
Normal file
72
internal/mcp/code.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bufbuild/protovalidate-go"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
func CreateCode(
|
||||
id string,
|
||||
expires time.Time,
|
||||
ad string,
|
||||
cipher cipher.AEAD,
|
||||
) (string, error) {
|
||||
if expires.IsZero() {
|
||||
return "", fmt.Errorf("validate: zero expiration")
|
||||
}
|
||||
|
||||
v := oauth21proto.Code{
|
||||
Id: id,
|
||||
ExpiresAt: timestamppb.New(expires),
|
||||
}
|
||||
|
||||
err := protovalidate.Validate(&v)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
|
||||
b, err := proto.Marshal(&v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
func DecryptCode(
|
||||
code string,
|
||||
cipher cipher.AEAD,
|
||||
ad string,
|
||||
now time.Time,
|
||||
) (*oauth21proto.Code, error) {
|
||||
b, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
var v oauth21proto.Code
|
||||
err = proto.Unmarshal(plaintext, &v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
if v.ExpiresAt == nil {
|
||||
return nil, fmt.Errorf("expiration is nil")
|
||||
}
|
||||
if v.ExpiresAt.AsTime().Before(now) {
|
||||
return nil, fmt.Errorf("code expired")
|
||||
}
|
||||
return &v, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue