pomerium/pkg/cryptutil/stream.go
2023-10-26 08:58:59 -04:00

120 lines
2.6 KiB
Go

package cryptutil
import (
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
)
const (
streamBlockSize = 4096
)
// EncryptStream encrypts the src stream and returns the encrypted stream reader
func EncryptStream(src io.Reader, c cipher.AEAD) (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
err := encryptStream(pw, src, c)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("encrypting stream: %w", err))
} else {
_ = pw.Close()
}
}()
return pr, nil
}
func encryptStream(dst io.Writer, src io.Reader, c cipher.AEAD) error {
buf := make([]byte, streamBlockSize+c.Overhead())
sizeBuf := make([]byte, 4)
nonce := make([]byte, c.NonceSize())
for {
n, err := src.Read(buf[0:streamBlockSize])
if n > 0 {
binary.BigEndian.PutUint32(sizeBuf, uint32(n))
_, err = dst.Write(sizeBuf)
if err != nil {
return fmt.Errorf("writing block size: %w", err)
}
_, err := rand.Read(nonce)
if err != nil {
return fmt.Errorf("generating nonce: %w", err)
}
_, err = dst.Write(nonce)
if err != nil {
return fmt.Errorf("writing nonce: %w", err)
}
_, err = dst.Write(c.Seal(nil, nonce, buf[0:n], sizeBuf))
if err != nil {
return fmt.Errorf("encrypting block: %w", err)
}
}
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return fmt.Errorf("reading block: %w", err)
}
}
}
// DecryptStream decrypts the src stream and returns the decrypted stream reader
func DecryptStream(src io.Reader, c cipher.AEAD) (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
err := decryptStream(pw, src, c)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("decrypting stream: %w", err))
} else {
_ = pw.Close()
}
}()
return pr, nil
}
func decryptStream(dst io.Writer, src io.Reader, c cipher.AEAD) error {
buf := make([]byte, streamBlockSize+c.Overhead())
sizeBuf := make([]byte, 4)
nonce := make([]byte, c.NonceSize())
for {
_, err := io.ReadFull(src, sizeBuf)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return fmt.Errorf("reading block size: %w", err)
}
_, err = io.ReadFull(src, nonce)
if err != nil {
return fmt.Errorf("reading nonce: %w", err)
}
n := binary.BigEndian.Uint32(sizeBuf)
_, err = io.ReadFull(src, buf[0:int(n)+c.Overhead()])
if err != nil {
return fmt.Errorf("reading block: %w", err)
}
plaintext, err := c.Open(nil, nonce, buf[0:int(n)+c.Overhead()], sizeBuf)
if err != nil {
return fmt.Errorf("decrypting block: %w", err)
}
_, err = dst.Write(plaintext)
if err != nil {
return fmt.Errorf("writing block: %w", err)
}
}
}