mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-02 03:46:29 +02:00
- all: prefer `FormValues` to `ParseForm` with subsequent `Form.Get`s - all: refactor authentication stack to be checked by middleware, and accessible via request context. - all: replace http.ServeMux with gorilla/mux’s router - all: replace custom CSRF checks with gorilla/csrf middleware - authenticate: extract callback path as constant. - internal/config: implement stringer interface for policy - internal/cryptutil: add helper func `NewBase64Key` - internal/cryptutil: rename `GenerateKey` to `NewKey` - internal/cryptutil: rename `GenerateRandomString` to `NewRandomStringN` - internal/middleware: removed alice in favor of gorilla/mux - internal/sessions: remove unused `ValidateRedirectURI` and `ValidateClientSecret` - internal/sessions: replace custom CSRF with gorilla/csrf fork that supports custom handler protection - internal/urlutil: add `SignedRedirectURL` to create hmac'd URLs - internal/urlutil: add `ValidateURL` helper to parse URL options - internal/urlutil: add `GetAbsoluteURL` which takes a request and returns its absolute URL. - proxy: remove holdover state verification checks; we no longer are setting sessions in any proxy routes so we don’t need them. - proxy: replace un-named http.ServeMux with named domain routes. Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
307 lines
6.2 KiB
Go
307 lines
6.2 KiB
Go
package cryptutil
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"reflect"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
|
plaintext := []byte("my plain text value")
|
|
|
|
key := NewKey()
|
|
c, err := NewCipher(key)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
ciphertext, err := c.Encrypt(plaintext)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
if reflect.DeepEqual(plaintext, ciphertext) {
|
|
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
|
|
}
|
|
|
|
got, err := c.Decrypt(ciphertext)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err decrypting: %v", err)
|
|
}
|
|
|
|
// if less than 32 bytes, fail
|
|
_, err = c.Decrypt([]byte("oh"))
|
|
if err == nil {
|
|
t.Fatalf("should fail if <32 bytes output: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(got, plaintext) {
|
|
t.Logf(" got: %v", got)
|
|
t.Logf("want: %v", plaintext)
|
|
t.Fatal("got unexpected decrypted value")
|
|
}
|
|
}
|
|
|
|
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
|
key := NewKey()
|
|
|
|
c, err := NewCipher(key)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
type TC struct {
|
|
Field string `json:"field"`
|
|
}
|
|
|
|
tc := &TC{
|
|
Field: "my plain text value",
|
|
}
|
|
|
|
value1, err := c.Marshal(tc)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
value2, err := c.Marshal(tc)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
if value1 == value2 {
|
|
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
|
}
|
|
|
|
got1 := &TC{}
|
|
err = c.Unmarshal(value1, got1)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(got1, tc) {
|
|
t.Logf("want: %#v", tc)
|
|
t.Logf(" got: %#v", got1)
|
|
t.Fatalf("expected structs to be equal")
|
|
}
|
|
|
|
got2 := &TC{}
|
|
err = c.Unmarshal(value2, got2)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(got1, got2) {
|
|
t.Logf("got2: %#v", got2)
|
|
t.Logf("got1: %#v", got1)
|
|
t.Fatalf("expected structs to be equal")
|
|
}
|
|
}
|
|
|
|
func TestCipherDataRace(t *testing.T) {
|
|
cipher, err := NewCipher(NewKey())
|
|
if err != nil {
|
|
t.Fatalf("unexpected generating cipher err: %v", err)
|
|
}
|
|
|
|
type TC struct {
|
|
Field string `json:"field"`
|
|
}
|
|
|
|
wg := &sync.WaitGroup{}
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func(c *XChaCha20Cipher, wg *sync.WaitGroup) {
|
|
defer wg.Done()
|
|
b := make([]byte, 32)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
t.Fatalf("unexecpted error reading random bytes: %v", err)
|
|
}
|
|
|
|
sha := fmt.Sprintf("%x", sha1.New().Sum(b))
|
|
tc := &TC{
|
|
Field: sha,
|
|
}
|
|
|
|
value1, err := c.Marshal(tc)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
value2, err := c.Marshal(tc)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
if value1 == value2 {
|
|
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
|
}
|
|
|
|
got1 := &TC{}
|
|
err = c.Unmarshal(value1, got1)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(got1, tc) {
|
|
t.Logf("want: %#v", tc)
|
|
t.Logf(" got: %#v", got1)
|
|
t.Fatalf("expected structs to be equal")
|
|
}
|
|
|
|
got2 := &TC{}
|
|
err = c.Unmarshal(value2, got2)
|
|
if err != nil {
|
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(got1, got2) {
|
|
t.Logf("got2: %#v", got2)
|
|
t.Logf("got1: %#v", got1)
|
|
t.Fatalf("expected structs to be equal")
|
|
}
|
|
|
|
}(cipher, wg)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestGenerateRandomString(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
c int
|
|
want int
|
|
}{
|
|
{"simple", 32, 32},
|
|
{"zero", 0, 0},
|
|
{"negative", -1, 32},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
o := NewRandomStringN(tt.c)
|
|
b, err := base64.StdEncoding.DecodeString(o)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
got := len(b)
|
|
if got != tt.want {
|
|
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
s interface{}
|
|
wantErr bool
|
|
}{
|
|
{"unsupported type",
|
|
struct {
|
|
Animal string `json:"animal"`
|
|
Func func() `json:"sound"`
|
|
}{
|
|
Animal: "cat",
|
|
Func: func() {},
|
|
},
|
|
true},
|
|
{"simple",
|
|
struct {
|
|
Animal string `json:"animal"`
|
|
Sound string `json:"sound"`
|
|
}{
|
|
Animal: "cat",
|
|
Sound: "meow",
|
|
},
|
|
false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
c, err := NewCipher(NewKey())
|
|
if err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
_, err = c.Marshal(tt.s)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("XChaCha20Cipher.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewCipher(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
secret []byte
|
|
wantErr bool
|
|
}{
|
|
{"simple 32 byte key", NewKey(), false},
|
|
{"key too short", []byte("what is entropy"), true},
|
|
{"key too long", []byte(NewRandomStringN(33)), true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := NewCipher(tt.secret)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewCipher() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewCipherFromBase64(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
s string
|
|
wantErr bool
|
|
}{
|
|
{"simple 32 byte key", base64.StdEncoding.EncodeToString(NewKey()), false},
|
|
{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
|
|
{"key too long", NewRandomStringN(33), true},
|
|
{"bad base 64", string(NewKey()), true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := NewCipherFromBase64(tt.s)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewBase64Key(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
want int
|
|
}{
|
|
{"simple", 32},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
o := NewBase64Key()
|
|
b, err := base64.StdEncoding.DecodeString(o)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
got := len(b)
|
|
if got != tt.want {
|
|
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|