diff --git a/config/options.go b/config/options.go index fd67289a0..22bfb5b54 100644 --- a/config/options.go +++ b/config/options.go @@ -999,10 +999,10 @@ func (o *Options) GetSharedKey() ([]byte, error) { } // GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key. -func (o *Options) GetHPKEPrivateKey() (hpke.PrivateKey, error) { +func (o *Options) GetHPKEPrivateKey() (*hpke.PrivateKey, error) { sharedKey, err := o.GetSharedKey() if err != nil { - return hpke.PrivateKey{}, err + return nil, err } return hpke.DerivePrivateKey(sharedKey), nil diff --git a/pkg/hpke/hpke.go b/pkg/hpke/hpke.go index 4502254af..5cfb031ac 100644 --- a/pkg/hpke/hpke.go +++ b/pkg/hpke/hpke.go @@ -26,48 +26,48 @@ type PrivateKey struct { } // DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key. -func DerivePrivateKey(seed []byte) PrivateKey { +func DerivePrivateKey(seed []byte) *PrivateKey { pk := kdfID.Extract(seed, nil) data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize())) _, key := kemID.Scheme().DeriveKeyPair(data) - return PrivateKey{key: key} + return &PrivateKey{key: key} } // GeneratePrivateKey generates an HPKE private key. -func GeneratePrivateKey() (PrivateKey, error) { +func GeneratePrivateKey() (*PrivateKey, error) { _, privateKey, err := kemID.Scheme().GenerateKeyPair() if err != nil { - return PrivateKey{}, err + return nil, err } - return PrivateKey{key: privateKey}, nil + return &PrivateKey{key: privateKey}, nil } // PrivateKeyFromString takes a string and returns a PrivateKey. -func PrivateKeyFromString(raw string) (PrivateKey, error) { +func PrivateKeyFromString(raw string) (*PrivateKey, error) { bs, err := decode(raw) if err != nil { - return PrivateKey{}, err + return nil, err } key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs) if err != nil { - return PrivateKey{}, err + return nil, err } - return PrivateKey{key: key}, nil + return &PrivateKey{key: key}, nil } // PublicKey returns the public key for the private key. -func (key PrivateKey) PublicKey() PublicKey { - if key.key == nil { - return PublicKey{} +func (key *PrivateKey) PublicKey() *PublicKey { + if key == nil || key.key == nil { + return nil } - return PublicKey{key: key.key.Public()} + return &PublicKey{key: key.key.Public()} } // MarshalJSON returns the JSON Web Key representation of the private key. -func (key PrivateKey) MarshalJSON() ([]byte, error) { +func (key *PrivateKey) MarshalJSON() ([]byte, error) { return json.Marshal(JWK{ Type: jwkType, ID: jwkID, @@ -78,8 +78,8 @@ func (key PrivateKey) MarshalJSON() ([]byte, error) { } // String converts the private key into a string. -func (key PrivateKey) String() string { - if key.key == nil { +func (key *PrivateKey) String() string { + if key == nil || key.key == nil { return "" } @@ -98,22 +98,28 @@ type PublicKey struct { } // PublicKeyFromString converts a string into a public key. -func PublicKeyFromString(raw string) (PublicKey, error) { +func PublicKeyFromString(raw string) (*PublicKey, error) { bs, err := decode(raw) if err != nil { - return PublicKey{}, err + return nil, err } key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs) if err != nil { - return PublicKey{}, err + return nil, err } - return PublicKey{key: key}, nil + return &PublicKey{key: key}, nil } // Equals returns true if the two keys are equivalent. -func (key PublicKey) Equals(other PublicKey) bool { +func (key *PublicKey) Equals(other *PublicKey) bool { + if key == nil && other == nil { + return true + } else if key == nil || other == nil { + return false + } + if key.key == nil && other.key == nil { return true } else if key.key == nil || other.key == nil { @@ -123,7 +129,7 @@ func (key PublicKey) Equals(other PublicKey) bool { } // MarshalJSON returns the JSON Web Key representation of the public key. -func (key PublicKey) MarshalJSON() ([]byte, error) { +func (key *PublicKey) MarshalJSON() ([]byte, error) { return json.Marshal(JWK{ Type: jwkType, ID: jwkID, @@ -133,8 +139,8 @@ func (key PublicKey) MarshalJSON() ([]byte, error) { } // String converts a public key into a string. -func (key PublicKey) String() string { - if key.key == nil { +func (key *PublicKey) String() string { + if key == nil || key.key == nil { return "" } @@ -149,10 +155,17 @@ func (key PublicKey) String() string { // Seal seales a message using HPKE. func Seal( - senderPrivateKey PrivateKey, - receiverPublicKey PublicKey, + senderPrivateKey *PrivateKey, + receiverPublicKey *PublicKey, message []byte, ) (sealed []byte, err error) { + if senderPrivateKey == nil { + return nil, fmt.Errorf("hpke: sender private key cannot be nil") + } + if receiverPublicKey == nil { + return nil, fmt.Errorf("hpke: receiver public key cannot be nil") + } + sender, err := suite.NewSender(receiverPublicKey.key, nil) if err != nil { return nil, fmt.Errorf("hpke: error creating sender: %w", err) @@ -173,10 +186,17 @@ func Seal( // Open opens a message using HPKE. func Open( - receiverPrivateKey PrivateKey, - senderPublicKey PublicKey, + receiverPrivateKey *PrivateKey, + senderPublicKey *PublicKey, sealed []byte, ) (message []byte, err error) { + if receiverPrivateKey == nil { + return nil, fmt.Errorf("hpke: receiver private key cannot be nil") + } + if senderPublicKey == nil { + return nil, fmt.Errorf("hpke: sender public key cannot be nil") + } + encSize := kemID.Scheme().SharedKeySize() if len(sealed) < encSize { return nil, fmt.Errorf("hpke: invalid sealed message") diff --git a/pkg/hpke/jwks.go b/pkg/hpke/jwks.go index 634bededd..3e71676ee 100644 --- a/pkg/hpke/jwks.go +++ b/pkg/hpke/jwks.go @@ -29,25 +29,25 @@ type JWK struct { } // FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint. -func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (PublicKey, error) { +func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return PublicKey{}, fmt.Errorf("hpke: error building jwks http request: %w", err) + return nil, fmt.Errorf("hpke: error building jwks http request: %w", err) } res, err := client.Do(req) if err != nil { - return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err) + return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err) } defer res.Body.Close() if res.StatusCode/100 != 2 { - return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode) + return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode) } bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize)) if err != nil { - return PublicKey{}, fmt.Errorf("hpke: error reading jwks endpoint: %w", err) + return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err) } var jwks struct { @@ -55,7 +55,7 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s } err = json.Unmarshal(bs, &jwks) if err != nil { - return PublicKey{}, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err) + return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err) } for _, key := range jwks.Keys { @@ -63,12 +63,12 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s return PublicKeyFromString(key.X) } } - return PublicKey{}, fmt.Errorf("hpke key not found in JWKS endpoint") + return nil, fmt.Errorf("hpke key not found in JWKS endpoint") } // A KeyFetcher fetches public keys. type KeyFetcher interface { - FetchPublicKey(ctx context.Context) (PublicKey, error) + FetchPublicKey(ctx context.Context) (*PublicKey, error) } type jwksKeyFetcher struct { @@ -76,7 +76,7 @@ type jwksKeyFetcher struct { endpoint string } -func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (PublicKey, error) { +func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) { return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint) } diff --git a/pkg/hpke/url.go b/pkg/hpke/url.go index 29188b4eb..49cbbbc19 100644 --- a/pkg/hpke/url.go +++ b/pkg/hpke/url.go @@ -18,8 +18,8 @@ func IsEncryptedURL(values url.Values) bool { // EncryptURLValues encrypts URL values using the Seal method. func EncryptURLValues( - senderPrivateKey PrivateKey, - receiverPublicKey PublicKey, + senderPrivateKey *PrivateKey, + receiverPublicKey *PublicKey, values url.Values, ) (encrypted url.Values, err error) { values = withoutHPKEParams(values) @@ -37,34 +37,34 @@ func EncryptURLValues( // DecryptURLValues decrypts URL values using the Open method. func DecryptURLValues( - receiverPrivateKey PrivateKey, + receiverPrivateKey *PrivateKey, encrypted url.Values, -) (senderPublicKey PublicKey, values url.Values, err error) { +) (senderPublicKey *PublicKey, values url.Values, err error) { if !encrypted.Has(ParamSenderPublicKey) { - return senderPublicKey, nil, fmt.Errorf("hpke: missing sender public key in query parameters") + return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters") } if !encrypted.Has(ParamQuery) { - return senderPublicKey, nil, fmt.Errorf("hpke: missing encrypted query in query parameters") + return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters") } senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey)) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) + return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) } sealed, err := decode(encrypted.Get(ParamQuery)) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) + return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) } message, err := Open(receiverPrivateKey, senderPublicKey, sealed) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) + return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) } decrypted, err := url.ParseQuery(string(message)) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) + return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) } values = withoutHPKEParams(encrypted)