package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( "crypto/rand" "crypto/sha1" "encoding/base64" "fmt" "reflect" "sync" "testing" ) func TestEncodeAndDecodeAccessToken(t *testing.T) { plaintext := []byte("my plain text value") key := GenerateKey() c, err := NewCipher(key) if err != nil { t.Fatalf("unexpected err: %v", err) } ciphertext, err := c.Encrypt(plaintext) if err != nil { t.Fatalf("unexpected err: %v", err) } if reflect.DeepEqual(plaintext, ciphertext) { t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext) } got, err := c.Decrypt(ciphertext) if err != nil { t.Fatalf("unexpected err decrypting: %v", err) } // if less than 32 bytes, fail _, err = c.Decrypt([]byte("oh")) if err == nil { t.Fatalf("should fail if <32 bytes output: %v", err) } if !reflect.DeepEqual(got, plaintext) { t.Logf(" got: %v", got) t.Logf("want: %v", plaintext) t.Fatal("got unexpected decrypted value") } } func TestMarshalAndUnmarshalStruct(t *testing.T) { key := GenerateKey() c, err := NewCipher(key) if err != nil { t.Fatalf("unexpected err: %v", err) } type TC struct { Field string `json:"field"` } tc := &TC{ Field: "my plain text value", } value1, err := c.Marshal(tc) if err != nil { t.Fatalf("unexpected err: %v", err) } value2, err := c.Marshal(tc) if err != nil { t.Fatalf("unexpected err: %v", err) } if value1 == value2 { t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2) } got1 := &TC{} err = c.Unmarshal(value1, got1) if err != nil { t.Fatalf("unexpected err unmarshalling struct: %v", err) } if !reflect.DeepEqual(got1, tc) { t.Logf("want: %#v", tc) t.Logf(" got: %#v", got1) t.Fatalf("expected structs to be equal") } got2 := &TC{} err = c.Unmarshal(value2, got2) if err != nil { t.Fatalf("unexpected err unmarshalling struct: %v", err) } if !reflect.DeepEqual(got1, got2) { t.Logf("got2: %#v", got2) t.Logf("got1: %#v", got1) t.Fatalf("expected structs to be equal") } } func TestCipherDataRace(t *testing.T) { cipher, err := NewCipher(GenerateKey()) if err != nil { t.Fatalf("unexpected generating cipher err: %v", err) } type TC struct { Field string `json:"field"` } wg := &sync.WaitGroup{} for i := 0; i < 100; i++ { wg.Add(1) go func(c *XChaCha20Cipher, wg *sync.WaitGroup) { defer wg.Done() b := make([]byte, 32) _, err := rand.Read(b) if err != nil { t.Fatalf("unexecpted error reading random bytes: %v", err) } sha := fmt.Sprintf("%x", sha1.New().Sum(b)) tc := &TC{ Field: sha, } value1, err := c.Marshal(tc) if err != nil { t.Fatalf("unexpected err: %v", err) } value2, err := c.Marshal(tc) if err != nil { t.Fatalf("unexpected err: %v", err) } if value1 == value2 { t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2) } got1 := &TC{} err = c.Unmarshal(value1, got1) if err != nil { t.Fatalf("unexpected err unmarshalling struct: %v", err) } if !reflect.DeepEqual(got1, tc) { t.Logf("want: %#v", tc) t.Logf(" got: %#v", got1) t.Fatalf("expected structs to be equal") } got2 := &TC{} err = c.Unmarshal(value2, got2) if err != nil { t.Fatalf("unexpected err unmarshalling struct: %v", err) } if !reflect.DeepEqual(got1, got2) { t.Logf("got2: %#v", got2) t.Logf("got1: %#v", got1) t.Fatalf("expected structs to be equal") } }(cipher, wg) } wg.Wait() } func TestGenerateRandomString(t *testing.T) { t.Parallel() tests := []struct { name string c int want int }{ {"simple", 32, 32}, {"zero", 0, 0}, {"negative", -1, 32}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := GenerateRandomString(tt.c) b, err := base64.StdEncoding.DecodeString(o) if err != nil { t.Error(err) } got := len(b) if got != tt.want { t.Errorf("GenerateRandomString() = %d, want %d", got, tt.want) } }) } } func TestXChaCha20Cipher_Marshal(t *testing.T) { tests := []struct { name string s interface{} wantErr bool }{ {"unsupported type", struct { Animal string `json:"animal"` Func func() `json:"sound"` }{ Animal: "cat", Func: func() {}, }, true}, {"simple", struct { Animal string `json:"animal"` Sound string `json:"sound"` }{ Animal: "cat", Sound: "meow", }, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewCipher(GenerateKey()) if err != nil { t.Fatalf("unexpected err: %v", err) } _, err = c.Marshal(tt.s) if (err != nil) != tt.wantErr { t.Errorf("XChaCha20Cipher.Marshal() error = %v, wantErr %v", err, tt.wantErr) return } }) } } func TestNewCipher(t *testing.T) { tests := []struct { name string secret []byte wantErr bool }{ {"simple 32 byte key", GenerateKey(), false}, {"key too short", []byte("what is entropy"), true}, {"key too long", []byte(GenerateRandomString(33)), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := NewCipher(tt.secret) if (err != nil) != tt.wantErr { t.Errorf("NewCipher() error = %v, wantErr %v", err, tt.wantErr) return } }) } } func TestNewCipherFromBase64(t *testing.T) { tests := []struct { name string s string wantErr bool }{ {"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false}, {"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true}, {"key too long", GenerateRandomString(33), true}, {"bad base 64", string(GenerateKey()), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := NewCipherFromBase64(tt.s) if (err != nil) != tt.wantErr { t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr) return } }) } }