use pointers

This commit is contained in:
Caleb Doxsey 2022-11-23 08:39:28 -07:00
parent 54fb9fea04
commit c34157d533
4 changed files with 69 additions and 49 deletions

View file

@ -999,10 +999,10 @@ func (o *Options) GetSharedKey() ([]byte, error) {
} }
// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key. // GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key.
func (o *Options) GetHPKEPrivateKey() (hpke.PrivateKey, error) { func (o *Options) GetHPKEPrivateKey() (*hpke.PrivateKey, error) {
sharedKey, err := o.GetSharedKey() sharedKey, err := o.GetSharedKey()
if err != nil { if err != nil {
return hpke.PrivateKey{}, err return nil, err
} }
return hpke.DerivePrivateKey(sharedKey), nil return hpke.DerivePrivateKey(sharedKey), nil

View file

@ -26,48 +26,48 @@ type PrivateKey struct {
} }
// DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key. // DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key.
func DerivePrivateKey(seed []byte) PrivateKey { func DerivePrivateKey(seed []byte) *PrivateKey {
pk := kdfID.Extract(seed, nil) pk := kdfID.Extract(seed, nil)
data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize())) data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize()))
_, key := kemID.Scheme().DeriveKeyPair(data) _, key := kemID.Scheme().DeriveKeyPair(data)
return PrivateKey{key: key} return &PrivateKey{key: key}
} }
// GeneratePrivateKey generates an HPKE private key. // GeneratePrivateKey generates an HPKE private key.
func GeneratePrivateKey() (PrivateKey, error) { func GeneratePrivateKey() (*PrivateKey, error) {
_, privateKey, err := kemID.Scheme().GenerateKeyPair() _, privateKey, err := kemID.Scheme().GenerateKeyPair()
if err != nil { if err != nil {
return PrivateKey{}, err return nil, err
} }
return PrivateKey{key: privateKey}, nil return &PrivateKey{key: privateKey}, nil
} }
// PrivateKeyFromString takes a string and returns a PrivateKey. // PrivateKeyFromString takes a string and returns a PrivateKey.
func PrivateKeyFromString(raw string) (PrivateKey, error) { func PrivateKeyFromString(raw string) (*PrivateKey, error) {
bs, err := decode(raw) bs, err := decode(raw)
if err != nil { if err != nil {
return PrivateKey{}, err return nil, err
} }
key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs) key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs)
if err != nil { if err != nil {
return PrivateKey{}, err return nil, err
} }
return PrivateKey{key: key}, nil return &PrivateKey{key: key}, nil
} }
// PublicKey returns the public key for the private key. // PublicKey returns the public key for the private key.
func (key PrivateKey) PublicKey() PublicKey { func (key *PrivateKey) PublicKey() *PublicKey {
if key.key == nil { if key == nil || key.key == nil {
return PublicKey{} return nil
} }
return PublicKey{key: key.key.Public()} return &PublicKey{key: key.key.Public()}
} }
// MarshalJSON returns the JSON Web Key representation of the private key. // MarshalJSON returns the JSON Web Key representation of the private key.
func (key PrivateKey) MarshalJSON() ([]byte, error) { func (key *PrivateKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{ return json.Marshal(JWK{
Type: jwkType, Type: jwkType,
ID: jwkID, ID: jwkID,
@ -78,8 +78,8 @@ func (key PrivateKey) MarshalJSON() ([]byte, error) {
} }
// String converts the private key into a string. // String converts the private key into a string.
func (key PrivateKey) String() string { func (key *PrivateKey) String() string {
if key.key == nil { if key == nil || key.key == nil {
return "" return ""
} }
@ -98,22 +98,28 @@ type PublicKey struct {
} }
// PublicKeyFromString converts a string into a public key. // PublicKeyFromString converts a string into a public key.
func PublicKeyFromString(raw string) (PublicKey, error) { func PublicKeyFromString(raw string) (*PublicKey, error) {
bs, err := decode(raw) bs, err := decode(raw)
if err != nil { if err != nil {
return PublicKey{}, err return nil, err
} }
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs) key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs)
if err != nil { if err != nil {
return PublicKey{}, err return nil, err
} }
return PublicKey{key: key}, nil return &PublicKey{key: key}, nil
} }
// Equals returns true if the two keys are equivalent. // Equals returns true if the two keys are equivalent.
func (key PublicKey) Equals(other PublicKey) bool { func (key *PublicKey) Equals(other *PublicKey) bool {
if key == nil && other == nil {
return true
} else if key == nil || other == nil {
return false
}
if key.key == nil && other.key == nil { if key.key == nil && other.key == nil {
return true return true
} else if key.key == nil || other.key == nil { } else if key.key == nil || other.key == nil {
@ -123,7 +129,7 @@ func (key PublicKey) Equals(other PublicKey) bool {
} }
// MarshalJSON returns the JSON Web Key representation of the public key. // MarshalJSON returns the JSON Web Key representation of the public key.
func (key PublicKey) MarshalJSON() ([]byte, error) { func (key *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{ return json.Marshal(JWK{
Type: jwkType, Type: jwkType,
ID: jwkID, ID: jwkID,
@ -133,8 +139,8 @@ func (key PublicKey) MarshalJSON() ([]byte, error) {
} }
// String converts a public key into a string. // String converts a public key into a string.
func (key PublicKey) String() string { func (key *PublicKey) String() string {
if key.key == nil { if key == nil || key.key == nil {
return "" return ""
} }
@ -149,10 +155,17 @@ func (key PublicKey) String() string {
// Seal seales a message using HPKE. // Seal seales a message using HPKE.
func Seal( func Seal(
senderPrivateKey PrivateKey, senderPrivateKey *PrivateKey,
receiverPublicKey PublicKey, receiverPublicKey *PublicKey,
message []byte, message []byte,
) (sealed []byte, err error) { ) (sealed []byte, err error) {
if senderPrivateKey == nil {
return nil, fmt.Errorf("hpke: sender private key cannot be nil")
}
if receiverPublicKey == nil {
return nil, fmt.Errorf("hpke: receiver public key cannot be nil")
}
sender, err := suite.NewSender(receiverPublicKey.key, nil) sender, err := suite.NewSender(receiverPublicKey.key, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("hpke: error creating sender: %w", err) return nil, fmt.Errorf("hpke: error creating sender: %w", err)
@ -173,10 +186,17 @@ func Seal(
// Open opens a message using HPKE. // Open opens a message using HPKE.
func Open( func Open(
receiverPrivateKey PrivateKey, receiverPrivateKey *PrivateKey,
senderPublicKey PublicKey, senderPublicKey *PublicKey,
sealed []byte, sealed []byte,
) (message []byte, err error) { ) (message []byte, err error) {
if receiverPrivateKey == nil {
return nil, fmt.Errorf("hpke: receiver private key cannot be nil")
}
if senderPublicKey == nil {
return nil, fmt.Errorf("hpke: sender public key cannot be nil")
}
encSize := kemID.Scheme().SharedKeySize() encSize := kemID.Scheme().SharedKeySize()
if len(sealed) < encSize { if len(sealed) < encSize {
return nil, fmt.Errorf("hpke: invalid sealed message") return nil, fmt.Errorf("hpke: invalid sealed message")

View file

@ -29,25 +29,25 @@ type JWK struct {
} }
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint. // FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (PublicKey, error) { func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error building jwks http request: %w", err) return nil, fmt.Errorf("hpke: error building jwks http request: %w", err)
} }
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err) return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode/100 != 2 { if res.StatusCode/100 != 2 {
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode) return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
} }
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize)) bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
if err != nil { if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error reading jwks endpoint: %w", err) return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
} }
var jwks struct { var jwks struct {
@ -55,7 +55,7 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s
} }
err = json.Unmarshal(bs, &jwks) err = json.Unmarshal(bs, &jwks)
if err != nil { if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err) return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
} }
for _, key := range jwks.Keys { for _, key := range jwks.Keys {
@ -63,12 +63,12 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s
return PublicKeyFromString(key.X) return PublicKeyFromString(key.X)
} }
} }
return PublicKey{}, fmt.Errorf("hpke key not found in JWKS endpoint") return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
} }
// A KeyFetcher fetches public keys. // A KeyFetcher fetches public keys.
type KeyFetcher interface { type KeyFetcher interface {
FetchPublicKey(ctx context.Context) (PublicKey, error) FetchPublicKey(ctx context.Context) (*PublicKey, error)
} }
type jwksKeyFetcher struct { type jwksKeyFetcher struct {
@ -76,7 +76,7 @@ type jwksKeyFetcher struct {
endpoint string endpoint string
} }
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (PublicKey, error) { func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint) return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
} }

View file

@ -18,8 +18,8 @@ func IsEncryptedURL(values url.Values) bool {
// EncryptURLValues encrypts URL values using the Seal method. // EncryptURLValues encrypts URL values using the Seal method.
func EncryptURLValues( func EncryptURLValues(
senderPrivateKey PrivateKey, senderPrivateKey *PrivateKey,
receiverPublicKey PublicKey, receiverPublicKey *PublicKey,
values url.Values, values url.Values,
) (encrypted url.Values, err error) { ) (encrypted url.Values, err error) {
values = withoutHPKEParams(values) values = withoutHPKEParams(values)
@ -37,34 +37,34 @@ func EncryptURLValues(
// DecryptURLValues decrypts URL values using the Open method. // DecryptURLValues decrypts URL values using the Open method.
func DecryptURLValues( func DecryptURLValues(
receiverPrivateKey PrivateKey, receiverPrivateKey *PrivateKey,
encrypted url.Values, encrypted url.Values,
) (senderPublicKey PublicKey, values url.Values, err error) { ) (senderPublicKey *PublicKey, values url.Values, err error) {
if !encrypted.Has(ParamSenderPublicKey) { if !encrypted.Has(ParamSenderPublicKey) {
return senderPublicKey, nil, fmt.Errorf("hpke: missing sender public key in query parameters") return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
} }
if !encrypted.Has(ParamQuery) { if !encrypted.Has(ParamQuery) {
return senderPublicKey, nil, fmt.Errorf("hpke: missing encrypted query in query parameters") return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
} }
senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey)) senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey))
if err != nil { if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
} }
sealed, err := decode(encrypted.Get(ParamQuery)) sealed, err := decode(encrypted.Get(ParamQuery))
if err != nil { if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
} }
message, err := Open(receiverPrivateKey, senderPublicKey, sealed) message, err := Open(receiverPrivateKey, senderPublicKey, sealed)
if err != nil { if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
} }
decrypted, err := url.ParseQuery(string(message)) decrypted, err := url.ParseQuery(string(message))
if err != nil { if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
} }
values = withoutHPKEParams(encrypted) values = withoutHPKEParams(encrypted)