package protoutil import ( "fmt" "sync" "time" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "github.com/pomerium/pomerium/pkg/cryptutil" cryptpb "github.com/pomerium/pomerium/pkg/grpc/crypt" ) // An Encryptor encrypts protobuf messages using a key encryption key and periodically rotated // generated data encryption keys. type Encryptor struct { kek *cryptutil.PublicKeyEncryptionKey rotateEvery time.Duration sync.RWMutex nextRotate time.Time dek *cryptutil.DataEncryptionKey encryptedDEK []byte } // NewEncryptor returns a new protobuf Encryptor. func NewEncryptor(kek *cryptutil.PublicKeyEncryptionKey) *Encryptor { return &Encryptor{ kek: kek, rotateEvery: time.Hour, } } func (enc *Encryptor) getDataEncryptionKey() (*cryptutil.DataEncryptionKey, []byte, error) { // double-checked locking // first time we do a read only lookup enc.RLock() dek, encryptedDEK, err := enc.getDataEncryptionKeyLocked(true) enc.RUnlock() if err != nil { return nil, nil, err } else if dek != nil { return dek, encryptedDEK, nil } // second time we do a read/write lookup enc.Lock() dek, encryptedDEK, err = enc.getDataEncryptionKeyLocked(false) enc.Unlock() return dek, encryptedDEK, err } func (enc *Encryptor) getDataEncryptionKeyLocked(readOnly bool) (*cryptutil.DataEncryptionKey, []byte, error) { needsNewKey := enc.dek == nil || time.Now().After(enc.nextRotate) if !needsNewKey { return enc.dek, enc.encryptedDEK, nil } if readOnly { return nil, nil, nil } // generate a new data encryption key dek, err := cryptutil.GenerateDataEncryptionKey() if err != nil { return nil, nil, err } // seal the data encryption key using the key encryption key encryptedDEK, err := enc.kek.EncryptDataEncryptionKey(dek) if err != nil { return nil, nil, err } enc.dek = dek enc.encryptedDEK = encryptedDEK enc.nextRotate = time.Now().Add(enc.rotateEvery) return enc.dek, enc.encryptedDEK, nil } // Encrypt encrypts a protobuf message. func (enc *Encryptor) Encrypt(msg proto.Message) (*cryptpb.SealedMessage, error) { // get the data encryption key dek, encryptedDEK, err := enc.getDataEncryptionKey() if err != nil { return nil, err } plaintext, err := protojson.Marshal(msg) if err != nil { return nil, err } ciphertext := dek.Encrypt(plaintext) return &cryptpb.SealedMessage{ KeyId: enc.kek.ID(), DataEncryptionKey: encryptedDEK, MessageType: GetTypeURL(msg), EncryptedMessage: ciphertext, }, nil } // A Decryptor decrypts encrypted protobuf messages. type Decryptor struct { keySource cryptutil.KeyEncryptionKeySource dekCache *cryptutil.DataEncryptionKeyCache } // NewDecryptor creates a new decryptor. func NewDecryptor(keySource cryptutil.KeyEncryptionKeySource) *Decryptor { return &Decryptor{ keySource: keySource, dekCache: cryptutil.NewDataEncryptionKeyCache(), } } func (dec *Decryptor) getDataEncryptionKey(keyEncryptionKeyID string, encryptedDEK []byte) (*cryptutil.DataEncryptionKey, error) { // return a dek if its already cached dek, ok := dec.dekCache.Get(encryptedDEK) if ok { return dek, nil } // look up the kek used for this dek kek, err := dec.keySource.GetKeyEncryptionKey(keyEncryptionKeyID) if err != nil { return nil, fmt.Errorf("protoutil: error getting key-encryption-key (%s): %w", keyEncryptionKeyID, err) } // decrypt the dek via the private kek dek, err = kek.DecryptDataEncryptionKey(encryptedDEK) if err != nil { return nil, fmt.Errorf("protoutil: error decrypting data-encryption-key: %w", err) } // cache it for next time dec.dekCache.Put(encryptedDEK, dek) return dek, nil } // Decrypt decrypts an encrypted protobuf message. func (dec *Decryptor) Decrypt(src *cryptpb.SealedMessage) (proto.Message, error) { dek, err := dec.getDataEncryptionKey(src.GetKeyId(), src.GetDataEncryptionKey()) if err != nil { return nil, err } plaintext, err := dek.Decrypt(src.GetEncryptedMessage()) if err != nil { return nil, err } msg, err := (&anypb.Any{TypeUrl: src.GetMessageType()}).UnmarshalNew() if err != nil { return nil, err } err = protojson.Unmarshal(plaintext, msg) if err != nil { return nil, err } return msg, nil }