use pointers

This commit is contained in:
Caleb Doxsey 2022-11-23 08:39:28 -07:00
parent 54fb9fea04
commit c34157d533
4 changed files with 69 additions and 49 deletions

View file

@ -999,10 +999,10 @@ func (o *Options) GetSharedKey() ([]byte, error) {
}
// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key.
func (o *Options) GetHPKEPrivateKey() (hpke.PrivateKey, error) {
func (o *Options) GetHPKEPrivateKey() (*hpke.PrivateKey, error) {
sharedKey, err := o.GetSharedKey()
if err != nil {
return hpke.PrivateKey{}, err
return nil, err
}
return hpke.DerivePrivateKey(sharedKey), nil

View file

@ -26,48 +26,48 @@ type PrivateKey struct {
}
// DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key.
func DerivePrivateKey(seed []byte) PrivateKey {
func DerivePrivateKey(seed []byte) *PrivateKey {
pk := kdfID.Extract(seed, nil)
data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize()))
_, key := kemID.Scheme().DeriveKeyPair(data)
return PrivateKey{key: key}
return &PrivateKey{key: key}
}
// GeneratePrivateKey generates an HPKE private key.
func GeneratePrivateKey() (PrivateKey, error) {
func GeneratePrivateKey() (*PrivateKey, error) {
_, privateKey, err := kemID.Scheme().GenerateKeyPair()
if err != nil {
return PrivateKey{}, err
return nil, err
}
return PrivateKey{key: privateKey}, nil
return &PrivateKey{key: privateKey}, nil
}
// PrivateKeyFromString takes a string and returns a PrivateKey.
func PrivateKeyFromString(raw string) (PrivateKey, error) {
func PrivateKeyFromString(raw string) (*PrivateKey, error) {
bs, err := decode(raw)
if err != nil {
return PrivateKey{}, err
return nil, err
}
key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs)
if err != nil {
return PrivateKey{}, err
return nil, err
}
return PrivateKey{key: key}, nil
return &PrivateKey{key: key}, nil
}
// PublicKey returns the public key for the private key.
func (key PrivateKey) PublicKey() PublicKey {
if key.key == nil {
return PublicKey{}
func (key *PrivateKey) PublicKey() *PublicKey {
if key == nil || key.key == nil {
return nil
}
return PublicKey{key: key.key.Public()}
return &PublicKey{key: key.key.Public()}
}
// MarshalJSON returns the JSON Web Key representation of the private key.
func (key PrivateKey) MarshalJSON() ([]byte, error) {
func (key *PrivateKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{
Type: jwkType,
ID: jwkID,
@ -78,8 +78,8 @@ func (key PrivateKey) MarshalJSON() ([]byte, error) {
}
// String converts the private key into a string.
func (key PrivateKey) String() string {
if key.key == nil {
func (key *PrivateKey) String() string {
if key == nil || key.key == nil {
return ""
}
@ -98,22 +98,28 @@ type PublicKey struct {
}
// PublicKeyFromString converts a string into a public key.
func PublicKeyFromString(raw string) (PublicKey, error) {
func PublicKeyFromString(raw string) (*PublicKey, error) {
bs, err := decode(raw)
if err != nil {
return PublicKey{}, err
return nil, err
}
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs)
if err != nil {
return PublicKey{}, err
return nil, err
}
return PublicKey{key: key}, nil
return &PublicKey{key: key}, nil
}
// Equals returns true if the two keys are equivalent.
func (key PublicKey) Equals(other PublicKey) bool {
func (key *PublicKey) Equals(other *PublicKey) bool {
if key == nil && other == nil {
return true
} else if key == nil || other == nil {
return false
}
if key.key == nil && other.key == nil {
return true
} else if key.key == nil || other.key == nil {
@ -123,7 +129,7 @@ func (key PublicKey) Equals(other PublicKey) bool {
}
// MarshalJSON returns the JSON Web Key representation of the public key.
func (key PublicKey) MarshalJSON() ([]byte, error) {
func (key *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{
Type: jwkType,
ID: jwkID,
@ -133,8 +139,8 @@ func (key PublicKey) MarshalJSON() ([]byte, error) {
}
// String converts a public key into a string.
func (key PublicKey) String() string {
if key.key == nil {
func (key *PublicKey) String() string {
if key == nil || key.key == nil {
return ""
}
@ -149,10 +155,17 @@ func (key PublicKey) String() string {
// Seal seales a message using HPKE.
func Seal(
senderPrivateKey PrivateKey,
receiverPublicKey PublicKey,
senderPrivateKey *PrivateKey,
receiverPublicKey *PublicKey,
message []byte,
) (sealed []byte, err error) {
if senderPrivateKey == nil {
return nil, fmt.Errorf("hpke: sender private key cannot be nil")
}
if receiverPublicKey == nil {
return nil, fmt.Errorf("hpke: receiver public key cannot be nil")
}
sender, err := suite.NewSender(receiverPublicKey.key, nil)
if err != nil {
return nil, fmt.Errorf("hpke: error creating sender: %w", err)
@ -173,10 +186,17 @@ func Seal(
// Open opens a message using HPKE.
func Open(
receiverPrivateKey PrivateKey,
senderPublicKey PublicKey,
receiverPrivateKey *PrivateKey,
senderPublicKey *PublicKey,
sealed []byte,
) (message []byte, err error) {
if receiverPrivateKey == nil {
return nil, fmt.Errorf("hpke: receiver private key cannot be nil")
}
if senderPublicKey == nil {
return nil, fmt.Errorf("hpke: sender public key cannot be nil")
}
encSize := kemID.Scheme().SharedKeySize()
if len(sealed) < encSize {
return nil, fmt.Errorf("hpke: invalid sealed message")

View file

@ -29,25 +29,25 @@ type JWK struct {
}
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (PublicKey, error) {
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error building jwks http request: %w", err)
return nil, fmt.Errorf("hpke: error building jwks http request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
}
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
}
var jwks struct {
@ -55,7 +55,7 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s
}
err = json.Unmarshal(bs, &jwks)
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
}
for _, key := range jwks.Keys {
@ -63,12 +63,12 @@ func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint s
return PublicKeyFromString(key.X)
}
}
return PublicKey{}, fmt.Errorf("hpke key not found in JWKS endpoint")
return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
}
// A KeyFetcher fetches public keys.
type KeyFetcher interface {
FetchPublicKey(ctx context.Context) (PublicKey, error)
FetchPublicKey(ctx context.Context) (*PublicKey, error)
}
type jwksKeyFetcher struct {
@ -76,7 +76,7 @@ type jwksKeyFetcher struct {
endpoint string
}
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (PublicKey, error) {
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
}

View file

@ -18,8 +18,8 @@ func IsEncryptedURL(values url.Values) bool {
// EncryptURLValues encrypts URL values using the Seal method.
func EncryptURLValues(
senderPrivateKey PrivateKey,
receiverPublicKey PublicKey,
senderPrivateKey *PrivateKey,
receiverPublicKey *PublicKey,
values url.Values,
) (encrypted url.Values, err error) {
values = withoutHPKEParams(values)
@ -37,34 +37,34 @@ func EncryptURLValues(
// DecryptURLValues decrypts URL values using the Open method.
func DecryptURLValues(
receiverPrivateKey PrivateKey,
receiverPrivateKey *PrivateKey,
encrypted url.Values,
) (senderPublicKey PublicKey, values url.Values, err error) {
) (senderPublicKey *PublicKey, values url.Values, err error) {
if !encrypted.Has(ParamSenderPublicKey) {
return senderPublicKey, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
}
if !encrypted.Has(ParamQuery) {
return senderPublicKey, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
}
senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey))
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
}
sealed, err := decode(encrypted.Get(ParamQuery))
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
}
message, err := Open(receiverPrivateKey, senderPublicKey, sealed)
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
}
decrypted, err := url.ParseQuery(string(message))
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
}
values = withoutHPKEParams(encrypted)