mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-11 16:17:39 +02:00
internal/cryptutil: fixed panic on nil pubkey
This commit is contained in:
parent
22fb3a0f7e
commit
c5bcc9bbef
5 changed files with 81 additions and 13 deletions
|
@ -70,11 +70,7 @@ func NewCipher(secret []byte) (*XChaCha20Cipher, error) {
|
|||
// GenerateNonce generates a random nonce.
|
||||
// Panics if source of randomness fails.
|
||||
func (c *XChaCha20Cipher) GenerateNonce() []byte {
|
||||
nonce := make([]byte, c.aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nonce
|
||||
return randomBytes(c.aead.NonceSize())
|
||||
}
|
||||
|
||||
// Encrypt a value using XChaCha20-Poly1305
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package cryptutil
|
||||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
@ -33,6 +33,12 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
|||
t.Fatalf("unexpected err decrypting: %v", err)
|
||||
}
|
||||
|
||||
// if less than 32 bytes, fail
|
||||
_, err = c.Decrypt([]byte("oh"))
|
||||
if err == nil {
|
||||
t.Fatalf("should fail if <32 bytes output: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, plaintext) {
|
||||
t.Logf(" got: %v", got)
|
||||
t.Logf("want: %v", plaintext)
|
||||
|
@ -189,3 +195,67 @@ func TestGenerateRandomString(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
s interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{"unsupported type",
|
||||
struct {
|
||||
Animal string `json:"animal"`
|
||||
Func func() `json:"sound"`
|
||||
}{
|
||||
Animal: "cat",
|
||||
Func: func() {},
|
||||
},
|
||||
true},
|
||||
{"simple",
|
||||
struct {
|
||||
Animal string `json:"animal"`
|
||||
Sound string `json:"sound"`
|
||||
}{
|
||||
Animal: "cat",
|
||||
Sound: "meow",
|
||||
},
|
||||
false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
c, err := NewCipher(GenerateKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
_, err = c.Marshal(tt.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("XChaCha20Cipher.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCipher(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
secret []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple 32 byte key", GenerateKey(), false},
|
||||
{"key too short", []byte("what is entropy"), true},
|
||||
{"key too long", []byte(GenerateRandomString(33)), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewCipher(tt.secret)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCipher() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Package cryptutil provides encoding and decoding routines for various cryptographic structures.
|
||||
package cryptutil
|
||||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
|
@ -13,11 +13,9 @@ import (
|
|||
// DecodePublicKey decodes a PEM-encoded ECDSA public key.
|
||||
func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) {
|
||||
block, _ := pem.Decode(encodedKey)
|
||||
if block == nil || block.Type != "PUBLIC KEY" {
|
||||
return nil, fmt.Errorf("marshal: could not decode PEM block type %s", block.Type)
|
||||
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("marshal: decoded nil PEM block")
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package cryptutil
|
||||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -35,6 +35,10 @@ func TestPublicKeyMarshaling(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = DecodePublicKey(nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
pemBytes, _ := EncodePublicKey(ecKey)
|
||||
if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package cryptutil
|
||||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue