mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-19 12:07:18 +02:00
databroker: add encryption for records (#1168)
This commit is contained in:
parent
8cae3f27bb
commit
29fb96a955
8 changed files with 332 additions and 7 deletions
5
cache/cache.go
vendored
5
cache/cache.go
vendored
|
@ -67,7 +67,10 @@ func New(opts config.Options) (*Cache, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
dataBrokerServer := NewDataBrokerServer(localGRPCServer, opts)
|
||||
dataBrokerServer, err := NewDataBrokerServer(localGRPCServer, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(localGRPCConnection)
|
||||
|
||||
manager := manager.New(
|
||||
|
|
13
cache/databroker.go
vendored
13
cache/databroker.go
vendored
|
@ -1,10 +1,14 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
|
@ -14,12 +18,17 @@ type DataBrokerServer struct {
|
|||
}
|
||||
|
||||
// NewDataBrokerServer creates a new databroker service server.
|
||||
func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) *DataBrokerServer {
|
||||
func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) (*DataBrokerServer, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(opts.SharedKey)
|
||||
if err != nil || len(key) != cryptutil.DefaultKeySize {
|
||||
return nil, fmt.Errorf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
|
||||
}
|
||||
internalSrv := internal_databroker.New(
|
||||
internal_databroker.WithSecret(key),
|
||||
internal_databroker.WithStorageType(opts.DataBrokerStorageType),
|
||||
internal_databroker.WithStorageConnectionString(opts.DataBrokerStorageConnectionString),
|
||||
)
|
||||
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
|
||||
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
|
||||
return srv
|
||||
return srv, nil
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ var (
|
|||
type serverConfig struct {
|
||||
deletePermanentlyAfter time.Duration
|
||||
btreeDegree int
|
||||
secret []byte
|
||||
storageType string
|
||||
storageConnectionString string
|
||||
}
|
||||
|
@ -49,6 +50,13 @@ func WithDeletePermanentlyAfter(dur time.Duration) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithSecret sets the secret in the config.
|
||||
func WithSecret(secret []byte) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.secret = secret
|
||||
}
|
||||
}
|
||||
|
||||
// WithStorageType sets the storage type.
|
||||
func WithStorageType(typ string) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
|
|
|
@ -346,17 +346,23 @@ func (srv *Server) getDB(recordType string) (storage.Backend, error) {
|
|||
return db, nil
|
||||
}
|
||||
|
||||
func (srv *Server) newDB(recordType string) (storage.Backend, error) {
|
||||
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
||||
switch srv.cfg.storageType {
|
||||
case inmemory.Name:
|
||||
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
||||
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
|
||||
case redis.Name:
|
||||
db, err := redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()))
|
||||
db, err = redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
||||
}
|
||||
if srv.cfg.secret != nil {
|
||||
db, err = storage.NewEncryptedBackend(srv.cfg.secret, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"encoding/base64"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||
|
@ -61,6 +63,22 @@ func TestNewAEADCipher(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkAEADCipher(b *testing.B) {
|
||||
plaintext := []byte("my plain text value")
|
||||
|
||||
key := NewKey()
|
||||
c, err := NewAEADCipher(key)
|
||||
if !assert.NoError(b, err) {
|
||||
return
|
||||
}
|
||||
|
||||
ciphertext := Encrypt(c, plaintext, nil)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decrypt(c, ciphertext, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAEADCipherFromBase64(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
|
|
134
pkg/storage/encrypted.go
Normal file
134
pkg/storage/encrypted.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
type encryptedBackend struct {
|
||||
Backend
|
||||
cipher cipher.AEAD
|
||||
}
|
||||
|
||||
// NewEncryptedBackend creates a new encrypted backend.
|
||||
func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) {
|
||||
c, err := cryptutil.NewAEADCipher(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &encryptedBackend{
|
||||
Backend: underlying,
|
||||
cipher: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||
encrypted, err := e.encrypt(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.Backend.Put(ctx, id, encrypted)
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
|
||||
record, err := e.Backend.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record, err = e.decryptRecord(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) {
|
||||
records, err := e.Backend.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range records {
|
||||
records[i], err = e.decryptRecord(records[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
records, err := e.Backend.List(ctx, sinceVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range records {
|
||||
records[i], err = e.decryptRecord(records[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) {
|
||||
data, err := e.decrypt(in.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create a new record so that we don't re-use any internal state
|
||||
return &databroker.Record{
|
||||
Version: in.Version,
|
||||
Type: data.TypeUrl,
|
||||
Id: in.Id,
|
||||
Data: data,
|
||||
CreatedAt: in.CreatedAt,
|
||||
ModifiedAt: in.ModifiedAt,
|
||||
DeletedAt: in.DeletedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) {
|
||||
var encrypted wrapperspb.BytesValue
|
||||
err = in.UnmarshalTo(&encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plaintext, err := cryptutil.Decrypt(e.cipher, encrypted.Value, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out = new(anypb.Any)
|
||||
err = proto.Unmarshal(plaintext, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) encrypt(in *anypb.Any) (out *anypb.Any, err error) {
|
||||
plaintext, err := proto.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encrypted := cryptutil.Encrypt(e.cipher, plaintext, nil)
|
||||
|
||||
out, err = anypb.New(&wrapperspb.BytesValue{
|
||||
Value: encrypted,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
100
pkg/storage/encrypted_test.go
Normal file
100
pkg/storage/encrypted_test.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
func TestEncryptedBackend(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
m := map[string]*anypb.Any{}
|
||||
backend := &mockBackend{
|
||||
put: func(ctx context.Context, id string, data *anypb.Any) error {
|
||||
m[id] = data
|
||||
return nil
|
||||
},
|
||||
get: func(ctx context.Context, id string) (*databroker.Record, error) {
|
||||
data, ok := m[id]
|
||||
if !ok {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &databroker.Record{
|
||||
Id: id,
|
||||
Data: data,
|
||||
}, nil
|
||||
},
|
||||
getAll: func(ctx context.Context) ([]*databroker.Record, error) {
|
||||
var records []*databroker.Record
|
||||
for id, data := range m {
|
||||
records = append(records, &databroker.Record{
|
||||
Id: id,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
return records, nil
|
||||
},
|
||||
list: func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
var records []*databroker.Record
|
||||
for id, data := range m {
|
||||
records = append(records, &databroker.Record{
|
||||
Id: id,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
return records, nil
|
||||
},
|
||||
}
|
||||
|
||||
e, err := NewEncryptedBackend(cryptutil.NewKey(), backend)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
any, _ := anypb.New(wrapperspb.String("HELLO WORLD"))
|
||||
|
||||
err = e.Put(ctx, "TEST-1", any)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if assert.NotNil(t, m["TEST-1"], "key should be set") {
|
||||
assert.NotEqual(t, any.TypeUrl, m["TEST-1"].TypeUrl, "encrypted data should be a bytes type")
|
||||
assert.NotEqual(t, any.Value, m["TEST-1"].Value, "value should be encrypted")
|
||||
}
|
||||
|
||||
record, err := e.Get(ctx, "TEST-1")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, any.TypeUrl, record.Data.TypeUrl, "type should be preserved")
|
||||
assert.Equal(t, any.Value, record.Data.Value, "value should be preserved")
|
||||
assert.Equal(t, any.TypeUrl, record.Type, "record type should be preserved")
|
||||
|
||||
records, err := e.GetAll(ctx)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if assert.Len(t, records, 1) {
|
||||
assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved")
|
||||
assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved")
|
||||
assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved")
|
||||
}
|
||||
|
||||
records, err = e.List(ctx, "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if assert.Len(t, records, 1) {
|
||||
assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved")
|
||||
assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved")
|
||||
assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved")
|
||||
}
|
||||
}
|
47
pkg/storage/storage_test.go
Normal file
47
pkg/storage/storage_test.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
type mockBackend struct {
|
||||
put func(ctx context.Context, id string, data *anypb.Any) error
|
||||
get func(ctx context.Context, id string) (*databroker.Record, error)
|
||||
getAll func(ctx context.Context) ([]*databroker.Record, error)
|
||||
list func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
|
||||
delete func(ctx context.Context, id string) error
|
||||
clearDeleted func(ctx context.Context, cutoff time.Time)
|
||||
watch func(ctx context.Context) chan struct{}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||
return m.put(ctx, id, data)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
|
||||
return m.get(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mockBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) {
|
||||
return m.getAll(ctx)
|
||||
}
|
||||
|
||||
func (m *mockBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
return m.list(ctx, sinceVersion)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Delete(ctx context.Context, id string) error {
|
||||
return m.delete(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mockBackend) ClearDeleted(ctx context.Context, cutoff time.Time) {
|
||||
m.clearDeleted(ctx, cutoff)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Watch(ctx context.Context) chan struct{} {
|
||||
return m.watch(ctx)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue