diff --git a/internal/cryptutil/encrypt.go b/internal/cryptutil/encrypt.go index ac44774ff..fbfca9a09 100644 --- a/internal/cryptutil/encrypt.go +++ b/internal/cryptutil/encrypt.go @@ -70,11 +70,7 @@ func NewCipher(secret []byte) (*XChaCha20Cipher, error) { // GenerateNonce generates a random nonce. // Panics if source of randomness fails. func (c *XChaCha20Cipher) GenerateNonce() []byte { - nonce := make([]byte, c.aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - panic(err) - } - return nonce + return randomBytes(c.aead.NonceSize()) } // Encrypt a value using XChaCha20-Poly1305 diff --git a/internal/cryptutil/encrypt_test.go b/internal/cryptutil/encrypt_test.go index a5f83658f..d46316600 100644 --- a/internal/cryptutil/encrypt_test.go +++ b/internal/cryptutil/encrypt_test.go @@ -1,4 +1,4 @@ -package cryptutil +package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( "crypto/rand" @@ -33,6 +33,12 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { t.Fatalf("unexpected err decrypting: %v", err) } + // if less than 32 bytes, fail + _, err = c.Decrypt([]byte("oh")) + if err == nil { + t.Fatalf("should fail if <32 bytes output: %v", err) + } + if !reflect.DeepEqual(got, plaintext) { t.Logf(" got: %v", got) t.Logf("want: %v", plaintext) @@ -189,3 +195,67 @@ func TestGenerateRandomString(t *testing.T) { }) } } + +func TestXChaCha20Cipher_Marshal(t *testing.T) { + + tests := []struct { + name string + s interface{} + wantErr bool + }{ + {"unsupported type", + struct { + Animal string `json:"animal"` + Func func() `json:"sound"` + }{ + Animal: "cat", + Func: func() {}, + }, + true}, + {"simple", + struct { + Animal string `json:"animal"` + Sound string `json:"sound"` + }{ + Animal: "cat", + Sound: "meow", + }, + false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + c, err := NewCipher(GenerateKey()) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + _, err = c.Marshal(tt.s) + if (err != nil) != tt.wantErr { + t.Errorf("XChaCha20Cipher.Marshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestNewCipher(t *testing.T) { + + tests := []struct { + name string + secret []byte + wantErr bool + }{ + {"simple 32 byte key", GenerateKey(), false}, + {"key too short", []byte("what is entropy"), true}, + {"key too long", []byte(GenerateRandomString(33)), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewCipher(tt.secret) + if (err != nil) != tt.wantErr { + t.Errorf("NewCipher() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/internal/cryptutil/marshal.go b/internal/cryptutil/marshal.go index d921ea07d..a2bfcef67 100644 --- a/internal/cryptutil/marshal.go +++ b/internal/cryptutil/marshal.go @@ -1,5 +1,5 @@ // Package cryptutil provides encoding and decoding routines for various cryptographic structures. -package cryptutil +package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( "crypto/ecdsa" @@ -13,11 +13,9 @@ import ( // DecodePublicKey decodes a PEM-encoded ECDSA public key. func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) { block, _ := pem.Decode(encodedKey) - if block == nil || block.Type != "PUBLIC KEY" { - return nil, fmt.Errorf("marshal: could not decode PEM block type %s", block.Type) - + if block == nil { + return nil, fmt.Errorf("marshal: decoded nil PEM block") } - pub, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return nil, err diff --git a/internal/cryptutil/marshal_test.go b/internal/cryptutil/marshal_test.go index 9a9abc2e0..b62451554 100644 --- a/internal/cryptutil/marshal_test.go +++ b/internal/cryptutil/marshal_test.go @@ -1,4 +1,4 @@ -package cryptutil +package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( "bytes" @@ -35,6 +35,10 @@ func TestPublicKeyMarshaling(t *testing.T) { if err != nil { t.Fatal(err) } + _, err = DecodePublicKey(nil) + if err == nil { + t.Fatal("expected error") + } pemBytes, _ := EncodePublicKey(ecKey) if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) { diff --git a/internal/cryptutil/sign_test.go b/internal/cryptutil/sign_test.go index b93848db6..ff31afe9e 100644 --- a/internal/cryptutil/sign_test.go +++ b/internal/cryptutil/sign_test.go @@ -1,4 +1,4 @@ -package cryptutil +package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( "testing"