mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 18:33:19 +02:00
authenticate: encrypt & mac oauth2 callback state
- cryptutil: add hmac & tests - cryptutil: rename cipher / encoders to be more clear - cryptutil: simplify SecureEncoder interface - cryptutil: renamed NewCipherFromBase64 to NewAEADCipherFromBase64 - cryptutil: move key & random generators to helpers Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
3a806c6dfc
commit
7c755d833f
26 changed files with 539 additions and 464 deletions
|
@ -1,12 +1,8 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -14,27 +10,24 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
|||
plaintext := []byte("my plain text value")
|
||||
|
||||
key := NewKey()
|
||||
c, err := NewCipher(key)
|
||||
c, err := NewAEADCipher(key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
ciphertext, err := c.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
ciphertext := Encrypt(c, plaintext, nil)
|
||||
|
||||
if reflect.DeepEqual(plaintext, ciphertext) {
|
||||
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
|
||||
}
|
||||
|
||||
got, err := c.Decrypt(ciphertext)
|
||||
got, err := Decrypt(c, ciphertext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err decrypting: %v", err)
|
||||
}
|
||||
|
||||
// if less than 32 bytes, fail
|
||||
_, err = c.Decrypt([]byte("oh"))
|
||||
_, err = Decrypt(c, []byte("oh"), nil)
|
||||
if err == nil {
|
||||
t.Fatalf("should fail if <32 bytes output: %v", err)
|
||||
}
|
||||
|
@ -49,10 +42,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
|||
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||
key := NewKey()
|
||||
|
||||
c, err := NewCipher(key)
|
||||
a, err := NewAEADCipher(key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
c := SecureJSONEncoder{aead: a}
|
||||
|
||||
type TC struct {
|
||||
Field string `json:"field"`
|
||||
|
@ -101,102 +95,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCipherDataRace(t *testing.T) {
|
||||
cipher, err := NewCipher(NewKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected generating cipher err: %v", err)
|
||||
}
|
||||
|
||||
type TC struct {
|
||||
Field string `json:"field"`
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(c *XChaCha20Cipher, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("unexecpted error reading random bytes: %v", err)
|
||||
}
|
||||
|
||||
sha := fmt.Sprintf("%x", sha1.New().Sum(b))
|
||||
tc := &TC{
|
||||
Field: sha,
|
||||
}
|
||||
|
||||
value1, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
value2, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
if value1 == value2 {
|
||||
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
||||
}
|
||||
|
||||
got1 := &TC{}
|
||||
err = c.Unmarshal(value1, got1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, tc) {
|
||||
t.Logf("want: %#v", tc)
|
||||
t.Logf(" got: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
|
||||
got2 := &TC{}
|
||||
err = c.Unmarshal(value2, got2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, got2) {
|
||||
t.Logf("got2: %#v", got2)
|
||||
t.Logf("got1: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
|
||||
}(cipher, wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestGenerateRandomString(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
c int
|
||||
want int
|
||||
}{
|
||||
{"simple", 32, 32},
|
||||
{"zero", 0, 0},
|
||||
{"negative", -1, 32},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := NewRandomStringN(tt.c)
|
||||
b, err := base64.StdEncoding.DecodeString(o)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
got := len(b)
|
||||
if got != tt.want {
|
||||
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
||||
func TestSecureJSONEncoder_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -225,20 +124,22 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
c, err := NewCipher(NewKey())
|
||||
c, err := NewAEADCipher(NewKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
_, err = c.Marshal(tt.s)
|
||||
e := SecureJSONEncoder{aead: c}
|
||||
|
||||
_, err = e.Marshal(tt.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("XChaCha20Cipher.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("SecureJSONEncoder.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCipher(t *testing.T) {
|
||||
func TestNewAEADCipher(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -251,16 +152,16 @@ func TestNewCipher(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewCipher(tt.secret)
|
||||
_, err := NewAEADCipher(tt.secret)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCipher() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("NewAEADCipher() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCipherFromBase64(t *testing.T) {
|
||||
func TestNewAEADCipherFromBase64(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -274,34 +175,11 @@ func TestNewCipherFromBase64(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewCipherFromBase64(tt.s)
|
||||
_, err := NewAEADCipherFromBase64(tt.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("NewAEADCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBase64Key(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
want int
|
||||
}{
|
||||
{"simple", 32},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := NewBase64Key()
|
||||
b, err := base64.StdEncoding.DecodeString(o)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
got := len(b)
|
||||
if got != tt.want {
|
||||
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue