diff --git a/pkg/cryptutil/dek.go b/pkg/cryptutil/dek.go new file mode 100644 index 000000000..e706b5864 --- /dev/null +++ b/pkg/cryptutil/dek.go @@ -0,0 +1,107 @@ +package cryptutil + +import ( + "crypto/cipher" + "encoding/base64" + "fmt" + + lru "github.com/hashicorp/golang-lru" + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + // DataEncryptionKeySize is the size of a data encryption key. + DataEncryptionKeySize = chacha20poly1305.KeySize + // DataEncryptionKeyCacheSize is the number of DEKs to keep in the LRU cache. + DataEncryptionKeyCacheSize = 20 +) + +// A DataEncryptionKey is an XChaCha20Poly1305 symmetric encryption key. For more details +// see the documentation on KeyEncryptionKeys. +type DataEncryptionKey struct { + data [DataEncryptionKeySize]byte + cipher cipher.AEAD +} + +// NewDataEncryptionKey returns a new DataEncryptionKey from existing bytes. +func NewDataEncryptionKey(raw []byte) (*DataEncryptionKey, error) { + if len(raw) != DataEncryptionKeySize { + return nil, fmt.Errorf("cryptutil: invalid data encryption key, expected %d bytes, got %d", + DataEncryptionKeySize, len(raw)) + } + dek := new(DataEncryptionKey) + copy(dek.data[:], raw) + dek.cipher, _ = chacha20poly1305.NewX(raw) // only errors on invalid size + return dek, nil +} + +// GenerateDataEncryptionKey generates a new random data encryption key. +func GenerateDataEncryptionKey() (*DataEncryptionKey, error) { + raw := randomBytes(DataEncryptionKeySize) + return NewDataEncryptionKey(raw) +} + +// Decrypt decrypts encrypted data using the data encryption key. +func (dek *DataEncryptionKey) Decrypt(ciphertext []byte) ([]byte, error) { + return Decrypt(dek.cipher, ciphertext, nil) +} + +// DecryptString decrypts an encrypted string using the data encryption key and base64 encoding. +func (dek *DataEncryptionKey) DecryptString(ciphertext string) (string, error) { + ciphertextBytes, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", err + } + plaintextBytes, err := dek.Decrypt(ciphertextBytes) + if err != nil { + return "", err + } + return string(plaintextBytes), nil +} + +// Encrypt encrypts data using the data encryption key. +func (dek *DataEncryptionKey) Encrypt(plaintext []byte) []byte { + return Encrypt(dek.cipher, plaintext, nil) +} + +// EncryptString encrypts a string using the data encryption key and base64 encoding. +func (dek *DataEncryptionKey) EncryptString(plaintext string) string { + bs := dek.Encrypt([]byte(plaintext)) + return base64.StdEncoding.EncodeToString(bs) +} + +// KeyBytes returns the private key encryption key's raw bytes. +func (dek *DataEncryptionKey) KeyBytes() []byte { + data := make([]byte, DataEncryptionKeySize) + copy(data, dek.data[:]) + return data +} + +// A DataEncryptionKeyCache caches recently used data encryption keys based on their +// encrypted representation. The cache is safe for concurrent read and write access. +// +// Internally an LRU cache is used and the encrypted DEK bytes are converted to strings +// to allow usage as hash map keys. +type DataEncryptionKeyCache struct { + lru *lru.Cache +} + +// NewDataEncryptionKeyCache creates a new DataEncryptionKeyCache. +func NewDataEncryptionKeyCache() *DataEncryptionKeyCache { + c, _ := lru.New(DataEncryptionKeyCacheSize) // only errors if size <= 0 + return &DataEncryptionKeyCache{lru: c} +} + +// Get returns a data encryption key if available. +func (cache *DataEncryptionKeyCache) Get(encryptedDEK []byte) (*DataEncryptionKey, bool) { + obj, ok := cache.lru.Get(string(encryptedDEK)) + if ok { + return obj.(*DataEncryptionKey), true + } + return nil, false +} + +// Put stores a data encryption key by its encrypted representation. +func (cache *DataEncryptionKeyCache) Put(encryptedDEK []byte, dek *DataEncryptionKey) { + cache.lru.Add(string(encryptedDEK), dek) +} diff --git a/pkg/cryptutil/dek_test.go b/pkg/cryptutil/dek_test.go new file mode 100644 index 000000000..c5c4d219e --- /dev/null +++ b/pkg/cryptutil/dek_test.go @@ -0,0 +1,86 @@ +package cryptutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDataEncryptionKey(t *testing.T) { + t.Run("roundtrip", func(t *testing.T) { + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + ciphertext := dek.Encrypt([]byte("HELLO WORLD")) + plaintext, err := dek.Decrypt(ciphertext) + require.NoError(t, err) + require.Equal(t, []byte("HELLO WORLD"), plaintext) + }) + t.Run("roundtrip string", func(t *testing.T) { + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + ciphertext := dek.EncryptString(("HELLO WORLD")) + plaintext, err := dek.DecryptString(ciphertext) + require.NoError(t, err) + require.Equal(t, ("HELLO WORLD"), plaintext) + }) + t.Run("KeyBytes", func(t *testing.T) { + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + assert.Equal(t, dek.data[:], dek.KeyBytes()) + assert.NotSame(t, dek.data[:], dek.KeyBytes()) + }) + t.Run("invalid key", func(t *testing.T) { + dek, err := NewDataEncryptionKey([]byte("NOT BIG ENOUGH")) + require.Nil(t, dek) + require.Error(t, err) + }) + t.Run("bad data", func(t *testing.T) { + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + ciphertext := dek.Encrypt([]byte("HELLO WORLD")) + ciphertext[3]++ + plaintext, err := dek.Decrypt(ciphertext) + require.Error(t, err) + require.Nil(t, plaintext) + }) +} + +func TestDataEncryptionKeyCache(t *testing.T) { + t.Run("roundtrip", func(t *testing.T) { + cache := NewDataEncryptionKeyCache() + kek, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + ciphertext, err := kek.Public().EncryptDataEncryptionKey(dek) + require.NoError(t, err) + cache.Put(ciphertext, dek) + dek2, ok := cache.Get(ciphertext) + require.True(t, ok) + require.Equal(t, dek, dek2) + }) + t.Run("eviction", func(t *testing.T) { + cache := NewDataEncryptionKeyCache() + kek, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + ciphertext, err := kek.Public().EncryptDataEncryptionKey(dek) + require.NoError(t, err) + cache.Put(ciphertext, dek) + + for i := 0; i < DataEncryptionKeyCacheSize; i++ { + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + ciphertext, err := kek.Public().EncryptDataEncryptionKey(dek) + require.NoError(t, err) + cache.Put(ciphertext, dek) + } + + dek2, ok := cache.Get(ciphertext) + require.False(t, ok, "should evict the least recently used DEK") + require.Nil(t, dek2) + }) +} diff --git a/pkg/cryptutil/kek.go b/pkg/cryptutil/kek.go new file mode 100644 index 000000000..b790968e2 --- /dev/null +++ b/pkg/cryptutil/kek.go @@ -0,0 +1,166 @@ +package cryptutil + +import ( + "crypto/rand" + "fmt" + + "github.com/btcsuite/btcutil/base58" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/nacl/box" +) + +// A KeyEncryptionKey (KEK) is used to implement *envelope encryption*, similar to how data is stored at rest with +// AWS or Google Cloud: +// +// - AWS: https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#enveloping +// - Google Cloud: https://cloud.google.com/kms/docs/envelope-encryption +// +// Data is encrypted with a data encryption key (DEK) and that key is stored next to the data encrypted with the KEK. +// Finally the KEK id is also stored with the data. +// +// To decrypt the data you first retrieve the KEK, second decrypt the DEK, and finally decrypt the data using the DEK. +// +// - Our KEKs are asymmetric Curve25519 keys. We use the *public* key to encrypt the DEK so only the *private* key can +// decrypt it. +// - Our DEKs are symmetric XChaCha20Poly1305 keys. +// +type KeyEncryptionKey interface { + ID() string + KeyBytes() []byte + + isKeyEncryptionKey() +} + +// KeyEncryptionKeySize is the size of a key encryption key. +const KeyEncryptionKeySize = curve25519.ScalarSize + +// PrivateKeyEncryptionKey is a Curve25519 asymmetric private encryption key used to decrypt data encryption keys. +type PrivateKeyEncryptionKey struct { + id string + data [KeyEncryptionKeySize]byte +} + +func (*PrivateKeyEncryptionKey) isKeyEncryptionKey() {} + +// NewPrivateKeyEncryptionKey creates a new encryption key from existing bytes. +func NewPrivateKeyEncryptionKey(id string, raw []byte) (*PrivateKeyEncryptionKey, error) { + if len(raw) != KeyEncryptionKeySize { + return nil, fmt.Errorf("cryptutil: invalid key encryption key, expected %d bytes, got %d", + KeyEncryptionKeySize, len(raw)) + } + kek := new(PrivateKeyEncryptionKey) + kek.id = id + copy(kek.data[:], raw) + return kek, nil +} + +// GenerateKeyEncryptionKey generates a new random key encryption key. +func GenerateKeyEncryptionKey() (*PrivateKeyEncryptionKey, error) { + raw := randomBytes(KeyEncryptionKeySize) + id := GetKeyEncryptionKeyID(raw) + return NewPrivateKeyEncryptionKey(id, raw) +} + +// GetKeyEncryptionKeyID derives an id from the key encryption key data itself. +func GetKeyEncryptionKeyID(raw []byte) string { + return base58.Encode(Hash("KeyEncryptionKey", raw)) +} + +// Decrypt decrypts data from a NACL anonymous box. +func (kek *PrivateKeyEncryptionKey) Decrypt(ciphertext []byte) ([]byte, error) { + private := kek + public := kek.Public() + + opened, ok := box.OpenAnonymous(nil, ciphertext, &public.data, &private.data) + if !ok { + return nil, fmt.Errorf("cryptutil: anonymous box decrypt failed") + } + return opened, nil +} + +// DecryptDataEncryptionKey decrypts a data encryption key. +func (kek *PrivateKeyEncryptionKey) DecryptDataEncryptionKey(ciphertext []byte) (*DataEncryptionKey, error) { + raw, err := kek.Decrypt(ciphertext) + if err != nil { + return nil, err + } + return NewDataEncryptionKey(raw) +} + +// ID returns the private key's id. +func (kek *PrivateKeyEncryptionKey) ID() string { + return kek.id +} + +// KeyBytes returns the private key encryption key's raw bytes. +func (kek *PrivateKeyEncryptionKey) KeyBytes() []byte { + data := make([]byte, KeyEncryptionKeySize) + copy(data, kek.data[:]) + return data +} + +// Public returns the private key's public key. +func (kek *PrivateKeyEncryptionKey) Public() *PublicKeyEncryptionKey { + // taken from NACL box.GenerateKey + var publicKey [32]byte + curve25519.ScalarBaseMult(&publicKey, &kek.data) + return &PublicKeyEncryptionKey{id: kek.id, data: publicKey} +} + +// PublicKeyEncryptionKey is a Curve25519 asymmetric public encryption key used to encrypt data encryption keys. +type PublicKeyEncryptionKey struct { + id string + data [KeyEncryptionKeySize]byte +} + +func (*PublicKeyEncryptionKey) isKeyEncryptionKey() {} + +// NewPublicKeyEncryptionKey creates a new encryption key from existing bytes. +func NewPublicKeyEncryptionKey(id string, raw []byte) (*PublicKeyEncryptionKey, error) { + if len(raw) != KeyEncryptionKeySize { + return nil, fmt.Errorf("cryptutil: invalid key encryption key, expected %d bytes, got %d", + KeyEncryptionKeySize, len(raw)) + } + kek := new(PublicKeyEncryptionKey) + copy(kek.data[:], raw) + return kek, nil +} + +// ID returns the public key's id. +func (kek *PublicKeyEncryptionKey) ID() string { + return kek.id +} + +// KeyBytes returns the public key's raw bytes. +func (kek *PublicKeyEncryptionKey) KeyBytes() []byte { + data := make([]byte, KeyEncryptionKeySize) + copy(data, kek.data[:]) + return data +} + +// Encrypt encrypts data using a NACL anonymous box. +func (kek *PublicKeyEncryptionKey) Encrypt(plaintext []byte) ([]byte, error) { + sealed, err := box.SealAnonymous(nil, plaintext, &kek.data, rand.Reader) + if err != nil { // only fails on rand.Read errors + return nil, fmt.Errorf("cryptutil: anonymous box encrypt failed: %w", err) + } + return sealed, nil +} + +// EncryptDataEncryptionKey encrypts a DataEncryptionKey. +func (kek *PublicKeyEncryptionKey) EncryptDataEncryptionKey(dek *DataEncryptionKey) ([]byte, error) { + return kek.Encrypt(dek.data[:]) +} + +// A KeyEncryptionKeySource gets private key encryption keys based on their id. +type KeyEncryptionKeySource interface { + GetKeyEncryptionKey(id string) (*PrivateKeyEncryptionKey, error) +} + +// A KeyEncryptionKeySourceFunc implements the KeyEncryptionKeySource interface using a function. +type KeyEncryptionKeySourceFunc func(id string) (*PrivateKeyEncryptionKey, error) + +// GetKeyEncryptionKey gets the key encryption key by calling the underlying function. +func (src KeyEncryptionKeySourceFunc) GetKeyEncryptionKey(id string) (*PrivateKeyEncryptionKey, error) { + return src(id) +} diff --git a/pkg/cryptutil/kek_test.go b/pkg/cryptutil/kek_test.go new file mode 100644 index 000000000..57e47ed7e --- /dev/null +++ b/pkg/cryptutil/kek_test.go @@ -0,0 +1,83 @@ +package cryptutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKeyEncryptionKey(t *testing.T) { + t.Run("roundtrip", func(t *testing.T) { + kek, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + assert.NotEqual(t, make([]byte, KeyEncryptionKeySize), kek.data) + ciphertext, err := kek.Public().Encrypt([]byte("HELLO WORLD")) + require.NoError(t, err) + plaintext, err := kek.Decrypt(ciphertext) + require.NoError(t, err) + require.Equal(t, []byte("HELLO WORLD"), plaintext) + }) + t.Run("anonymous", func(t *testing.T) { + kek, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + kekPublic, err := NewPublicKeyEncryptionKey(kek.ID(), kek.Public().KeyBytes()) + require.NoError(t, err) + ciphertext, err := kekPublic.Encrypt([]byte("HELLO WORLD")) + require.NoError(t, err) + plaintext, err := kek.Decrypt(ciphertext) + require.NoError(t, err) + require.Equal(t, []byte("HELLO WORLD"), plaintext) + }) + t.Run("dek", func(t *testing.T) { + dek, err := GenerateDataEncryptionKey() + require.NoError(t, err) + kek, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + ciphertext, err := kek.Public().EncryptDataEncryptionKey(dek) + require.NoError(t, err) + dek2, err := kek.DecryptDataEncryptionKey(ciphertext) + require.NoError(t, err) + require.Equal(t, dek, dek2) + }) + t.Run("ID", func(t *testing.T) { + kek, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + assert.Equal(t, kek.id, kek.ID()) + }) + t.Run("KeyBytes", func(t *testing.T) { + private, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + assert.Equal(t, private.data[:], private.KeyBytes()) + assert.NotSame(t, private.data[:], private.KeyBytes()) + public := private.Public() + assert.Equal(t, public.data[:], public.KeyBytes()) + assert.NotSame(t, public.data[:], public.KeyBytes()) + }) + t.Run("GetKeyEncryptionKeyID", func(t *testing.T) { + id := GetKeyEncryptionKeyID([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}) + assert.Equal(t, "7nfE5LQBMyWq3tmZsDiK5EaT2nMPMvFJWDDEZWWLoni", id) + }) + t.Run("invalid key", func(t *testing.T) { + t.Run("private", func(t *testing.T) { + kek, err := NewPrivateKeyEncryptionKey("TEST", []byte("NOT BIG ENOUGH")) + require.Nil(t, kek) + require.Error(t, err) + }) + t.Run("public", func(t *testing.T) { + kek, err := NewPublicKeyEncryptionKey("TEST", []byte("NOT BIG ENOUGH")) + require.Nil(t, kek) + require.Error(t, err) + }) + }) + t.Run("bad data", func(t *testing.T) { + kek, err := GenerateKeyEncryptionKey() + require.NoError(t, err) + ciphertext, err := kek.Public().Encrypt([]byte("HELLO WORLD")) + require.NoError(t, err) + ciphertext[3]++ + plaintext, err := kek.Decrypt(ciphertext) + require.Error(t, err) + require.Nil(t, plaintext) + }) +}