From cebd1df947589c7ae2eea8cb20a2dfc89948463a Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Thu, 26 Oct 2023 08:58:59 -0400 Subject: [PATCH] cryptutil: add streaming crypt helpers --- pkg/cryptutil/stream.go | 120 +++++++++++++++++++++++++++++++++++ pkg/cryptutil/stream_test.go | 39 ++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 pkg/cryptutil/stream.go create mode 100644 pkg/cryptutil/stream_test.go diff --git a/pkg/cryptutil/stream.go b/pkg/cryptutil/stream.go new file mode 100644 index 000000000..2d8551478 --- /dev/null +++ b/pkg/cryptutil/stream.go @@ -0,0 +1,120 @@ +package cryptutil + +import ( + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" +) + +const ( + streamBlockSize = 4096 +) + +// EncryptStream encrypts the src stream and returns the encrypted stream reader +func EncryptStream(src io.Reader, c cipher.AEAD) (io.Reader, error) { + pr, pw := io.Pipe() + go func() { + err := encryptStream(pw, src, c) + if err != nil { + _ = pw.CloseWithError(fmt.Errorf("encrypting stream: %w", err)) + } else { + _ = pw.Close() + } + }() + + return pr, nil +} + +func encryptStream(dst io.Writer, src io.Reader, c cipher.AEAD) error { + buf := make([]byte, streamBlockSize+c.Overhead()) + sizeBuf := make([]byte, 4) + nonce := make([]byte, c.NonceSize()) + + for { + n, err := src.Read(buf[0:streamBlockSize]) + if n > 0 { + binary.BigEndian.PutUint32(sizeBuf, uint32(n)) + _, err = dst.Write(sizeBuf) + if err != nil { + return fmt.Errorf("writing block size: %w", err) + } + + _, err := rand.Read(nonce) + if err != nil { + return fmt.Errorf("generating nonce: %w", err) + } + + _, err = dst.Write(nonce) + if err != nil { + return fmt.Errorf("writing nonce: %w", err) + } + + _, err = dst.Write(c.Seal(nil, nonce, buf[0:n], sizeBuf)) + if err != nil { + return fmt.Errorf("encrypting block: %w", err) + } + } + + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return fmt.Errorf("reading block: %w", err) + } + } +} + +// DecryptStream decrypts the src stream and returns the decrypted stream reader +func DecryptStream(src io.Reader, c cipher.AEAD) (io.Reader, error) { + pr, pw := io.Pipe() + go func() { + err := decryptStream(pw, src, c) + if err != nil { + _ = pw.CloseWithError(fmt.Errorf("decrypting stream: %w", err)) + } else { + _ = pw.Close() + } + }() + + return pr, nil +} + +func decryptStream(dst io.Writer, src io.Reader, c cipher.AEAD) error { + buf := make([]byte, streamBlockSize+c.Overhead()) + sizeBuf := make([]byte, 4) + nonce := make([]byte, c.NonceSize()) + + for { + _, err := io.ReadFull(src, sizeBuf) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return fmt.Errorf("reading block size: %w", err) + } + + _, err = io.ReadFull(src, nonce) + if err != nil { + return fmt.Errorf("reading nonce: %w", err) + } + + n := binary.BigEndian.Uint32(sizeBuf) + _, err = io.ReadFull(src, buf[0:int(n)+c.Overhead()]) + if err != nil { + return fmt.Errorf("reading block: %w", err) + } + + plaintext, err := c.Open(nil, nonce, buf[0:int(n)+c.Overhead()], sizeBuf) + if err != nil { + return fmt.Errorf("decrypting block: %w", err) + } + + _, err = dst.Write(plaintext) + if err != nil { + return fmt.Errorf("writing block: %w", err) + } + } +} diff --git a/pkg/cryptutil/stream_test.go b/pkg/cryptutil/stream_test.go new file mode 100644 index 000000000..103599c53 --- /dev/null +++ b/pkg/cryptutil/stream_test.go @@ -0,0 +1,39 @@ +package cryptutil_test + +import ( + "bytes" + "crypto/rand" + "io" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +func TestEncryptStream(t *testing.T) { + t.Parallel() + + plaintext := make([]byte, 4048*2.5) + _, err := rand.Read(plaintext) + require.NoError(t, err) + + // cipher + c, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) + require.NoError(t, err) + + // encrypt + encrypted := &bytes.Buffer{} + encrypter, err := cryptutil.EncryptStream(bytes.NewBuffer(plaintext), c) + require.NoError(t, err) + _, err = io.Copy(encrypted, encrypter) + require.NoError(t, err) + + // decrypt + decrypted := &bytes.Buffer{} + decrypter, err := cryptutil.DecryptStream(encrypted, c) + require.NoError(t, err) + _, err = decrypted.ReadFrom(decrypter) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted.Bytes()) +}