cryptutil: add streaming crypt helpers

This commit is contained in:
Denis Mishin 2023-10-26 08:58:59 -04:00
parent 58d8f406a9
commit cebd1df947
2 changed files with 159 additions and 0 deletions

120
pkg/cryptutil/stream.go Normal file
View file

@ -0,0 +1,120 @@
package cryptutil
import (
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
)
const (
streamBlockSize = 4096
)
// EncryptStream encrypts the src stream and returns the encrypted stream reader
func EncryptStream(src io.Reader, c cipher.AEAD) (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
err := encryptStream(pw, src, c)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("encrypting stream: %w", err))
} else {
_ = pw.Close()
}
}()
return pr, nil
}
func encryptStream(dst io.Writer, src io.Reader, c cipher.AEAD) error {
buf := make([]byte, streamBlockSize+c.Overhead())
sizeBuf := make([]byte, 4)
nonce := make([]byte, c.NonceSize())
for {
n, err := src.Read(buf[0:streamBlockSize])
if n > 0 {
binary.BigEndian.PutUint32(sizeBuf, uint32(n))
_, err = dst.Write(sizeBuf)
if err != nil {
return fmt.Errorf("writing block size: %w", err)
}
_, err := rand.Read(nonce)
if err != nil {
return fmt.Errorf("generating nonce: %w", err)
}
_, err = dst.Write(nonce)
if err != nil {
return fmt.Errorf("writing nonce: %w", err)
}
_, err = dst.Write(c.Seal(nil, nonce, buf[0:n], sizeBuf))
if err != nil {
return fmt.Errorf("encrypting block: %w", err)
}
}
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return fmt.Errorf("reading block: %w", err)
}
}
}
// DecryptStream decrypts the src stream and returns the decrypted stream reader
func DecryptStream(src io.Reader, c cipher.AEAD) (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
err := decryptStream(pw, src, c)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("decrypting stream: %w", err))
} else {
_ = pw.Close()
}
}()
return pr, nil
}
func decryptStream(dst io.Writer, src io.Reader, c cipher.AEAD) error {
buf := make([]byte, streamBlockSize+c.Overhead())
sizeBuf := make([]byte, 4)
nonce := make([]byte, c.NonceSize())
for {
_, err := io.ReadFull(src, sizeBuf)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return fmt.Errorf("reading block size: %w", err)
}
_, err = io.ReadFull(src, nonce)
if err != nil {
return fmt.Errorf("reading nonce: %w", err)
}
n := binary.BigEndian.Uint32(sizeBuf)
_, err = io.ReadFull(src, buf[0:int(n)+c.Overhead()])
if err != nil {
return fmt.Errorf("reading block: %w", err)
}
plaintext, err := c.Open(nil, nonce, buf[0:int(n)+c.Overhead()], sizeBuf)
if err != nil {
return fmt.Errorf("decrypting block: %w", err)
}
_, err = dst.Write(plaintext)
if err != nil {
return fmt.Errorf("writing block: %w", err)
}
}
}

View file

@ -0,0 +1,39 @@
package cryptutil_test
import (
"bytes"
"crypto/rand"
"io"
"testing"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestEncryptStream(t *testing.T) {
t.Parallel()
plaintext := make([]byte, 4048*2.5)
_, err := rand.Read(plaintext)
require.NoError(t, err)
// cipher
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
require.NoError(t, err)
// encrypt
encrypted := &bytes.Buffer{}
encrypter, err := cryptutil.EncryptStream(bytes.NewBuffer(plaintext), c)
require.NoError(t, err)
_, err = io.Copy(encrypted, encrypter)
require.NoError(t, err)
// decrypt
decrypted := &bytes.Buffer{}
decrypter, err := cryptutil.DecryptStream(encrypted, c)
require.NoError(t, err)
_, err = decrypted.ReadFrom(decrypter)
require.NoError(t, err)
require.Equal(t, plaintext, decrypted.Bytes())
}