diff --git a/cache/cache.go b/cache/cache.go index a3b63a6f2..6e41c9931 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -67,7 +67,10 @@ func New(opts config.Options) (*Cache, error) { return nil, err } - dataBrokerServer := NewDataBrokerServer(localGRPCServer, opts) + dataBrokerServer, err := NewDataBrokerServer(localGRPCServer, opts) + if err != nil { + return nil, err + } dataBrokerClient := databroker.NewDataBrokerServiceClient(localGRPCConnection) manager := manager.New( diff --git a/cache/databroker.go b/cache/databroker.go index 3dceb385b..0574bf8aa 100644 --- a/cache/databroker.go +++ b/cache/databroker.go @@ -1,10 +1,14 @@ package cache import ( + "encoding/base64" + "fmt" + "google.golang.org/grpc" "github.com/pomerium/pomerium/config" internal_databroker "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" ) @@ -14,12 +18,17 @@ type DataBrokerServer struct { } // NewDataBrokerServer creates a new databroker service server. -func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) *DataBrokerServer { +func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) (*DataBrokerServer, error) { + key, err := base64.StdEncoding.DecodeString(opts.SharedKey) + if err != nil || len(key) != cryptutil.DefaultKeySize { + return nil, fmt.Errorf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize) + } internalSrv := internal_databroker.New( + internal_databroker.WithSecret(key), internal_databroker.WithStorageType(opts.DataBrokerStorageType), internal_databroker.WithStorageConnectionString(opts.DataBrokerStorageConnectionString), ) srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv} databroker.RegisterDataBrokerServiceServer(grpcServer, srv) - return srv + return srv, nil } diff --git a/internal/databroker/config.go b/internal/databroker/config.go index dbe4814b6..3979e2138 100644 --- a/internal/databroker/config.go +++ b/internal/databroker/config.go @@ -15,6 +15,7 @@ var ( type serverConfig struct { deletePermanentlyAfter time.Duration btreeDegree int + secret []byte storageType string storageConnectionString string } @@ -49,6 +50,13 @@ func WithDeletePermanentlyAfter(dur time.Duration) ServerOption { } } +// WithSecret sets the secret in the config. +func WithSecret(secret []byte) ServerOption { + return func(cfg *serverConfig) { + cfg.secret = secret + } +} + // WithStorageType sets the storage type. func WithStorageType(typ string) ServerOption { return func(cfg *serverConfig) { diff --git a/internal/databroker/server.go b/internal/databroker/server.go index d86045b03..d32a39faf 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -346,17 +346,23 @@ func (srv *Server) getDB(recordType string) (storage.Backend, error) { return db, nil } -func (srv *Server) newDB(recordType string) (storage.Backend, error) { +func (srv *Server) newDB(recordType string) (db storage.Backend, err error) { switch srv.cfg.storageType { case inmemory.Name: - return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil + db = inmemory.NewDB(recordType, srv.cfg.btreeDegree) case redis.Name: - db, err := redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds())) + db, err = redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds())) if err != nil { return nil, fmt.Errorf("failed to create new redis storage: %w", err) } - return db, nil default: return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType) } + if srv.cfg.secret != nil { + db, err = storage.NewEncryptedBackend(srv.cfg.secret, db) + if err != nil { + return nil, err + } + } + return db, nil } diff --git a/pkg/cryptutil/encrypt_test.go b/pkg/cryptutil/encrypt_test.go index cfb4d62c4..98d364f64 100644 --- a/pkg/cryptutil/encrypt_test.go +++ b/pkg/cryptutil/encrypt_test.go @@ -4,6 +4,8 @@ import ( "encoding/base64" "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestEncodeAndDecodeAccessToken(t *testing.T) { @@ -61,6 +63,22 @@ func TestNewAEADCipher(t *testing.T) { } } +func BenchmarkAEADCipher(b *testing.B) { + plaintext := []byte("my plain text value") + + key := NewKey() + c, err := NewAEADCipher(key) + if !assert.NoError(b, err) { + return + } + + ciphertext := Encrypt(c, plaintext, nil) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Decrypt(c, ciphertext, nil) + } +} + func TestNewAEADCipherFromBase64(t *testing.T) { t.Parallel() tests := []struct { diff --git a/pkg/storage/encrypted.go b/pkg/storage/encrypted.go new file mode 100644 index 000000000..e964f7f3f --- /dev/null +++ b/pkg/storage/encrypted.go @@ -0,0 +1,134 @@ +package storage + +import ( + "context" + "crypto/cipher" + + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +type encryptedBackend struct { + Backend + cipher cipher.AEAD +} + +// NewEncryptedBackend creates a new encrypted backend. +func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) { + c, err := cryptutil.NewAEADCipher(secret) + if err != nil { + return nil, err + } + + return &encryptedBackend{ + Backend: underlying, + cipher: c, + }, nil +} + +func (e *encryptedBackend) Put(ctx context.Context, id string, data *anypb.Any) error { + encrypted, err := e.encrypt(data) + if err != nil { + return err + } + return e.Backend.Put(ctx, id, encrypted) +} + +func (e *encryptedBackend) Get(ctx context.Context, id string) (*databroker.Record, error) { + record, err := e.Backend.Get(ctx, id) + if err != nil { + return nil, err + } + record, err = e.decryptRecord(record) + if err != nil { + return nil, err + } + return record, nil +} + +func (e *encryptedBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) { + records, err := e.Backend.GetAll(ctx) + if err != nil { + return nil, err + } + for i := range records { + records[i], err = e.decryptRecord(records[i]) + if err != nil { + return nil, err + } + } + return records, nil +} + +func (e *encryptedBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) { + records, err := e.Backend.List(ctx, sinceVersion) + if err != nil { + return nil, err + } + for i := range records { + records[i], err = e.decryptRecord(records[i]) + if err != nil { + return nil, err + } + } + return records, nil +} + +func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) { + data, err := e.decrypt(in.Data) + if err != nil { + return nil, err + } + // Create a new record so that we don't re-use any internal state + return &databroker.Record{ + Version: in.Version, + Type: data.TypeUrl, + Id: in.Id, + Data: data, + CreatedAt: in.CreatedAt, + ModifiedAt: in.ModifiedAt, + DeletedAt: in.DeletedAt, + }, nil +} + +func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) { + var encrypted wrapperspb.BytesValue + err = in.UnmarshalTo(&encrypted) + if err != nil { + return nil, err + } + + plaintext, err := cryptutil.Decrypt(e.cipher, encrypted.Value, nil) + if err != nil { + return nil, err + } + + out = new(anypb.Any) + err = proto.Unmarshal(plaintext, out) + if err != nil { + return nil, err + } + + return out, nil +} + +func (e *encryptedBackend) encrypt(in *anypb.Any) (out *anypb.Any, err error) { + plaintext, err := proto.Marshal(in) + if err != nil { + return nil, err + } + + encrypted := cryptutil.Encrypt(e.cipher, plaintext, nil) + + out, err = anypb.New(&wrapperspb.BytesValue{ + Value: encrypted, + }) + if err != nil { + return nil, err + } + + return out, nil +} diff --git a/pkg/storage/encrypted_test.go b/pkg/storage/encrypted_test.go new file mode 100644 index 000000000..5f7f6e51f --- /dev/null +++ b/pkg/storage/encrypted_test.go @@ -0,0 +1,100 @@ +package storage + +import ( + "context" + "errors" + "testing" + + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestEncryptedBackend(t *testing.T) { + + ctx := context.Background() + + m := map[string]*anypb.Any{} + backend := &mockBackend{ + put: func(ctx context.Context, id string, data *anypb.Any) error { + m[id] = data + return nil + }, + get: func(ctx context.Context, id string) (*databroker.Record, error) { + data, ok := m[id] + if !ok { + return nil, errors.New("not found") + } + return &databroker.Record{ + Id: id, + Data: data, + }, nil + }, + getAll: func(ctx context.Context) ([]*databroker.Record, error) { + var records []*databroker.Record + for id, data := range m { + records = append(records, &databroker.Record{ + Id: id, + Data: data, + }) + } + return records, nil + }, + list: func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) { + var records []*databroker.Record + for id, data := range m { + records = append(records, &databroker.Record{ + Id: id, + Data: data, + }) + } + return records, nil + }, + } + + e, err := NewEncryptedBackend(cryptutil.NewKey(), backend) + if !assert.NoError(t, err) { + return + } + + any, _ := anypb.New(wrapperspb.String("HELLO WORLD")) + + err = e.Put(ctx, "TEST-1", any) + if !assert.NoError(t, err) { + return + } + if assert.NotNil(t, m["TEST-1"], "key should be set") { + assert.NotEqual(t, any.TypeUrl, m["TEST-1"].TypeUrl, "encrypted data should be a bytes type") + assert.NotEqual(t, any.Value, m["TEST-1"].Value, "value should be encrypted") + } + + record, err := e.Get(ctx, "TEST-1") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, any.TypeUrl, record.Data.TypeUrl, "type should be preserved") + assert.Equal(t, any.Value, record.Data.Value, "value should be preserved") + assert.Equal(t, any.TypeUrl, record.Type, "record type should be preserved") + + records, err := e.GetAll(ctx) + if !assert.NoError(t, err) { + return + } + if assert.Len(t, records, 1) { + assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved") + assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved") + assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved") + } + + records, err = e.List(ctx, "") + if !assert.NoError(t, err) { + return + } + if assert.Len(t, records, 1) { + assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved") + assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved") + assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved") + } +} diff --git a/pkg/storage/storage_test.go b/pkg/storage/storage_test.go new file mode 100644 index 000000000..a9953a7ba --- /dev/null +++ b/pkg/storage/storage_test.go @@ -0,0 +1,47 @@ +package storage + +import ( + "context" + "time" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "google.golang.org/protobuf/types/known/anypb" +) + +type mockBackend struct { + put func(ctx context.Context, id string, data *anypb.Any) error + get func(ctx context.Context, id string) (*databroker.Record, error) + getAll func(ctx context.Context) ([]*databroker.Record, error) + list func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) + delete func(ctx context.Context, id string) error + clearDeleted func(ctx context.Context, cutoff time.Time) + watch func(ctx context.Context) chan struct{} +} + +func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error { + return m.put(ctx, id, data) +} + +func (m *mockBackend) Get(ctx context.Context, id string) (*databroker.Record, error) { + return m.get(ctx, id) +} + +func (m *mockBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) { + return m.getAll(ctx) +} + +func (m *mockBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) { + return m.list(ctx, sinceVersion) +} + +func (m *mockBackend) Delete(ctx context.Context, id string) error { + return m.delete(ctx, id) +} + +func (m *mockBackend) ClearDeleted(ctx context.Context, cutoff time.Time) { + m.clearDeleted(ctx, cutoff) +} + +func (m *mockBackend) Watch(ctx context.Context) chan struct{} { + return m.watch(ctx) +}