mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
101 lines
2.9 KiB
Go
101 lines
2.9 KiB
Go
package protoutil
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
|
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
)
|
|
|
|
func TestEncryptor_Encrypt(t *testing.T) {
|
|
t.Run("simple", func(t *testing.T) {
|
|
kek, err := cryptutil.GenerateKeyEncryptionKey()
|
|
require.NoError(t, err)
|
|
enc := NewEncryptor(kek.Public())
|
|
sealed, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
|
|
require.NoError(t, err)
|
|
require.Equal(t, kek.Public().ID(), sealed.GetKeyId())
|
|
require.NotEmpty(t, sealed.GetDataEncryptionKey())
|
|
require.Equal(t, "type.googleapis.com/google.protobuf.StringValue", sealed.GetMessageType())
|
|
require.NotEmpty(t, sealed.GetEncryptedMessage())
|
|
})
|
|
|
|
t.Run("reuse dek", func(t *testing.T) {
|
|
kek, err := cryptutil.GenerateKeyEncryptionKey()
|
|
require.NoError(t, err)
|
|
enc := NewEncryptor(kek.Public())
|
|
s1, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
|
|
require.NoError(t, err)
|
|
s2, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, s1.GetDataEncryptionKey(), s2.GetDataEncryptionKey())
|
|
})
|
|
t.Run("rotate dek", func(t *testing.T) {
|
|
kek, err := cryptutil.GenerateKeyEncryptionKey()
|
|
require.NoError(t, err)
|
|
enc := NewEncryptor(kek.Public())
|
|
s1, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
|
|
require.NoError(t, err)
|
|
enc.nextRotate = time.Now()
|
|
s2, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
|
|
require.NoError(t, err)
|
|
assert.NotEqual(t, s1.GetDataEncryptionKey(), s2.GetDataEncryptionKey())
|
|
})
|
|
}
|
|
|
|
func TestDecryptor_Decrypt(t *testing.T) {
|
|
expect := wrapperspb.String("HELLO WORLD")
|
|
|
|
kek, err := cryptutil.GenerateKeyEncryptionKey()
|
|
require.NoError(t, err)
|
|
|
|
enc := NewEncryptor(kek.Public())
|
|
sealed, err := enc.Encrypt(expect)
|
|
require.NoError(t, err)
|
|
|
|
dec := NewDecryptor(cryptutil.KeyEncryptionKeySourceFunc(func(id string) (*cryptutil.PrivateKeyEncryptionKey, error) {
|
|
require.Equal(t, kek.ID(), id)
|
|
return kek, nil
|
|
}))
|
|
opened, err := dec.Decrypt(sealed)
|
|
require.NoError(t, err)
|
|
assertProtoEqual(t, expect, opened)
|
|
}
|
|
|
|
func assertProtoEqual(t *testing.T, x, y proto.Message) {
|
|
xbs, _ := protojson.Marshal(x)
|
|
ybs, _ := protojson.Marshal(y)
|
|
assert.True(t, proto.Equal(x, y), "%s != %s", xbs, ybs)
|
|
}
|
|
|
|
func BenchmarkEncrypt(b *testing.B) {
|
|
m := map[string]interface{}{}
|
|
for i := 0; i < 10; i++ {
|
|
mm := map[string]interface{}{}
|
|
for j := 0; j < 10; j++ {
|
|
mm[fmt.Sprintf("key%d", j)] = fmt.Sprintf("value%d", j)
|
|
}
|
|
m[fmt.Sprintf("key%d", i)] = mm
|
|
}
|
|
|
|
obj, err := structpb.NewStruct(m)
|
|
require.NoError(b, err)
|
|
|
|
kek, err := cryptutil.GenerateKeyEncryptionKey()
|
|
require.NoError(b, err)
|
|
enc := NewEncryptor(kek.Public())
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := enc.Encrypt(obj)
|
|
require.NoError(b, err)
|
|
}
|
|
}
|