mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-27 06:58:13 +02:00
use pointers
This commit is contained in:
parent
54fb9fea04
commit
c34157d533
4 changed files with 69 additions and 49 deletions
|
@ -999,10 +999,10 @@ func (o *Options) GetSharedKey() ([]byte, error) {
|
|||
}
|
||||
|
||||
// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key.
|
||||
func (o *Options) GetHPKEPrivateKey() (hpke.PrivateKey, error) {
|
||||
func (o *Options) GetHPKEPrivateKey() (*hpke.PrivateKey, error) {
|
||||
sharedKey, err := o.GetSharedKey()
|
||||
if err != nil {
|
||||
return hpke.PrivateKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hpke.DerivePrivateKey(sharedKey), nil
|
||||
|
|
|
@ -26,48 +26,48 @@ type PrivateKey struct {
|
|||
}
|
||||
|
||||
// DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key.
|
||||
func DerivePrivateKey(seed []byte) PrivateKey {
|
||||
func DerivePrivateKey(seed []byte) *PrivateKey {
|
||||
pk := kdfID.Extract(seed, nil)
|
||||
data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize()))
|
||||
_, key := kemID.Scheme().DeriveKeyPair(data)
|
||||
return PrivateKey{key: key}
|
||||
return &PrivateKey{key: key}
|
||||
}
|
||||
|
||||
// GeneratePrivateKey generates an HPKE private key.
|
||||
func GeneratePrivateKey() (PrivateKey, error) {
|
||||
func GeneratePrivateKey() (*PrivateKey, error) {
|
||||
_, privateKey, err := kemID.Scheme().GenerateKeyPair()
|
||||
if err != nil {
|
||||
return PrivateKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
return PrivateKey{key: privateKey}, nil
|
||||
return &PrivateKey{key: privateKey}, nil
|
||||
}
|
||||
|
||||
// PrivateKeyFromString takes a string and returns a PrivateKey.
|
||||
func PrivateKeyFromString(raw string) (PrivateKey, error) {
|
||||
func PrivateKeyFromString(raw string) (*PrivateKey, error) {
|
||||
bs, err := decode(raw)
|
||||
if err != nil {
|
||||
return PrivateKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs)
|
||||
if err != nil {
|
||||
return PrivateKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return PrivateKey{key: key}, nil
|
||||
return &PrivateKey{key: key}, nil
|
||||
}
|
||||
|
||||
// PublicKey returns the public key for the private key.
|
||||
func (key PrivateKey) PublicKey() PublicKey {
|
||||
if key.key == nil {
|
||||
return PublicKey{}
|
||||
func (key *PrivateKey) PublicKey() *PublicKey {
|
||||
if key == nil || key.key == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return PublicKey{key: key.key.Public()}
|
||||
return &PublicKey{key: key.key.Public()}
|
||||
}
|
||||
|
||||
// MarshalJSON returns the JSON Web Key representation of the private key.
|
||||
func (key PrivateKey) MarshalJSON() ([]byte, error) {
|
||||
func (key *PrivateKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(JWK{
|
||||
Type: jwkType,
|
||||
ID: jwkID,
|
||||
|
@ -78,8 +78,8 @@ func (key PrivateKey) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// String converts the private key into a string.
|
||||
func (key PrivateKey) String() string {
|
||||
if key.key == nil {
|
||||
func (key *PrivateKey) String() string {
|
||||
if key == nil || key.key == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
@ -98,22 +98,28 @@ type PublicKey struct {
|
|||
}
|
||||
|
||||
// PublicKeyFromString converts a string into a public key.
|
||||
func PublicKeyFromString(raw string) (PublicKey, error) {
|
||||
func PublicKeyFromString(raw string) (*PublicKey, error) {
|
||||
bs, err := decode(raw)
|
||||
if err != nil {
|
||||
return PublicKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs)
|
||||
if err != nil {
|
||||
return PublicKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return PublicKey{key: key}, nil
|
||||
return &PublicKey{key: key}, nil
|
||||
}
|
||||
|
||||
// Equals returns true if the two keys are equivalent.
|
||||
func (key PublicKey) Equals(other PublicKey) bool {
|
||||
func (key *PublicKey) Equals(other *PublicKey) bool {
|
||||
if key == nil && other == nil {
|
||||
return true
|
||||
} else if key == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if key.key == nil && other.key == nil {
|
||||
return true
|
||||
} else if key.key == nil || other.key == nil {
|
||||
|
@ -123,7 +129,7 @@ func (key PublicKey) Equals(other PublicKey) bool {
|
|||
}
|
||||
|
||||
// MarshalJSON returns the JSON Web Key representation of the public key.
|
||||
func (key PublicKey) MarshalJSON() ([]byte, error) {
|
||||
func (key *PublicKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(JWK{
|
||||
Type: jwkType,
|
||||
ID: jwkID,
|
||||
|
@ -133,8 +139,8 @@ func (key PublicKey) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// String converts a public key into a string.
|
||||
func (key PublicKey) String() string {
|
||||
if key.key == nil {
|
||||
func (key *PublicKey) String() string {
|
||||
if key == nil || key.key == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
@ -149,10 +155,17 @@ func (key PublicKey) String() string {
|
|||
|
||||
// Seal seales a message using HPKE.
|
||||
func Seal(
|
||||
senderPrivateKey PrivateKey,
|
||||
receiverPublicKey PublicKey,
|
||||
senderPrivateKey *PrivateKey,
|
||||
receiverPublicKey *PublicKey,
|
||||
message []byte,
|
||||
) (sealed []byte, err error) {
|
||||
if senderPrivateKey == nil {
|
||||
return nil, fmt.Errorf("hpke: sender private key cannot be nil")
|
||||
}
|
||||
if receiverPublicKey == nil {
|
||||
return nil, fmt.Errorf("hpke: receiver public key cannot be nil")
|
||||
}
|
||||
|
||||
sender, err := suite.NewSender(receiverPublicKey.key, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hpke: error creating sender: %w", err)
|
||||
|
@ -173,10 +186,17 @@ func Seal(
|
|||
|
||||
// Open opens a message using HPKE.
|
||||
func Open(
|
||||
receiverPrivateKey PrivateKey,
|
||||
senderPublicKey PublicKey,
|
||||
receiverPrivateKey *PrivateKey,
|
||||
senderPublicKey *PublicKey,
|
||||
sealed []byte,
|
||||
) (message []byte, err error) {
|
||||
if receiverPrivateKey == nil {
|
||||
return nil, fmt.Errorf("hpke: receiver private key cannot be nil")
|
||||
}
|
||||
if senderPublicKey == nil {
|
||||
return nil, fmt.Errorf("hpke: sender public key cannot be nil")
|
||||
}
|
||||
|
||||
encSize := kemID.Scheme().SharedKeySize()
|
||||
if len(sealed) < encSize {
|
||||
return nil, fmt.Errorf("hpke: invalid sealed message")
|
||||
|
|
|
@ -29,25 +29,25 @@ type JWK struct {
|
|||
}
|
||||
|
||||
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
|
||||
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (PublicKey, error) {
|
||||
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error building jwks http request: %w", err)
|
||||
return nil, fmt.Errorf("hpke: error building jwks http request: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
|
||||
return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
|
||||
return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
|
||||
return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
|
||||
}
|
||||
|
||||
var jwks struct {
|
||||
|
@ -55,7 +55,7 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s
|
|||
}
|
||||
err = json.Unmarshal(bs, &jwks)
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
|
||||
return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
|
||||
}
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
|
@ -63,12 +63,12 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s
|
|||
return PublicKeyFromString(key.X)
|
||||
}
|
||||
}
|
||||
return PublicKey{}, fmt.Errorf("hpke key not found in JWKS endpoint")
|
||||
return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
|
||||
}
|
||||
|
||||
// A KeyFetcher fetches public keys.
|
||||
type KeyFetcher interface {
|
||||
FetchPublicKey(ctx context.Context) (PublicKey, error)
|
||||
FetchPublicKey(ctx context.Context) (*PublicKey, error)
|
||||
}
|
||||
|
||||
type jwksKeyFetcher struct {
|
||||
|
@ -76,7 +76,7 @@ type jwksKeyFetcher struct {
|
|||
endpoint string
|
||||
}
|
||||
|
||||
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (PublicKey, error) {
|
||||
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
|
||||
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ func IsEncryptedURL(values url.Values) bool {
|
|||
|
||||
// EncryptURLValues encrypts URL values using the Seal method.
|
||||
func EncryptURLValues(
|
||||
senderPrivateKey PrivateKey,
|
||||
receiverPublicKey PublicKey,
|
||||
senderPrivateKey *PrivateKey,
|
||||
receiverPublicKey *PublicKey,
|
||||
values url.Values,
|
||||
) (encrypted url.Values, err error) {
|
||||
values = withoutHPKEParams(values)
|
||||
|
@ -37,34 +37,34 @@ func EncryptURLValues(
|
|||
|
||||
// DecryptURLValues decrypts URL values using the Open method.
|
||||
func DecryptURLValues(
|
||||
receiverPrivateKey PrivateKey,
|
||||
receiverPrivateKey *PrivateKey,
|
||||
encrypted url.Values,
|
||||
) (senderPublicKey PublicKey, values url.Values, err error) {
|
||||
) (senderPublicKey *PublicKey, values url.Values, err error) {
|
||||
if !encrypted.Has(ParamSenderPublicKey) {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
|
||||
return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
|
||||
}
|
||||
if !encrypted.Has(ParamQuery) {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
|
||||
return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
|
||||
}
|
||||
|
||||
senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey))
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
|
||||
}
|
||||
|
||||
sealed, err := decode(encrypted.Get(ParamQuery))
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
|
||||
}
|
||||
|
||||
message, err := Open(receiverPrivateKey, senderPublicKey, sealed)
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
|
||||
}
|
||||
|
||||
decrypted, err := url.ParseQuery(string(message))
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
|
||||
}
|
||||
|
||||
values = withoutHPKEParams(encrypted)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue