databroker: add encryption for records (#1168)

This commit is contained in:
Caleb Doxsey 2020-07-30 14:04:31 -06:00 committed by GitHub
parent 8cae3f27bb
commit 29fb96a955
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 332 additions and 7 deletions

5
cache/cache.go vendored
View file

@ -67,7 +67,10 @@ func New(opts config.Options) (*Cache, error) {
return nil, err
}
dataBrokerServer := NewDataBrokerServer(localGRPCServer, opts)
dataBrokerServer, err := NewDataBrokerServer(localGRPCServer, opts)
if err != nil {
return nil, err
}
dataBrokerClient := databroker.NewDataBrokerServiceClient(localGRPCConnection)
manager := manager.New(

13
cache/databroker.go vendored
View file

@ -1,10 +1,14 @@
package cache
import (
"encoding/base64"
"fmt"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/config"
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
@ -14,12 +18,17 @@ type DataBrokerServer struct {
}
// NewDataBrokerServer creates a new databroker service server.
func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) *DataBrokerServer {
func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) (*DataBrokerServer, error) {
key, err := base64.StdEncoding.DecodeString(opts.SharedKey)
if err != nil || len(key) != cryptutil.DefaultKeySize {
return nil, fmt.Errorf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
}
internalSrv := internal_databroker.New(
internal_databroker.WithSecret(key),
internal_databroker.WithStorageType(opts.DataBrokerStorageType),
internal_databroker.WithStorageConnectionString(opts.DataBrokerStorageConnectionString),
)
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
return srv
return srv, nil
}

View file

@ -15,6 +15,7 @@ var (
type serverConfig struct {
deletePermanentlyAfter time.Duration
btreeDegree int
secret []byte
storageType string
storageConnectionString string
}
@ -49,6 +50,13 @@ func WithDeletePermanentlyAfter(dur time.Duration) ServerOption {
}
}
// WithSecret sets the secret in the config.
func WithSecret(secret []byte) ServerOption {
return func(cfg *serverConfig) {
cfg.secret = secret
}
}
// WithStorageType sets the storage type.
func WithStorageType(typ string) ServerOption {
return func(cfg *serverConfig) {

View file

@ -346,17 +346,23 @@ func (srv *Server) getDB(recordType string) (storage.Backend, error) {
return db, nil
}
func (srv *Server) newDB(recordType string) (storage.Backend, error) {
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
switch srv.cfg.storageType {
case inmemory.Name:
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
case redis.Name:
db, err := redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()))
db, err = redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()))
if err != nil {
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
}
return db, nil
default:
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
}
if srv.cfg.secret != nil {
db, err = storage.NewEncryptedBackend(srv.cfg.secret, db)
if err != nil {
return nil, err
}
}
return db, nil
}

View file

@ -4,6 +4,8 @@ import (
"encoding/base64"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEncodeAndDecodeAccessToken(t *testing.T) {
@ -61,6 +63,22 @@ func TestNewAEADCipher(t *testing.T) {
}
}
func BenchmarkAEADCipher(b *testing.B) {
plaintext := []byte("my plain text value")
key := NewKey()
c, err := NewAEADCipher(key)
if !assert.NoError(b, err) {
return
}
ciphertext := Encrypt(c, plaintext, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
Decrypt(c, ciphertext, nil)
}
}
func TestNewAEADCipherFromBase64(t *testing.T) {
t.Parallel()
tests := []struct {

134
pkg/storage/encrypted.go Normal file
View file

@ -0,0 +1,134 @@
package storage
import (
"context"
"crypto/cipher"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
type encryptedBackend struct {
Backend
cipher cipher.AEAD
}
// NewEncryptedBackend creates a new encrypted backend.
func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) {
c, err := cryptutil.NewAEADCipher(secret)
if err != nil {
return nil, err
}
return &encryptedBackend{
Backend: underlying,
cipher: c,
}, nil
}
func (e *encryptedBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
encrypted, err := e.encrypt(data)
if err != nil {
return err
}
return e.Backend.Put(ctx, id, encrypted)
}
func (e *encryptedBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
record, err := e.Backend.Get(ctx, id)
if err != nil {
return nil, err
}
record, err = e.decryptRecord(record)
if err != nil {
return nil, err
}
return record, nil
}
func (e *encryptedBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) {
records, err := e.Backend.GetAll(ctx)
if err != nil {
return nil, err
}
for i := range records {
records[i], err = e.decryptRecord(records[i])
if err != nil {
return nil, err
}
}
return records, nil
}
func (e *encryptedBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
records, err := e.Backend.List(ctx, sinceVersion)
if err != nil {
return nil, err
}
for i := range records {
records[i], err = e.decryptRecord(records[i])
if err != nil {
return nil, err
}
}
return records, nil
}
func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) {
data, err := e.decrypt(in.Data)
if err != nil {
return nil, err
}
// Create a new record so that we don't re-use any internal state
return &databroker.Record{
Version: in.Version,
Type: data.TypeUrl,
Id: in.Id,
Data: data,
CreatedAt: in.CreatedAt,
ModifiedAt: in.ModifiedAt,
DeletedAt: in.DeletedAt,
}, nil
}
func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) {
var encrypted wrapperspb.BytesValue
err = in.UnmarshalTo(&encrypted)
if err != nil {
return nil, err
}
plaintext, err := cryptutil.Decrypt(e.cipher, encrypted.Value, nil)
if err != nil {
return nil, err
}
out = new(anypb.Any)
err = proto.Unmarshal(plaintext, out)
if err != nil {
return nil, err
}
return out, nil
}
func (e *encryptedBackend) encrypt(in *anypb.Any) (out *anypb.Any, err error) {
plaintext, err := proto.Marshal(in)
if err != nil {
return nil, err
}
encrypted := cryptutil.Encrypt(e.cipher, plaintext, nil)
out, err = anypb.New(&wrapperspb.BytesValue{
Value: encrypted,
})
if err != nil {
return nil, err
}
return out, nil
}

View file

@ -0,0 +1,100 @@
package storage
import (
"context"
"errors"
"testing"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
func TestEncryptedBackend(t *testing.T) {
ctx := context.Background()
m := map[string]*anypb.Any{}
backend := &mockBackend{
put: func(ctx context.Context, id string, data *anypb.Any) error {
m[id] = data
return nil
},
get: func(ctx context.Context, id string) (*databroker.Record, error) {
data, ok := m[id]
if !ok {
return nil, errors.New("not found")
}
return &databroker.Record{
Id: id,
Data: data,
}, nil
},
getAll: func(ctx context.Context) ([]*databroker.Record, error) {
var records []*databroker.Record
for id, data := range m {
records = append(records, &databroker.Record{
Id: id,
Data: data,
})
}
return records, nil
},
list: func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
var records []*databroker.Record
for id, data := range m {
records = append(records, &databroker.Record{
Id: id,
Data: data,
})
}
return records, nil
},
}
e, err := NewEncryptedBackend(cryptutil.NewKey(), backend)
if !assert.NoError(t, err) {
return
}
any, _ := anypb.New(wrapperspb.String("HELLO WORLD"))
err = e.Put(ctx, "TEST-1", any)
if !assert.NoError(t, err) {
return
}
if assert.NotNil(t, m["TEST-1"], "key should be set") {
assert.NotEqual(t, any.TypeUrl, m["TEST-1"].TypeUrl, "encrypted data should be a bytes type")
assert.NotEqual(t, any.Value, m["TEST-1"].Value, "value should be encrypted")
}
record, err := e.Get(ctx, "TEST-1")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, any.TypeUrl, record.Data.TypeUrl, "type should be preserved")
assert.Equal(t, any.Value, record.Data.Value, "value should be preserved")
assert.Equal(t, any.TypeUrl, record.Type, "record type should be preserved")
records, err := e.GetAll(ctx)
if !assert.NoError(t, err) {
return
}
if assert.Len(t, records, 1) {
assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved")
assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved")
assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved")
}
records, err = e.List(ctx, "")
if !assert.NoError(t, err) {
return
}
if assert.Len(t, records, 1) {
assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved")
assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved")
assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved")
}
}

View file

@ -0,0 +1,47 @@
package storage
import (
"context"
"time"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"google.golang.org/protobuf/types/known/anypb"
)
type mockBackend struct {
put func(ctx context.Context, id string, data *anypb.Any) error
get func(ctx context.Context, id string) (*databroker.Record, error)
getAll func(ctx context.Context) ([]*databroker.Record, error)
list func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
delete func(ctx context.Context, id string) error
clearDeleted func(ctx context.Context, cutoff time.Time)
watch func(ctx context.Context) chan struct{}
}
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
return m.put(ctx, id, data)
}
func (m *mockBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
return m.get(ctx, id)
}
func (m *mockBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) {
return m.getAll(ctx)
}
func (m *mockBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
return m.list(ctx, sinceVersion)
}
func (m *mockBackend) Delete(ctx context.Context, id string) error {
return m.delete(ctx, id)
}
func (m *mockBackend) ClearDeleted(ctx context.Context, cutoff time.Time) {
m.clearDeleted(ctx, cutoff)
}
func (m *mockBackend) Watch(ctx context.Context) chan struct{} {
return m.watch(ctx)
}