databroker: refactor databroker to sync all changes (#1879)

* refactor backend, implement encrypted store

* refactor in-memory store

* wip

* wip

* wip

* add syncer test

* fix redis expiry

* fix linting issues

* fix test by skipping non-config records

* fix backoff import

* fix init issues

* fix query

* wait for initial sync before starting directory sync

* add type to SyncLatest

* add more log messages, fix deadlock in in-memory store, always return server version from SyncLatest

* update sync types and tests

* add redis tests

* skip macos in github actions

* add comments to proto

* split getBackend into separate methods

* handle errors in initVersion

* return different error for not found vs other errors in get

* use exponential backoff for redis transaction retry

* rename raw to result

* use context instead of close channel

* store type urls as constants in databroker

* use timestampb instead of ptypes

* fix group merging not waiting

* change locked names

* update GetAll to return latest record version

* add method to grpcutil to get the type url for a protobuf type
This commit is contained in:
Caleb Doxsey 2021-02-18 15:24:33 -07:00 committed by GitHub
parent b1871b0f2e
commit 5d60cff21e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
66 changed files with 2762 additions and 2871 deletions

View file

@ -98,6 +98,7 @@ linters:
# - wsl
issues:
exclude-use-default: false
# List of regexps of issue texts to exclude, empty list by default.
# But independently from this option we use default exclude patterns,
# it can be disabled by `exclude-use-default: false`. To list all
@ -140,6 +141,10 @@ issues:
# good job Protobuffs!
- "method XXX"
- "SA1019"
# EXC0001 errcheck: Almost all programs ignore errors on these functions and in most cases it's ok
- Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv). is not checked
exclude-rules:
# https://github.com/go-critic/go-critic/issues/926
@ -179,6 +184,8 @@ issues:
text: "Potential hardcoded credentials"
linters:
- gosec
- linters: [golint]
text: "should have a package comment"
# golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration

View file

@ -526,17 +526,17 @@ func (a *Authenticate) saveSessionToDataBroker(
if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
_, err = user.Set(ctx, state.dataBrokerClient, mu.User)
_, err = user.Put(ctx, state.dataBrokerClient, mu.User)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)
}
}
res, err := session.Set(ctx, state.dataBrokerClient, s)
res, err := session.Put(ctx, state.dataBrokerClient, s)
if err != nil {
return fmt.Errorf("authenticate: error saving session: %w", err)
}
sessionState.Version = sessions.Version(res.GetServerVersion())
sessionState.Version = sessions.Version(fmt.Sprint(res.GetServerVersion()))
_, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{
UserId: s.UserId,

View file

@ -23,7 +23,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/config"
@ -166,7 +165,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
@ -246,9 +245,6 @@ func TestAuthenticate_SignOut(t *testing.T) {
encryptedEncoder: mock.Encoder{},
sharedEncoder: mock.Encoder{},
dataBrokerClient: mockDataBrokerServiceClient{
delete: func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
return nil, nil
},
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{
Id: "SESSION_ID",
@ -259,13 +255,16 @@ func TestAuthenticate_SignOut(t *testing.T) {
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
},
}, nil
},
put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
return nil, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
@ -368,8 +367,8 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return nil, fmt.Errorf("not implemented")
},
set: func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil
put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
return nil, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
@ -514,7 +513,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
@ -633,7 +632,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
@ -672,21 +671,16 @@ func TestAuthenticate_userInfo(t *testing.T) {
type mockDataBrokerServiceClient struct {
databroker.DataBrokerServiceClient
delete func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
set func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error)
}
func (m mockDataBrokerServiceClient) Delete(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
return m.delete(ctx, in, opts...)
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
}
func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return m.get(ctx, in, opts...)
}
func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
return m.set(ctx, in, opts...)
func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
return m.put(ctx, in, opts...)
}
type mockDirectoryServiceClient struct {
@ -729,7 +723,7 @@ func TestAuthenticate_SignOut_CSRF(t *testing.T) {
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,

View file

@ -24,19 +24,16 @@ type Authorize struct {
currentOptions *config.AtomicOptions
templates *template.Template
dataBrokerInitialSync map[string]chan struct{}
dataBrokerInitialSync chan struct{}
}
// New validates and creates a new Authorize service from a set of config options.
func New(cfg *config.Config) (*Authorize, error) {
a := Authorize{
currentOptions: config.NewAtomicOptions(),
store: evaluator.NewStore(),
templates: template.Must(frontend.NewTemplates()),
dataBrokerInitialSync: map[string]chan struct{}{
"type.googleapis.com/directory.Group": make(chan struct{}, 1),
"type.googleapis.com/directory.User": make(chan struct{}, 1),
},
currentOptions: config.NewAtomicOptions(),
store: evaluator.NewStore(),
templates: template.Must(frontend.NewTemplates()),
dataBrokerInitialSync: make(chan struct{}),
}
state, err := newAuthorizeStateFromConfig(cfg, a.store)
@ -48,6 +45,22 @@ func New(cfg *config.Config) (*Authorize, error) {
return &a, nil
}
// Run runs the authorize service.
func (a *Authorize) Run(ctx context.Context) error {
return newDataBrokerSyncer(a).Run(ctx)
}
// WaitForInitialSync blocks until the initial sync is complete.
func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-a.dataBrokerInitialSync:
}
log.Info().Msg("initial sync from databroker complete")
return nil
}
func validateOptions(o *config.Options) error {
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
return fmt.Errorf("authorize: bad 'SHARED_SECRET': %w", err)

View file

@ -115,7 +115,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
},
})
store.UpdateRecord(&databroker.Record{
Version: "1",
Version: 1,
Type: "type.googleapis.com/session.Session",
Id: sessionID,
Data: data,
@ -126,7 +126,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
Email: "foo@example.com",
})
store.UpdateRecord(&databroker.Record{
Version: "1",
Version: 1,
Type: "type.googleapis.com/user.User",
Id: userID,
Data: data,
@ -188,7 +188,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
},
})
store.UpdateRecord(&databroker.Record{
Version: fmt.Sprint(i),
Version: uint64(i),
Type: "type.googleapis.com/session.Session",
Id: sessionID,
Data: data,
@ -198,7 +198,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
Id: userID,
})
store.UpdateRecord(&databroker.Record{
Version: fmt.Sprint(i),
Version: uint64(i),
Type: "type.googleapis.com/user.User",
Id: userID,
Data: data,

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
@ -40,9 +41,8 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
}
record := new(databroker.Record)
record.CreatedAt = timestamppb.Now()
record.ModifiedAt = timestamppb.Now()
record.Version = uuid.New().String()
record.Version = cryptutil.NewRandomUInt64()
record.Id = uuid.New().String()
record.Data = any
record.Type = any.TypeUrl
@ -56,8 +56,8 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
}
// ClearRecords removes all the records from the store.
func (s *Store) ClearRecords(typeURL string) {
rawPath := fmt.Sprintf("/databroker_data/%s", typeURL)
func (s *Store) ClearRecords() {
rawPath := "/databroker_data"
s.delete(rawPath)
}

View file

@ -27,7 +27,7 @@ func TestStore(t *testing.T) {
}
any, _ := ptypes.MarshalAny(u)
s.UpdateRecord(&databroker.Record{
Version: "v1",
Version: 1,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
@ -43,7 +43,7 @@ func TestStore(t *testing.T) {
}, v)
s.UpdateRecord(&databroker.Record{
Version: "v2",
Version: 2,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
@ -55,7 +55,7 @@ func TestStore(t *testing.T) {
assert.Nil(t, v)
s.UpdateRecord(&databroker.Record{
Version: "v1",
Version: 3,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
@ -65,7 +65,7 @@ func TestStore(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, v)
s.ClearRecords("type.googleapis.com/user.User")
s.ClearRecords()
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1"))
assert.Error(t, err)
assert.Nil(t, v)

View file

@ -22,16 +22,11 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
)
const (
serviceAccountTypeURL = "type.googleapis.com/user.ServiceAccount"
sessionTypeURL = "type.googleapis.com/session.Session"
userTypeURL = "type.googleapis.com/user.User"
)
// Check implements the envoy auth server gRPC endpoint.
func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRequest) (*envoy_service_auth_v2.CheckResponse, error) {
ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check")
@ -108,18 +103,18 @@ func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) inte
state := a.state.Load()
s, ok := a.store.GetRecordData(sessionTypeURL, sessionID).(*session.Session)
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
if ok {
return s
}
sa, ok := a.store.GetRecordData(serviceAccountTypeURL, sessionID).(*user.ServiceAccount)
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
if ok {
return sa
}
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: sessionTypeURL,
Type: grpcutil.GetTypeURL(new(session.Session)),
Id: sessionID,
})
if err != nil {
@ -127,10 +122,10 @@ func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) inte
return nil
}
if current := a.store.GetRecordData(sessionTypeURL, sessionID); current == nil {
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID); current == nil {
a.store.UpdateRecord(res.GetRecord())
}
s, _ = a.store.GetRecordData(sessionTypeURL, sessionID).(*session.Session)
s, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
return s
}
@ -141,13 +136,13 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
state := a.state.Load()
u, ok := a.store.GetRecordData(userTypeURL, userID).(*user.User)
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
if ok {
return u
}
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: userTypeURL,
Type: grpcutil.GetTypeURL(new(user.User)),
Id: userID,
})
if err != nil {
@ -155,10 +150,10 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
return nil
}
if current := a.store.GetRecordData(userTypeURL, userID); current == nil {
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID); current == nil {
a.store.UpdateRecord(res.GetRecord())
}
u, _ = a.store.GetRecordData(userTypeURL, userID).(*user.User)
u, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
return u
}

View file

@ -325,7 +325,7 @@ func TestSync(t *testing.T) {
})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
@ -336,7 +336,7 @@ func TestSync(t *testing.T) {
data, _ := ptypes.MarshalAny(&user.User{Id: in.GetId()})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
@ -391,7 +391,7 @@ func TestSync(t *testing.T) {
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
@ -418,7 +418,7 @@ func TestSync(t *testing.T) {
})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,

View file

@ -1,198 +0,0 @@
package authorize
import (
"context"
"io"
"time"
backoff "github.com/cenkalti/backoff/v4"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Run runs the authorize server.
func (a *Authorize) Run(ctx context.Context) error {
eg, ctx := errgroup.WithContext(ctx)
updateTypes := make(chan []string)
eg.Go(func() error {
return a.runTypesSyncer(ctx, updateTypes)
})
eg.Go(func() error {
return a.runDataSyncer(ctx, updateTypes)
})
return eg.Wait()
}
// WaitForInitialSync waits for the initial sync to complete.
func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
for typeURL, ch := range a.dataBrokerInitialSync {
log.Info().Str("type_url", typeURL).Msg("waiting for initial sync")
select {
case <-ch:
case <-ctx.Done():
return ctx.Err()
}
log.Info().Str("type_url", typeURL).Msg("initial sync complete")
}
return nil
}
func (a *Authorize) runTypesSyncer(ctx context.Context, updateTypes chan<- []string) error {
log.Info().Msg("starting type sync")
return tryForever(ctx, func(backoff interface{ Reset() }) error {
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync")
defer span.End()
stream, err := a.state.Load().dataBrokerClient.SyncTypes(ctx, new(emptypb.Empty))
if err != nil {
return err
}
for {
res, err := stream.Recv()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
backoff.Reset()
select {
case <-stream.Context().Done():
return stream.Context().Err()
case updateTypes <- res.GetTypes():
}
}
})
}
func (a *Authorize) runDataSyncer(ctx context.Context, updateTypes <-chan []string) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
seen := map[string]struct{}{}
for {
select {
case <-ctx.Done():
return ctx.Err()
case types := <-updateTypes:
for _, dataType := range types {
dataType := dataType
if _, ok := seen[dataType]; !ok {
eg.Go(func() error {
return a.runDataTypeSyncer(ctx, dataType)
})
seen[dataType] = struct{}{}
}
}
}
}
})
return eg.Wait()
}
func (a *Authorize) runDataTypeSyncer(ctx context.Context, typeURL string) error {
var serverVersion, recordVersion string
log.Info().Str("type_url", typeURL).Msg("starting data initial load")
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.InitialSync")
backoff := backoff.NewExponentialBackOff()
backoff.MaxElapsedTime = 0
for {
res, err := databroker.InitialSync(ctx, a.state.Load().dataBrokerClient, &databroker.SyncRequest{
Type: typeURL,
})
if err != nil {
log.Warn().Err(err).Str("type_url", typeURL).Msg("error getting data")
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff.NextBackOff()):
}
continue
}
serverVersion = res.GetServerVersion()
for _, record := range res.GetRecords() {
a.store.UpdateRecord(record)
recordVersion = record.GetVersion()
}
break
}
span.End()
if ch, ok := a.dataBrokerInitialSync[typeURL]; ok {
select {
case ch <- struct{}{}:
default:
}
}
log.Info().Str("type_url", typeURL).Msg("starting data syncer")
return tryForever(ctx, func(backoff interface{ Reset() }) error {
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync")
defer span.End()
stream, err := a.state.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
ServerVersion: serverVersion,
RecordVersion: recordVersion,
Type: typeURL,
})
if err != nil {
return err
}
for {
res, err := stream.Recv()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
backoff.Reset()
if res.GetServerVersion() != serverVersion {
log.Info().
Str("old_version", serverVersion).
Str("new_version", res.GetServerVersion()).
Str("type_url", typeURL).
Msg("detected new server version, clearing data")
serverVersion = res.GetServerVersion()
recordVersion = ""
a.store.ClearRecords(typeURL)
}
for _, record := range res.GetRecords() {
if record.GetVersion() > recordVersion {
recordVersion = record.GetVersion()
}
}
for _, record := range res.GetRecords() {
a.store.UpdateRecord(record)
}
}
})
}
func tryForever(ctx context.Context, callback func(onSuccess interface{ Reset() }) error) error {
backoff := backoff.NewExponentialBackOff()
backoff.MaxElapsedTime = 0
for {
err := callback(backoff)
if err != nil {
log.Warn().Err(err).Msg("sync error")
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff.NextBackOff()):
}
}
}

41
authorize/sync.go Normal file
View file

@ -0,0 +1,41 @@
package authorize
import (
"context"
"sync"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type dataBrokerSyncer struct {
*databroker.Syncer
authorize *Authorize
signalOnce sync.Once
}
func newDataBrokerSyncer(authorize *Authorize) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{
authorize: authorize,
}
syncer.Syncer = databroker.NewSyncer(syncer)
return syncer
}
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return syncer.authorize.state.Load().dataBrokerClient
}
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
syncer.authorize.store.ClearRecords()
}
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
for _, record := range records {
syncer.authorize.store.UpdateRecord(record)
}
// the first time we update records we signal the initial sync
syncer.signalOnce.Do(func() {
close(syncer.authorize.dataBrokerInitialSync)
})
}

View file

@ -43,7 +43,7 @@ func loadCachedCredential(serverURL string) *ExecCredential {
if err != nil {
return nil
}
defer f.Close()
defer func() { _ = f.Close() }()
var creds ExecCredential
err = json.NewDecoder(f).Decode(&creds)

View file

@ -1,3 +1,4 @@
// Package main implements the pomerium-cli.
package main
import (

View file

@ -109,9 +109,9 @@ func decodeJWTClaimHeadersHookFunc() mapstructure.DecodeHookFunc {
// A StringSlice is a slice of strings.
type StringSlice []string
// NewStringSlice creatse a new StringSlice.
// NewStringSlice creates a new StringSlice.
func NewStringSlice(values ...string) StringSlice {
return StringSlice(values)
return values
}
const (
@ -197,7 +197,7 @@ type WeightedURL struct {
LbWeight uint32
}
// Validate validates the WeightedURL.
// Validate validates that the WeightedURL is valid.
func (u *WeightedURL) Validate() error {
if u.URL.Hostname() == "" {
return errHostnameMustBeSpecified
@ -227,6 +227,7 @@ func ParseWeightedURL(dst string) (*WeightedURL, error) {
return &WeightedURL{*u, w}, nil
}
// String returns the WeightedURL as a string.
func (u *WeightedURL) String() string {
str := u.URL.String()
if u.LbWeight == 0 {
@ -235,7 +236,7 @@ func (u *WeightedURL) String() string {
return fmt.Sprintf("{url=%s, weight=%d}", str, u.LbWeight)
}
// WeightedURLs is a slice of WeightedURL.
// WeightedURLs is a slice of WeightedURLs.
type WeightedURLs []WeightedURL
// ParseWeightedUrls parses
@ -285,9 +286,9 @@ func (urls WeightedURLs) Validate() (HasWeight, error) {
}
if noWeight {
return HasWeight(false), nil
return false, nil
}
return HasWeight(true), nil
return true, nil
}
// Flatten converts weighted url array into indidual arrays of urls and weights
@ -311,7 +312,7 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
return str, wghts, nil
}
// DecodePolicyBase64Hook creates a mapstructure DecodeHookFunc.
// DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data.
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
if t != reflect.TypeOf([]Policy{}) {
@ -332,7 +333,7 @@ func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
return nil, fmt.Errorf("base64 decoding policy data: %w", err)
}
out := []map[interface{}]interface{}{}
var out []map[interface{}]interface{}
if err = yaml.Unmarshal(bytes, &out); err != nil {
return nil, fmt.Errorf("parsing base64-encoded policy data as yaml: %w", err)
}
@ -341,7 +342,7 @@ func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
}
}
// DecodePolicyHookFunc creates a mapstructure DecodeHookFunc.
// DecodePolicyHookFunc returns a Decode Hook for mapstructure.
func DecodePolicyHookFunc() mapstructure.DecodeHookFunc {
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
if t != reflect.TypeOf(Policy{}) {

View file

@ -1,3 +1,4 @@
// Package databroker contains the databroker service.
package databroker
import (
@ -5,8 +6,6 @@ import (
"encoding/base64"
"sync/atomic"
"github.com/golang/protobuf/ptypes/empty"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/databroker"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -52,13 +51,6 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) {
srv.sharedKey.Store(bs)
}
func (srv *dataBrokerServer) Delete(ctx context.Context, req *databrokerpb.DeleteRequest) (*empty.Empty, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.Delete(ctx, req)
}
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
@ -66,13 +58,6 @@ func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetReque
return srv.server.Get(ctx, req)
}
func (srv *dataBrokerServer) GetAll(ctx context.Context, req *databrokerpb.GetAllRequest) (*databrokerpb.GetAllResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.GetAll(ctx, req)
}
func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
@ -80,11 +65,11 @@ func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryR
return srv.server.Query(ctx, req)
}
func (srv *dataBrokerServer) Set(ctx context.Context, req *databrokerpb.SetRequest) (*databrokerpb.SetResponse, error) {
func (srv *dataBrokerServer) Put(ctx context.Context, req *databrokerpb.PutRequest) (*databrokerpb.PutResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.Set(ctx, req)
return srv.server.Put(ctx, req)
}
func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error {
@ -94,16 +79,9 @@ func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrok
return srv.server.Sync(req, stream)
}
func (srv *dataBrokerServer) GetTypes(ctx context.Context, req *empty.Empty) (*databrokerpb.GetTypesResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.GetTypes(ctx, req)
}
func (srv *dataBrokerServer) SyncTypes(req *empty.Empty, stream databrokerpb.DataBrokerService_SyncTypesServer) error {
func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, stream databrokerpb.DataBrokerService_SyncLatestServer) error {
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
return err
}
return srv.server.SyncTypes(req, stream)
return srv.server.SyncLatest(req, stream)
}

View file

@ -7,9 +7,10 @@ import (
"testing"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
@ -51,44 +52,44 @@ func TestServerSync(t *testing.T) {
any, _ := ptypes.MarshalAny(new(user.User))
numRecords := 200
var serverVersion uint64
for i := 0; i < numRecords; i++ {
c.Set(ctx, &databroker.SetRequest{Type: any.TypeUrl, Id: strconv.Itoa(i), Data: any})
res, err := c.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.TypeUrl,
Id: strconv.Itoa(i),
Data: any,
},
})
require.NoError(t, err)
serverVersion = res.GetServerVersion()
}
t.Run("Sync ok", func(t *testing.T) {
client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()})
client, _ := c.Sync(ctx, &databroker.SyncRequest{
ServerVersion: serverVersion,
})
count := 0
for {
res, err := client.Recv()
_, err := client.Recv()
if err != nil {
break
}
count += len(res.Records)
count++
if count == numRecords {
break
}
}
})
t.Run("Error occurred while syncing", func(t *testing.T) {
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()})
count := 0
numRecordsWanted := 100
cancelFuncCalled := false
for {
res, err := client.Recv()
if err != nil {
assert.True(t, cancelFuncCalled)
break
}
count += len(res.Records)
if count == numRecordsWanted {
cancelFunc()
cancelFuncCalled = true
}
}
t.Run("Aborted", func(t *testing.T) {
client, err := c.Sync(ctx, &databroker.SyncRequest{
ServerVersion: 0,
})
require.NoError(t, err)
_, err = client.Recv()
require.Error(t, err)
require.Equal(t, codes.Aborted, status.Code(err))
})
}
@ -104,19 +105,25 @@ func BenchmarkSync(b *testing.B) {
numRecords := 10000
for i := 0; i < numRecords; i++ {
c.Set(ctx, &databroker.SetRequest{Type: any.TypeUrl, Id: strconv.Itoa(i), Data: any})
_, _ = c.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.TypeUrl,
Id: strconv.Itoa(i),
Data: any,
},
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()})
client, _ := c.Sync(ctx, &databroker.SyncRequest{})
count := 0
for {
res, err := client.Recv()
_, err := client.Recv()
if err != nil {
break
}
count += len(res.Records)
count++
if count == numRecords {
break
}

View file

@ -31,10 +31,12 @@ func (c *DataBroker) RefreshUser(ctx context.Context, req *directory.RefreshUser
return nil, err
}
_, err = c.dataBrokerServer.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
_, err = c.dataBrokerServer.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
},
})
if err != nil {
return nil, err

View file

@ -61,6 +61,7 @@ func main() {
}
}
// Result is a result used for rendering.
type Result struct {
Headers map[string]string `json:"headers"`
Method string `json:"method"`

View file

@ -1,3 +1,4 @@
// Package filemgr defines a Manager for managing files for the controlplane.
package filemgr
import (

View file

@ -13,8 +13,6 @@ var (
// DefaultDeletePermanentlyAfter is the default amount of time to wait before deleting
// a record permanently.
DefaultDeletePermanentlyAfter = time.Hour
// DefaultBTreeDegree is the default number of items to store in each node of the BTree.
DefaultBTreeDegree = 8
// DefaultStorageType is the default storage type that Server use
DefaultStorageType = "memory"
// DefaultGetAllPageSize is the default page size for GetAll calls.
@ -23,7 +21,6 @@ var (
type serverConfig struct {
deletePermanentlyAfter time.Duration
btreeDegree int
secret []byte
storageType string
storageConnectionString string
@ -36,7 +33,6 @@ type serverConfig struct {
func newServerConfig(options ...ServerOption) *serverConfig {
cfg := new(serverConfig)
WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg)
WithBTreeDegree(DefaultBTreeDegree)(cfg)
WithStorageType(DefaultStorageType)(cfg)
WithGetAllPageSize(DefaultGetAllPageSize)(cfg)
for _, option := range options {
@ -48,13 +44,6 @@ func newServerConfig(options ...ServerOption) *serverConfig {
// A ServerOption customizes the server.
type ServerOption func(*serverConfig)
// WithBTreeDegree sets the number of items to store in each node of the BTree.
func WithBTreeDegree(degree int) ServerOption {
return func(cfg *serverConfig) {
cfg.btreeDegree = degree
}
}
// WithDeletePermanentlyAfter sets the deletePermanentlyAfter duration.
// If a record is deleted via Delete, it will be permanently deleted after
// the given duration.

View file

@ -3,12 +3,7 @@ package databroker
import (
"context"
"encoding/base64"
"errors"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/golang/protobuf/ptypes"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/hashutil"
@ -17,15 +12,9 @@ import (
"github.com/pomerium/pomerium/pkg/grpc"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpcutil"
)
var configTypeURL string
func init() {
any, _ := ptypes.MarshalAny(new(configpb.Config))
configTypeURL = any.GetTypeUrl()
}
// ConfigSource provides a new Config source that decorates an underlying config with
// configuration derived from the data broker.
type ConfigSource struct {
@ -35,8 +24,6 @@ type ConfigSource struct {
dbConfigs map[string]*configpb.Config
updaterHash uint64
cancel func()
serverVersion string
recordVersion string
config.ChangeDispatcher
}
@ -188,78 +175,51 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
ctx := context.Background()
ctx, src.cancel = context.WithCancel(ctx)
go tryForever(ctx, func(onSuccess func()) error {
src.mu.Lock()
serverVersion, recordVersion := src.serverVersion, src.recordVersion
src.mu.Unlock()
stream, err := client.Sync(ctx, &databroker.SyncRequest{
Type: configTypeURL,
ServerVersion: serverVersion,
RecordVersion: recordVersion,
})
if err != nil {
return err
}
for {
res, err := stream.Recv()
if err != nil {
return err
}
onSuccess()
if len(res.GetRecords()) > 0 {
src.onSync(res.GetRecords())
for _, record := range res.GetRecords() {
recordVersion = record.GetVersion()
}
}
src.mu.Lock()
src.serverVersion, src.recordVersion = res.GetServerVersion(), recordVersion
src.mu.Unlock()
}
})
syncer := databroker.NewSyncer(&syncerHandler{
client: client,
src: src,
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))))
go func() { _ = syncer.Run(ctx) }()
}
func (src *ConfigSource) onSync(records []*databroker.Record) {
src.mu.Lock()
type syncerHandler struct {
src *ConfigSource
client databroker.DataBrokerServiceClient
}
func (s *syncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return s.client
}
func (s *syncerHandler) ClearRecords(ctx context.Context) {
s.src.mu.Lock()
s.src.dbConfigs = map[string]*configpb.Config{}
s.src.mu.Unlock()
}
func (s *syncerHandler) UpdateRecords(ctx context.Context, records []*databroker.Record) {
if len(records) == 0 {
return
}
s.src.mu.Lock()
for _, record := range records {
if record.GetDeletedAt() != nil {
delete(src.dbConfigs, record.GetId())
delete(s.src.dbConfigs, record.GetId())
continue
}
var cfgpb configpb.Config
err := ptypes.UnmarshalAny(record.GetData(), &cfgpb)
err := record.GetData().UnmarshalTo(&cfgpb)
if err != nil {
log.Warn().Err(err).Msg("databroker: error decoding config")
delete(src.dbConfigs, record.GetId())
delete(s.src.dbConfigs, record.GetId())
continue
}
src.dbConfigs[record.GetId()] = &cfgpb
s.src.dbConfigs[record.GetId()] = &cfgpb
}
src.mu.Unlock()
s.src.mu.Unlock()
src.rebuild(false)
}
func tryForever(ctx context.Context, callback func(onSuccess func()) error) {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
for {
err := callback(bo.Reset)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
} else if err != nil {
log.Warn().Err(err).Msg("sync error")
}
select {
case <-ctx.Done():
return
case <-time.After(bo.NextBackOff()):
}
}
s.src.rebuild(false)
}

View file

@ -7,9 +7,9 @@ import (
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/config"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
@ -24,7 +24,7 @@ func TestConfigSource(t *testing.T) {
if !assert.NoError(t, err) {
return
}
defer li.Close()
defer func() { _ = li.Close() }()
dataBrokerServer := New()
srv := grpc.NewServer()
@ -45,7 +45,7 @@ func TestConfigSource(t *testing.T) {
})
cfgs <- src.GetConfig()
data, _ := ptypes.MarshalAny(&configpb.Config{
data, _ := anypb.New(&configpb.Config{
Name: "config",
Routes: []*configpb.Route{
{
@ -54,10 +54,12 @@ func TestConfigSource(t *testing.T) {
},
},
})
_, _ = dataBrokerServer.Set(ctx, &databroker.SetRequest{
Type: configTypeURL,
Id: "1",
Data: data,
_, _ = dataBrokerServer.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: data.TypeUrl,
Id: "1",
Data: data,
},
})
select {

View file

@ -4,28 +4,24 @@ package databroker
import (
"context"
"crypto/tls"
"errors"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/inmemory"
"github.com/pomerium/pomerium/pkg/storage/redis"
@ -34,81 +30,57 @@ import (
const (
recordTypeServerVersion = "server_version"
serverVersionKey = "version"
syncBatchSize = 100
)
// newUUID returns a new UUID. This make it easy to stub out in tests.
var newUUID = uuid.New
// Server implements the databroker service using an in memory database.
type Server struct {
cfg *serverConfig
log zerolog.Logger
mu sync.RWMutex
version string
byType map[string]storage.Backend
onTypechange *signal.Signal
mu sync.RWMutex
version uint64
backend storage.Backend
}
// New creates a new server.
func New(options ...ServerOption) *Server {
srv := &Server{
log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]storage.Backend),
onTypechange: signal.New(),
}
srv.UpdateConfig(options...)
go func() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
srv.mu.RLock()
tm := time.Now().Add(-srv.cfg.deletePermanentlyAfter)
srv.mu.RUnlock()
var recordTypes []string
srv.mu.RLock()
for recordType := range srv.byType {
recordTypes = append(recordTypes, recordType)
}
srv.mu.RUnlock()
for _, recordType := range recordTypes {
db, _, err := srv.getDB(recordType, true)
if err != nil {
continue
}
db.ClearDeleted(context.Background(), tm)
}
}
}()
return srv
}
func (srv *Server) initVersion() {
dbServerVersion, _, err := srv.getDB(recordTypeServerVersion, false)
db, _, err := srv.getBackendLocked()
if err != nil {
log.Error().Err(err).Msg("failed to init server version")
return
}
// Get version from storage first.
if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
var sv databroker.ServerVersion
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
srv.version = sv.Version
r, err := db.Get(context.Background(), recordTypeServerVersion, serverVersionKey)
switch {
case err == nil:
var sv wrapperspb.UInt64Value
if err := r.GetData().UnmarshalTo(&sv); err == nil {
srv.log.Debug().Uint64("server_version", sv.Value).Msg("got db version from Backend")
srv.version = sv.Value
}
return
case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one
case err != nil:
log.Error().Err(err).Msg("failed to retrieve server version")
return
}
srv.version = newUUID().String()
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
srv.version = cryptutil.NewRandomUInt64()
data, _ := anypb.New(wrapperspb.UInt64(srv.version))
if err := db.Put(context.Background(), &databroker.Record{
Type: recordTypeServerVersion,
Id: serverVersionKey,
Data: data,
}); err != nil {
srv.log.Warn().Err(err).Msg("failed to save server version.")
}
}
@ -125,121 +97,49 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
}
srv.cfg = cfg
for t, db := range srv.byType {
err := db.Close()
if srv.backend != nil {
err := srv.backend.Close()
if err != nil {
log.Warn().Err(err).Msg("databroker: error closing backend")
log.Error().Err(err).Msg("databroker: error closing backend")
}
delete(srv.byType, t)
srv.backend = nil
}
srv.initVersion()
}
// Delete deletes a record from the in-memory list.
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("delete")
db, _, err := srv.getDB(req.GetType(), true)
if err != nil {
return nil, err
}
if err := db.Delete(ctx, req.GetId()); err != nil {
return nil, err
}
return new(empty.Empty), nil
}
// Get gets a record from the in-memory list.
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Get")
defer span.End()
srv.log.Info().
Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("get")
db, _, err := srv.getDB(req.GetType(), true)
db, _, err := srv.getBackend()
if err != nil {
return nil, err
}
record, err := db.Get(ctx, req.GetId())
if err != nil {
record, err := db.Get(ctx, req.GetType(), req.GetId())
switch {
case errors.Is(err, storage.ErrNotFound):
return nil, status.Error(codes.NotFound, "record not found")
}
if record.DeletedAt != nil {
case err != nil:
return nil, status.Error(codes.Internal, err.Error())
case record.DeletedAt != nil:
return nil, status.Error(codes.NotFound, "record not found")
}
return &databroker.GetResponse{Record: record}, nil
}
// GetAll gets all the records from the backend.
func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*databroker.GetAllResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.GetAll")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Msg("get all")
db, version, err := srv.getDB(req.GetType(), true)
if err != nil {
return nil, err
}
all, err := db.List(ctx, "")
if err != nil {
return nil, err
}
// sort by record version
sort.Slice(all, func(i, j int) bool {
return all[i].Version < all[j].Version
})
var recordVersion string
records := make([]*databroker.Record, 0, len(all))
for _, record := range all {
// skip previous page records
if record.GetVersion() <= req.PageToken {
continue
}
recordVersion = record.GetVersion()
if record.DeletedAt == nil {
records = append(records, record)
}
// stop when we've hit the page size
if len(records) >= srv.cfg.getAllPageSize {
break
}
}
nextPageToken := recordVersion
if len(records) < srv.cfg.getAllPageSize {
nextPageToken = ""
}
return &databroker.GetAllResponse{
ServerVersion: version,
RecordVersion: recordVersion,
Records: records,
NextPageToken: nextPageToken,
}, nil
}
// Query queries for records.
func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Query")
defer span.End()
srv.log.Info().
Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", req.GetType()).
Str("query", req.GetQuery()).
Int64("offset", req.GetOffset()).
@ -248,21 +148,25 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
query := strings.ToLower(req.GetQuery())
db, _, err := srv.getDB(req.GetType(), true)
db, _, err := srv.getBackend()
if err != nil {
return nil, err
}
all, err := db.List(ctx, "")
all, _, err := db.GetAll(ctx)
if err != nil {
return nil, err
}
var filtered []*databroker.Record
for _, record := range all {
if record.DeletedAt == nil && storage.MatchAny(record.GetData(), query) {
filtered = append(filtered, record)
if record.GetType() != req.GetType() {
continue
}
if query != "" && !storage.MatchAny(record.GetData(), query) {
continue
}
filtered = append(filtered, record)
}
records, totalCount := databroker.ApplyOffsetAndLimit(filtered, int(req.GetOffset()), int(req.GetLimit()))
@ -272,197 +176,160 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
}, nil
}
// Set updates a record in the in-memory list, or adds a new one.
func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databroker.SetResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Set")
// Put updates a record in the in-memory list, or adds a new one.
func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databroker.PutResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Put")
defer span.End()
record := req.GetRecord()
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("set")
Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", record.GetType()).
Str("id", record.GetId()).
Msg("put")
db, version, err := srv.getDB(req.GetType(), true)
db, version, err := srv.getBackend()
if err != nil {
return nil, err
}
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
if err := db.Put(ctx, record); err != nil {
return nil, err
}
record, err := db.Get(ctx, req.GetId())
if err != nil {
return nil, err
}
return &databroker.SetResponse{
Record: record,
return &databroker.PutResponse{
ServerVersion: version,
Record: record,
}, nil
}
func (srv *Server) doSync(ctx context.Context,
serverVersion string, recordVersion *string,
db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
updated, err := db.List(ctx, *recordVersion)
if err != nil {
return err
}
if len(updated) == 0 {
return nil
}
*recordVersion = updated[len(updated)-1].Version
for i := 0; i < len(updated); i += syncBatchSize {
j := i + syncBatchSize
if j > len(updated) {
j = len(updated)
}
if err := stream.Send(&databroker.SyncResponse{
ServerVersion: serverVersion,
Records: updated[i:j],
}); err != nil {
return err
}
}
return nil
}
// Sync streams updates for the given record type.
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Str("server_version", req.GetServerVersion()).
Str("record_version", req.GetRecordVersion()).
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Uint64("server_version", req.GetServerVersion()).
Uint64("record_version", req.GetRecordVersion()).
Msg("sync")
db, serverVersion, err := srv.getDB(req.GetType(), true)
backend, serverVersion, err := srv.getBackend()
if err != nil {
return err
}
recordVersion := req.GetRecordVersion()
// reset record version if the server versions don't match
if req.GetServerVersion() != serverVersion {
serverVersion = req.GetServerVersion()
recordVersion = ""
// send the new server version to the client
err := stream.Send(&databroker.SyncResponse{
ServerVersion: serverVersion,
})
if err != nil {
return err
}
return status.Errorf(codes.Aborted, "invalid server version, expected: %d", req.GetServerVersion())
}
ctx := stream.Context()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var ch <-chan struct{}
if !req.GetNoWait() {
ch = db.Watch(ctx)
recordStream, err := backend.Sync(ctx, req.GetRecordVersion())
if err != nil {
return err
}
defer func() { _ = recordStream.Close() }()
for recordStream.Next(true) {
err = stream.Send(&databroker.SyncResponse{
ServerVersion: serverVersion,
Record: recordStream.Record(),
})
if err != nil {
return err
}
}
// Do first sync, so we won't miss anything.
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
return recordStream.Err()
}
// SyncLatest returns the latest value of every record in the databroker as a stream of records.
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest")
defer span.End()
srv.log.Info().
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Str("type", req.GetType()).
Msg("sync latest")
backend, serverVersion, err := srv.getBackend()
if err != nil {
return err
}
if req.GetNoWait() {
return nil
ctx := stream.Context()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
records, latestRecordVersion, err := backend.GetAll(ctx)
if err != nil {
return err
}
for range ch {
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
return err
}
}
return nil
}
// GetTypes returns all the known record types.
func (srv *Server) GetTypes(ctx context.Context, _ *emptypb.Empty) (*databroker.GetTypesResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.GetTypes")
defer span.End()
var recordTypes []string
srv.mu.RLock()
for recordType := range srv.byType {
recordTypes = append(recordTypes, recordType)
}
srv.mu.RUnlock()
sort.Strings(recordTypes)
return &databroker.GetTypesResponse{
Types: recordTypes,
}, nil
}
// SyncTypes synchronizes all the known record types.
func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerService_SyncTypesServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncTypes")
defer span.End()
srv.log.Info().
Msg("sync types")
ch := srv.onTypechange.Bind()
defer srv.onTypechange.Unbind(ch)
var prev []string
for {
res, err := srv.GetTypes(stream.Context(), req)
if err != nil {
return err
}
if prev == nil || !reflect.DeepEqual(prev, res.Types) {
err := stream.Send(res)
for _, record := range records {
if req.GetType() == "" || req.GetType() == record.GetType() {
err = stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Record{
Record: record,
},
})
if err != nil {
return err
}
prev = res.Types
}
select {
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
}
}
// always send the server version last in case there are no records
return stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Versions{
Versions: &databroker.Versions{
ServerVersion: serverVersion,
LatestRecordVersion: latestRecordVersion,
},
},
})
}
func (srv *Server) getDB(recordType string, lock bool) (db storage.Backend, version string, err error) {
func (srv *Server) getBackend() (backend storage.Backend, version uint64, err error) {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
if lock {
srv.mu.RLock()
}
db = srv.byType[recordType]
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
srv.mu.RLock()
backend = srv.backend
version = srv.version
if lock {
srv.mu.RUnlock()
}
if db == nil {
if lock {
srv.mu.Lock()
}
db = srv.byType[recordType]
srv.mu.RUnlock()
if backend == nil {
srv.mu.Lock()
backend = srv.backend
version = srv.version
var err error
if db == nil {
db, err = srv.newDB(recordType)
srv.byType[recordType] = db
defer srv.onTypechange.Broadcast()
}
if lock {
srv.mu.Unlock()
if backend == nil {
backend, err = srv.newBackendLocked()
srv.backend = backend
}
srv.mu.Unlock()
if err != nil {
return nil, "", err
return nil, 0, err
}
}
return db, version, nil
return backend, version, nil
}
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
func (srv *Server) getBackendLocked() (backend storage.Backend, version uint64, err error) {
backend = srv.backend
version = srv.version
if backend == nil {
var err error
backend, err = srv.newBackendLocked()
srv.backend = backend
if err != nil {
return nil, 0, err
}
}
return backend, version, nil
}
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
if err != nil {
log.Warn().Err(err).Msg("failed to read databroker CA file")
@ -478,11 +345,12 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
switch srv.cfg.storageType {
case config.StorageInMemoryName:
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
srv.log.Info().Msg("using in-memory store")
return inmemory.New(), nil
case config.StorageRedisName:
db, err = redis.New(
srv.log.Info().Msg("using redis store")
backend, err = redis.New(
srv.cfg.storageConnectionString,
redis.WithRecordType(recordType),
redis.WithTLSConfig(tlsConfig),
)
if err != nil {
@ -492,10 +360,10 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
}
if srv.cfg.secret != nil {
db, err = storage.NewEncryptedBackend(srv.cfg.secret, db)
backend, err = storage.NewEncryptedBackend(srv.cfg.secret, backend)
if err != nil {
return nil, err
}
}
return db, nil
return backend, nil
}

View file

@ -2,101 +2,27 @@ package databroker
import (
"context"
"fmt"
"testing"
"github.com/golang/protobuf/ptypes"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/storage"
)
func newServer(cfg *serverConfig) *Server {
return &Server{
version: uuid.New().String(),
version: 11,
cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]storage.Backend),
onTypechange: signal.New(),
}
}
func TestServer_initVersion(t *testing.T) {
cfg := newServerConfig()
t.Run("nil db", func(t *testing.T) {
srvVersion := uuid.New()
oldNewUUID := newUUID
newUUID = func() uuid.UUID {
return srvVersion
}
defer func() { newUUID = oldNewUUID }()
srv := newServer(cfg)
srv.byType[recordTypeServerVersion] = nil
srv.initVersion()
assert.Equal(t, srvVersion.String(), srv.version)
})
t.Run("new server with random version", func(t *testing.T) {
srvVersion := uuid.New()
oldNewUUID := newUUID
newUUID = func() uuid.UUID {
return srvVersion
}
defer func() { newUUID = oldNewUUID }()
srv := newServer(cfg)
ctx := context.Background()
db, _, err := srv.getDB(recordTypeServerVersion, false)
require.NoError(t, err)
r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)
assert.Nil(t, r)
srv.initVersion()
assert.Equal(t, srvVersion.String(), srv.version)
r, err = db.Get(ctx, serverVersionKey)
require.NoError(t, err)
assert.NotNil(t, r)
var sv databroker.ServerVersion
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
assert.Equal(t, srvVersion.String(), sv.Version)
})
t.Run("init version twice should get the same version", func(t *testing.T) {
srv := newServer(cfg)
ctx := context.Background()
db, _, err := srv.getDB(recordTypeServerVersion, false)
require.NoError(t, err)
r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)
assert.Nil(t, r)
srv.initVersion()
srvVersion := srv.version
r, err = db.Get(ctx, serverVersionKey)
require.NoError(t, err)
assert.NotNil(t, r)
var sv databroker.ServerVersion
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
assert.Equal(t, srvVersion, sv.Version)
// re-init version should get the same value as above
srv.version = "foo"
srv.initVersion()
assert.Equal(t, srvVersion, srv.version)
})
}
func TestServer_Get(t *testing.T) {
cfg := newServerConfig()
t.Run("ignore deleted", func(t *testing.T) {
@ -107,15 +33,22 @@ func TestServer_Get(t *testing.T) {
any, err := anypb.New(s)
assert.NoError(t, err)
srv.Set(context.Background(), &databroker.SetRequest{
Type: any.TypeUrl,
Id: s.Id,
Data: any,
_, err = srv.Put(context.Background(), &databroker.PutRequest{
Record: &databroker.Record{
Type: any.TypeUrl,
Id: s.Id,
Data: any,
},
})
srv.Delete(context.Background(), &databroker.DeleteRequest{
Type: any.TypeUrl,
Id: s.Id,
assert.NoError(t, err)
_, err = srv.Put(context.Background(), &databroker.PutRequest{
Record: &databroker.Record{
Type: any.TypeUrl,
Id: s.Id,
DeletedAt: timestamppb.Now(),
},
})
assert.NoError(t, err)
_, err = srv.Get(context.Background(), &databroker.GetRequest{
Type: any.TypeUrl,
Id: s.Id,
@ -124,61 +57,3 @@ func TestServer_Get(t *testing.T) {
assert.Equal(t, codes.NotFound, status.Code(err))
})
}
func TestServer_GetAll(t *testing.T) {
cfg := newServerConfig(
WithGetAllPageSize(5),
)
t.Run("ignore deleted", func(t *testing.T) {
srv := newServer(cfg)
s := new(session.Session)
s.Id = "1"
any, err := anypb.New(s)
assert.NoError(t, err)
srv.Set(context.Background(), &databroker.SetRequest{
Type: any.TypeUrl,
Id: s.Id,
Data: any,
})
srv.Delete(context.Background(), &databroker.DeleteRequest{
Type: any.TypeUrl,
Id: s.Id,
})
res, err := srv.GetAll(context.Background(), &databroker.GetAllRequest{
Type: any.TypeUrl,
})
assert.NoError(t, err)
assert.Len(t, res.GetRecords(), 0)
})
t.Run("paging", func(t *testing.T) {
srv := newServer(cfg)
any, err := anypb.New(wrapperspb.String("TEST"))
assert.NoError(t, err)
for i := 0; i < 7; i++ {
srv.Set(context.Background(), &databroker.SetRequest{
Type: any.TypeUrl,
Id: fmt.Sprint(i),
Data: any,
})
}
res, err := srv.GetAll(context.Background(), &databroker.GetAllRequest{
Type: any.TypeUrl,
})
assert.NoError(t, err)
assert.Len(t, res.GetRecords(), 5)
assert.Equal(t, res.GetNextPageToken(), "000000000005")
res, err = srv.GetAll(context.Background(), &databroker.GetAllRequest{
Type: any.TypeUrl,
PageToken: res.GetNextPageToken(),
})
assert.NoError(t, err)
assert.Len(t, res.GetRecords(), 2)
assert.Equal(t, res.GetNextPageToken(), "")
})
}

View file

@ -7,14 +7,14 @@ import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/google/btree"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/identity/identity"
@ -23,6 +23,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
)
const (
@ -37,13 +38,8 @@ type Authenticator interface {
}
type (
sessionMessage struct {
record *databroker.Record
session *session.Session
}
userMessage struct {
record *databroker.Record
user *user.User
updateRecordsMessage struct {
records []*databroker.Record
}
)
@ -52,19 +48,13 @@ type Manager struct {
cfg *atomicConfig
log zerolog.Logger
sessions sessionCollection
sessionScheduler *scheduler.Scheduler
userScheduler *scheduler.Scheduler
users userCollection
userScheduler *scheduler.Scheduler
directoryUsers map[string]*directory.User
directoryUsersServerVersion string
directoryUsersRecordVersion string
directoryGroups map[string]*directory.Group
directoryGroupsServerVersion string
directoryGroupsRecordVersion string
sessions sessionCollection
users userCollection
directoryUsers map[string]*directory.User
directoryGroups map[string]*directory.Group
directoryNextRefresh time.Time
@ -79,14 +69,8 @@ func New(
cfg: newAtomicConfig(newConfig()),
log: log.With().Str("service", "identity_manager").Logger(),
sessions: sessionCollection{
BTree: btree.New(8),
},
sessionScheduler: scheduler.New(),
users: userCollection{
BTree: btree.New(8),
},
userScheduler: scheduler.New(),
userScheduler: scheduler.New(),
dataBrokerSemaphore: semaphore.NewWeighted(dataBrokerParallelism),
}
@ -101,52 +85,47 @@ func (mgr *Manager) UpdateConfig(options ...Option) {
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) Run(ctx context.Context) error {
err := mgr.initDirectoryGroups(ctx)
if err != nil {
return fmt.Errorf("failed to initialize directory groups: %w", err)
}
update := make(chan updateRecordsMessage, 1)
clear := make(chan struct{}, 1)
err = mgr.initDirectoryUsers(ctx)
if err != nil {
return fmt.Errorf("failed to initialize directory users: %w", err)
}
syncer := newDataBrokerSyncer(mgr.cfg, mgr.log, update, clear)
eg, ctx := errgroup.WithContext(ctx)
updatedSession := make(chan sessionMessage, 1)
eg.Go(func() error {
return mgr.syncSessions(ctx, updatedSession)
return syncer.Run(ctx)
})
updatedUser := make(chan userMessage, 1)
eg.Go(func() error {
return mgr.syncUsers(ctx, updatedUser)
})
updatedDirectoryGroup := make(chan *directory.Group, 1)
eg.Go(func() error {
return mgr.syncDirectoryGroups(ctx, updatedDirectoryGroup)
})
updatedDirectoryUser := make(chan *directory.User, 1)
eg.Go(func() error {
return mgr.syncDirectoryUsers(ctx, updatedDirectoryUser)
})
eg.Go(func() error {
return mgr.refreshLoop(ctx, updatedSession, updatedUser, updatedDirectoryUser, updatedDirectoryGroup)
return mgr.refreshLoop(ctx, update, clear)
})
return eg.Wait()
}
func (mgr *Manager) refreshLoop(
ctx context.Context,
updatedSession <-chan sessionMessage,
updatedUser <-chan userMessage,
updatedDirectoryUser <-chan *directory.User,
updatedDirectoryGroup <-chan *directory.Group,
) error {
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
// wait for initial sync
select {
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.directoryGroups = make(map[string]*directory.Group)
mgr.directoryUsers = make(map[string]*directory.User)
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
}
select {
case <-ctx.Done():
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
}
mgr.log.Info().
Int("directory_groups", len(mgr.directoryGroups)).
Int("directory_users", len(mgr.directoryUsers)).
Int("sessions", mgr.sessions.Len()).
Int("users", mgr.users.Len()).
Msg("initial sync complete")
// start refreshing
maxWait := time.Minute * 10
nextTime := time.Now().Add(maxWait)
if mgr.directoryNextRefresh.Before(nextTime) {
@ -160,14 +139,13 @@ func (mgr *Manager) refreshLoop(
select {
case <-ctx.Done():
return ctx.Err()
case s := <-updatedSession:
mgr.onUpdateSession(ctx, s)
case u := <-updatedUser:
mgr.onUpdateUser(ctx, u)
case du := <-updatedDirectoryUser:
mgr.onUpdateDirectoryUser(ctx, du)
case dg := <-updatedDirectoryGroup:
mgr.onUpdateDirectoryGroup(ctx, dg)
case <-clear:
mgr.directoryGroups = make(map[string]*directory.Group)
mgr.directoryUsers = make(map[string]*directory.User)
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
case <-timer.C:
}
@ -249,7 +227,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
curDG, ok := mgr.directoryGroups[groupID]
if !ok || !proto.Equal(newDG, curDG) {
id := newDG.GetId()
any, err := ptypes.MarshalAny(newDG)
any, err := anypb.New(newDG)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
return
@ -260,10 +238,12 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
}
defer mgr.dataBrokerSemaphore.Release(1)
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
_, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
},
})
if err != nil {
return fmt.Errorf("failed to update directory group: %s", id)
@ -277,7 +257,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
_, ok := lookup[groupID]
if !ok {
id := curDG.GetId()
any, err := ptypes.MarshalAny(curDG)
any, err := anypb.New(curDG)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
return
@ -288,9 +268,12 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
}
defer mgr.dataBrokerSemaphore.Release(1)
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
Type: any.GetTypeUrl(),
Id: id,
_, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
DeletedAt: timestamppb.Now(),
},
})
if err != nil {
return fmt.Errorf("failed to delete directory group: %s", id)
@ -299,6 +282,10 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
})
}
}
if err := eg.Wait(); err != nil {
mgr.log.Warn().Err(err).Msg("manager: failed to merge groups")
}
}
func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.User) {
@ -313,7 +300,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
curDU, ok := mgr.directoryUsers[userID]
if !ok || !proto.Equal(newDU, curDU) {
id := newDU.GetId()
any, err := ptypes.MarshalAny(newDU)
any, err := anypb.New(newDU)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
return
@ -325,10 +312,12 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
defer mgr.dataBrokerSemaphore.Release(1)
client := mgr.cfg.Load().dataBrokerClient
if _, err := client.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
if _, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
},
}); err != nil {
return fmt.Errorf("failed to update directory user: %s", id)
}
@ -341,7 +330,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
_, ok := lookup[userID]
if !ok {
id := curDU.GetId()
any, err := ptypes.MarshalAny(curDU)
any, err := anypb.New(curDU)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
return
@ -353,9 +342,13 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
defer mgr.dataBrokerSemaphore.Release(1)
client := mgr.cfg.Load().dataBrokerClient
if _, err := client.Delete(ctx, &databroker.DeleteRequest{
Type: any.GetTypeUrl(),
Id: id,
if _, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
DeletedAt: timestamppb.Now(),
},
}); err != nil {
return fmt.Errorf("failed to delete directory user (%s): %w", id, err)
}
@ -384,8 +377,8 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
return
}
expiry, err := ptypes.Timestamp(s.GetExpiresAt())
if err == nil && !expiry.After(time.Now()) {
expiry := s.GetExpiresAt().AsTime()
if !expiry.After(time.Now()) {
mgr.log.Info().
Str("user_id", userID).
Str("session_id", sessionID).
@ -435,7 +428,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
return
}
res, err := session.Set(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
if err != nil {
mgr.log.Error().Err(err).
Str("user_id", s.GetUserId()).
@ -444,7 +437,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
return
}
mgr.onUpdateSession(ctx, sessionMessage{record: res.GetRecord(), session: s.Session})
mgr.onUpdateSession(ctx, res.GetRecord(), s.Session)
}
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
@ -486,7 +479,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
continue
}
record, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil {
mgr.log.Error().Err(err).
Str("user_id", s.GetUserId()).
@ -495,257 +488,78 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
continue
}
mgr.onUpdateUser(ctx, userMessage{record: record, user: u.User})
mgr.onUpdateUser(ctx, record, u.User)
}
}
func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- sessionMessage) error {
mgr.log.Info().Msg("syncing sessions")
any, err := ptypes.MarshalAny(new(session.Session))
if err != nil {
return err
}
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
return fmt.Errorf("error syncing sessions: %w", err)
}
for {
res, err := client.Recv()
if err != nil {
return fmt.Errorf("error receiving sessions: %w", err)
}
for _, record := range res.GetRecords() {
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
for _, record := range msg.records {
switch record.GetType() {
case grpcutil.GetTypeURL(new(directory.Group)):
var pbDirectoryGroup directory.Group
err := record.GetData().UnmarshalTo(&pbDirectoryGroup)
if err != nil {
mgr.log.Warn().Msgf("error unmarshaling directory group: %s", err)
continue
}
mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup)
case grpcutil.GetTypeURL(new(directory.User)):
var pbDirectoryUser directory.User
err := record.GetData().UnmarshalTo(&pbDirectoryUser)
if err != nil {
mgr.log.Warn().Msgf("error unmarshaling directory user: %s", err)
continue
}
mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser)
case grpcutil.GetTypeURL(new(session.Session)):
var pbSession session.Session
err := ptypes.UnmarshalAny(record.GetData(), &pbSession)
err := record.GetData().UnmarshalTo(&pbSession)
if err != nil {
return fmt.Errorf("error unmarshaling session: %w", err)
mgr.log.Warn().Msgf("error unmarshaling session: %s", err)
continue
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- sessionMessage{record: record, session: &pbSession}:
}
}
}
}
func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- userMessage) error {
mgr.log.Info().Msg("syncing users")
any, err := ptypes.MarshalAny(new(user.User))
if err != nil {
return err
}
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
return fmt.Errorf("error syncing users: %w", err)
}
for {
res, err := client.Recv()
if err != nil {
return fmt.Errorf("error receiving users: %w", err)
}
for _, record := range res.GetRecords() {
mgr.onUpdateSession(ctx, record, &pbSession)
case grpcutil.GetTypeURL(new(user.User)):
var pbUser user.User
err := ptypes.UnmarshalAny(record.GetData(), &pbUser)
err := record.GetData().UnmarshalTo(&pbUser)
if err != nil {
return fmt.Errorf("error unmarshaling user: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- userMessage{record: record, user: &pbUser}:
mgr.log.Warn().Msgf("error unmarshaling user: %s", err)
continue
}
}
}
}
func (mgr *Manager) initDirectoryUsers(ctx context.Context) error {
mgr.log.Info().Msg("initializing directory users")
func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
any, err := ptypes.MarshalAny(new(directory.User))
if err != nil {
return err
}
return exponentialTry(ctx, func() error {
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
return fmt.Errorf("error getting all directory users: %w", err)
}
mgr.directoryUsers = map[string]*directory.User{}
for _, record := range res.GetRecords() {
var pbDirectoryUser directory.User
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
if err != nil {
return fmt.Errorf("error unmarshaling directory user: %w", err)
}
mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser
mgr.directoryUsersRecordVersion = record.GetVersion()
}
mgr.directoryUsersServerVersion = res.GetServerVersion()
mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users")
return nil
})
}
func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory.User) error {
mgr.log.Info().Msg("syncing directory users")
any, err := ptypes.MarshalAny(new(directory.User))
if err != nil {
return err
}
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
ServerVersion: mgr.directoryUsersServerVersion,
RecordVersion: mgr.directoryUsersRecordVersion,
})
if err != nil {
return fmt.Errorf("error syncing directory users: %w", err)
}
for {
res, err := client.Recv()
if err != nil {
return fmt.Errorf("error receiving directory users: %w", err)
}
for _, record := range res.GetRecords() {
var pbDirectoryUser directory.User
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
if err != nil {
return fmt.Errorf("error unmarshaling directory user: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- &pbDirectoryUser:
}
}
}
}
func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
mgr.log.Info().Msg("initializing directory groups")
any, err := ptypes.MarshalAny(new(directory.Group))
if err != nil {
return err
}
return exponentialTry(ctx, func() error {
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
return fmt.Errorf("error getting all directory groups: %w", err)
}
mgr.directoryGroups = map[string]*directory.Group{}
for _, record := range res.GetRecords() {
var pbDirectoryGroup directory.Group
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
if err != nil {
return fmt.Errorf("error unmarshaling directory group: %w", err)
}
mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup
mgr.directoryGroupsRecordVersion = record.GetVersion()
}
mgr.directoryGroupsServerVersion = res.GetServerVersion()
mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups")
return nil
})
}
func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error {
mgr.log.Info().Msg("syncing directory groups")
any, err := ptypes.MarshalAny(new(directory.Group))
if err != nil {
return err
}
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
ServerVersion: mgr.directoryGroupsServerVersion,
RecordVersion: mgr.directoryGroupsRecordVersion,
})
if err != nil {
return fmt.Errorf("error syncing directory groups: %w", err)
}
for {
res, err := client.Recv()
if err != nil {
return fmt.Errorf("error receiving directory groups: %w", err)
}
for _, record := range res.GetRecords() {
var pbDirectoryGroup directory.Group
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
if err != nil {
return fmt.Errorf("error unmarshaling directory group: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- &pbDirectoryGroup:
}
}
}
}
func (mgr *Manager) onUpdateSession(_ context.Context, msg sessionMessage) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
if msg.record.GetDeletedAt() != nil {
mgr.sessions.Delete(msg.session.GetUserId(), msg.session.GetId())
if record.GetDeletedAt() != nil {
mgr.sessions.Delete(session.GetUserId(), session.GetId())
return
}
// update session
s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId())
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
s.lastRefresh = time.Now()
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
s.Session = msg.session
s.Session = session
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId()))
}
func (mgr *Manager) onUpdateUser(_ context.Context, msg userMessage) {
mgr.userScheduler.Remove(msg.user.GetId())
func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) {
mgr.userScheduler.Remove(user.GetId())
if msg.record.GetDeletedAt() != nil {
mgr.users.Delete(msg.user.GetId())
if record.GetDeletedAt() != nil {
mgr.users.Delete(user.GetId())
return
}
u, _ := mgr.users.Get(msg.user.GetId())
u, _ := mgr.users.Get(user.GetId())
u.lastRefresh = time.Now()
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
u.User = msg.user
u.User = user
mgr.users.ReplaceOrInsert(u)
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
}
@ -779,22 +593,3 @@ func isTemporaryError(err error) bool {
}
return false
}
// exponentialTry executes f until it succeeds or ctx is Done.
func exponentialTry(ctx context.Context, f func() error) error {
backoff := backoff.NewExponentialBackOff()
backoff.MaxElapsedTime = 0
for {
err := f()
if err == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff.NextBackOff()):
}
continue
}
}

View file

@ -0,0 +1,58 @@
package manager
import (
"context"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type dataBrokerSyncer struct {
cfg *atomicConfig
log zerolog.Logger
update chan<- updateRecordsMessage
clear chan<- struct{}
syncer *databroker.Syncer
}
func newDataBrokerSyncer(
cfg *atomicConfig,
log zerolog.Logger,
update chan<- updateRecordsMessage,
clear chan<- struct{},
) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{
cfg: cfg,
log: log,
update: update,
clear: clear,
}
syncer.syncer = databroker.NewSyncer(syncer)
return syncer
}
func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) {
return syncer.syncer.Run(ctx)
}
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
select {
case <-ctx.Done():
case syncer.clear <- struct{}{}:
}
}
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return syncer.cfg.Load().dataBrokerClient
}
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
select {
case <-ctx.Done():
case syncer.update <- updateRecordsMessage{records: records}:
}
}

View file

@ -0,0 +1,2 @@
// Package registry implements a service registry server.
package registry

View file

@ -9,40 +9,23 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/wrapperspb"
)
var statsHandler = &ocgrpc.ServerHandler{}
type testProto struct {
message string
}
func (t testProto) Reset() {}
func (t testProto) ProtoMessage() {}
func (t testProto) String() string {
return t.message
}
func (t testProto) XXX_Size() int {
return len([]byte(t.message))
}
func (t testProto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return []byte(t.message), nil
}
type testInvoker struct {
invokeResult error
statsHandler stats.Handler
}
func (t testInvoker) UnaryInvoke(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
r := reply.(*testProto)
r.message = "hello"
r := reply.(*wrapperspb.StringValue)
r.Value = "hello"
ctx = t.statsHandler.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
t.statsHandler.HandleRPC(ctx, &stats.InPayload{Client: true, Length: len(r.message)})
t.statsHandler.HandleRPC(ctx, &stats.OutPayload{Client: true, Length: len(r.message)})
t.statsHandler.HandleRPC(ctx, &stats.InPayload{Client: true, Length: len(r.Value)})
t.statsHandler.HandleRPC(ctx, &stats.OutPayload{Client: true, Length: len(r.Value)})
t.statsHandler.HandleRPC(ctx, &stats.End{Client: true, Error: t.invokeResult})
return t.invokeResult
@ -106,7 +89,7 @@ func Test_GRPCClientInterceptor(t *testing.T) {
invokeResult: tt.errorCode,
statsHandler: &ocgrpc.ClientHandler{},
}
var reply testProto
var reply wrapperspb.StringValue
interceptor(context.Background(), tt.method, nil, &reply, newTestCC(t), invoker.UnaryInvoke)

105
internal/testutil/redis.go Normal file
View file

@ -0,0 +1,105 @@
package testutil
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"path/filepath"
"time"
"github.com/go-redis/redis/v8"
"github.com/ory/dockertest/v3"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
const maxWait = time.Minute
// WithTestRedis creates a test a test redis instance using docker.
func WithTestRedis(useTLS bool, handler func(rawURL string) error) error {
ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait)
defer clearTimeout()
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
pool, err := dockertest.NewPool("")
if err != nil {
return err
}
opts := &dockertest.RunOptions{
Repository: "redis",
Tag: "6",
}
scheme := "redis"
if useTLS {
opts.Mounts = []string{
filepath.Join(TestDataRoot(), "tls") + ":/tls",
}
opts.Cmd = []string{
"--port", "0",
"--tls-port", "6379",
"--tls-cert-file", "/tls/redis.crt",
"--tls-key-file", "/tls/redis.key",
"--tls-ca-cert-file", "/tls/ca.crt",
}
scheme = "rediss"
}
resource, err := pool.RunWithOptions(opts)
if err != nil {
return err
}
_ = resource.Expire(uint(maxWait.Seconds()))
redisURL := fmt.Sprintf("%s://%s/0", scheme, resource.GetHostPort("6379/tcp"))
if err := pool.Retry(func() error {
options, err := redis.ParseURL(redisURL)
if err != nil {
return err
}
if useTLS {
options.TLSConfig = RedisTLSConfig()
}
client := redis.NewClient(options)
defer client.Close()
return client.Ping(ctx).Err()
}); err != nil {
_ = pool.Purge(resource)
return err
}
e := handler(redisURL)
if err := pool.Purge(resource); err != nil {
return err
}
return e
}
// RedisTLSConfig returns the TLS Config to use with redis.
func RedisTLSConfig() *tls.Config {
cert, err := cryptutil.CertificateFromFile(
filepath.Join(TestDataRoot(), "tls", "redis.crt"),
filepath.Join(TestDataRoot(), "tls", "redis.key"),
)
if err != nil {
panic(err)
}
caCertPool := x509.NewCertPool()
caCert, err := ioutil.ReadFile(filepath.Join(TestDataRoot(), "tls", "ca.crt"))
if err != nil {
panic(err)
}
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig := &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{*cert},
MinVersion: tls.VersionTLS12,
}
return tlsConfig
}

View file

@ -3,6 +3,8 @@ package testutil
import (
"encoding/json"
"os"
"path/filepath"
"reflect"
"testing"
@ -32,3 +34,28 @@ func toProtoJSON(protoMsg interface{}) json.RawMessage {
bs, _ := protojson.Marshal(v2)
return bs
}
// ModRoot returns the directory containing the go.mod file.
func ModRoot() string {
dir, err := os.Getwd()
if err != nil {
panic("error getting working directory")
}
for {
if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() {
return dir
}
d := filepath.Dir(dir)
if d == dir {
break
}
dir = d
}
return ""
}
// TestDataRoot returns the testdata directory.
func TestDataRoot() string {
return filepath.Join(ModRoot(), "internal", "testutil", "testdata")
}

View file

@ -3,6 +3,7 @@ package cryptutil
import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
)
// DefaultKeySize is the default key size in bytes.
@ -29,6 +30,13 @@ func NewRandomStringN(c int) string {
return base64.StdEncoding.EncodeToString(randomBytes(c))
}
// NewRandomUInt64 returns a random uint64.
//
// Panics if source of randomness fails.
func NewRandomUInt64() uint64 {
return binary.LittleEndian.Uint64(randomBytes(8))
}
// randomBytes generates C number of random bytes suitable for cryptographic
// operations.
//

View file

@ -1,3 +1,4 @@
// Package config contains protobuf definitions for config.
package config
// IsSet returns true if one of the route redirect options has been chosen.

View file

@ -3,10 +3,9 @@ package databroker
import (
"context"
"fmt"
"io"
"strings"
"google.golang.org/protobuf/proto"
)
// GetUserID gets the databroker user id from a provider user id.
@ -37,19 +36,17 @@ func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, t
return records, len(all)
}
// InitialSync performs a sync with no_wait set to true and then returns all the results.
func InitialSync(ctx context.Context, client DataBrokerServiceClient, in *SyncRequest) (*SyncResponse, error) {
dup := new(SyncRequest)
proto.Merge(dup, in)
dup.NoWait = true
stream, err := client.Sync(ctx, dup)
// InitialSync performs a sync latest and then returns all the results.
func InitialSync(
ctx context.Context,
client DataBrokerServiceClient,
req *SyncLatestRequest,
) (records []*Record, recordVersion, serverVersion uint64, err error) {
stream, err := client.SyncLatest(ctx, req)
if err != nil {
return nil, err
return nil, 0, 0, err
}
finalRes := &SyncResponse{}
loop:
for {
res, err := stream.Recv()
@ -57,12 +54,19 @@ loop:
case err == io.EOF:
break loop
case err != nil:
return nil, err
return nil, 0, 0, err
}
finalRes.ServerVersion = res.GetServerVersion()
finalRes.Records = append(finalRes.Records, res.GetRecords()...)
switch res := res.GetResponse().(type) {
case *SyncLatestResponse_Versions:
recordVersion = res.Versions.GetLatestRecordVersion()
serverVersion = res.Versions.GetServerVersion()
case *SyncLatestResponse_Record:
records = append(records, res.Record)
default:
panic(fmt.Sprintf("unexpected response: %T", res))
}
}
return finalRes, nil
return records, recordVersion, serverVersion, nil
}

File diff suppressed because it is too large Load diff

View file

@ -4,46 +4,27 @@ package databroker;
option go_package = "github.com/pomerium/pomerium/pkg/grpc/databroker";
import "google/protobuf/any.proto";
import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto";
message ServerVersion {
string version = 1;
}
message Record {
string version = 1;
uint64 version = 1;
string type = 2;
string id = 3;
google.protobuf.Any data = 4;
google.protobuf.Timestamp created_at = 5;
google.protobuf.Timestamp modified_at = 6;
google.protobuf.Timestamp deleted_at = 7;
google.protobuf.Timestamp modified_at = 5;
google.protobuf.Timestamp deleted_at = 6;
}
message DeleteRequest {
string type = 1;
string id = 2;
message Versions {
// the server version indicates the version of the server storing the data
uint64 server_version = 1;
uint64 latest_record_version = 2;
}
message GetRequest {
string type = 1;
string id = 2;
}
message GetResponse {
Record record = 1;
}
message GetAllRequest {
string type = 1;
string page_token = 2;
}
message GetAllResponse {
repeated Record records = 1;
string server_version = 2;
string record_version = 3;
string next_page_token = 4;
}
message GetResponse { Record record = 1; }
message QueryRequest {
string type = 1;
@ -56,39 +37,39 @@ message QueryResponse {
int64 total_count = 2;
}
message SetRequest {
string type = 1;
string id = 2;
google.protobuf.Any data = 3;
}
message SetResponse {
Record record = 1;
string server_version = 2;
message PutRequest { Record record = 1; }
message PutResponse {
uint64 server_version = 1;
Record record = 2;
}
message SyncRequest {
string server_version = 1;
string record_version = 2;
string type = 3;
bool no_wait = 4;
uint64 server_version = 1;
uint64 record_version = 2;
}
message SyncResponse {
string server_version = 1;
repeated Record records = 2;
uint64 server_version = 1;
Record record = 2;
}
message GetTypesResponse {
repeated string types = 1;
message SyncLatestRequest { string type = 1; }
message SyncLatestResponse {
oneof response {
Record record = 1;
Versions versions = 2;
}
}
// The DataBrokerService stores key-value data.
service DataBrokerService {
rpc Delete(DeleteRequest) returns (google.protobuf.Empty);
// Get gets a record.
rpc Get(GetRequest) returns (GetResponse);
rpc GetAll(GetAllRequest) returns (GetAllResponse);
// Put saves a record.
rpc Put(PutRequest) returns (PutResponse);
// Query queries for records.
rpc Query(QueryRequest) returns (QueryResponse);
rpc Set(SetRequest) returns (SetResponse);
// Sync streams changes to records after the specified version.
rpc Sync(SyncRequest) returns (stream SyncResponse);
rpc GetTypes(google.protobuf.Empty) returns (GetTypesResponse);
rpc SyncTypes(google.protobuf.Empty) returns (stream GetTypesResponse);
// SyncLatest streams the latest version of every record.
rpc SyncLatest(SyncLatestRequest) returns (stream SyncLatestResponse);
}

View file

@ -61,18 +61,26 @@ func TestInitialSync(t *testing.T) {
r1 := new(Record)
r2 := new(Record)
r3 := new(Record)
m := &mockServer{
sync: func(req *SyncRequest, stream DataBrokerService_SyncServer) error {
assert.Equal(t, true, req.GetNoWait())
stream.Send(&SyncResponse{
ServerVersion: "a",
Records: []*Record{r1, r2},
syncLatest: func(req *SyncLatestRequest, stream DataBrokerService_SyncLatestServer) error {
stream.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Record{
Record: r1,
},
})
stream.Send(&SyncResponse{
ServerVersion: "b",
Records: []*Record{r3},
stream.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Record{
Record: r2,
},
})
stream.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Versions{
Versions: &Versions{
LatestRecordVersion: 2,
ServerVersion: 1,
},
},
})
return nil
},
@ -90,20 +98,19 @@ func TestInitialSync(t *testing.T) {
c := NewDataBrokerServiceClient(cc)
res, err := InitialSync(ctx, c, &SyncRequest{
Type: "TEST",
})
records, recordVersion, serverVersion, err := InitialSync(ctx, c, new(SyncLatestRequest))
assert.NoError(t, err)
assert.Equal(t, "b", res.GetServerVersion())
assert.Equal(t, []*Record{r1, r2, r3}, res.GetRecords())
assert.Equal(t, uint64(2), recordVersion)
assert.Equal(t, uint64(1), serverVersion)
assert.Equal(t, []*Record{r1, r2}, records)
}
type mockServer struct {
DataBrokerServiceServer
sync func(*SyncRequest, DataBrokerService_SyncServer) error
syncLatest func(empty *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error
}
func (m *mockServer) Sync(req *SyncRequest, stream DataBrokerService_SyncServer) error {
return m.sync(req, stream)
func (m *mockServer) SyncLatest(req *SyncLatestRequest, stream DataBrokerService_SyncLatestServer) error {
return m.syncLatest(req, stream)
}

View file

@ -0,0 +1,171 @@
package databroker
import (
"context"
"fmt"
"time"
backoff "github.com/cenkalti/backoff/v4"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/internal/log"
)
type syncerConfig struct {
typeURL string
}
// A SyncerOption customizes the syncer configuration.
type SyncerOption func(cfg *syncerConfig)
func getSyncerConfig(options ...SyncerOption) *syncerConfig {
cfg := new(syncerConfig)
for _, option := range options {
option(cfg)
}
return cfg
}
// WithTypeURL restricts the sync'd results to the given type.
func WithTypeURL(typeURL string) SyncerOption {
return func(cfg *syncerConfig) {
cfg.typeURL = typeURL
}
}
// A SyncerHandler receives sync events from the Syncer.
type SyncerHandler interface {
GetDataBrokerServiceClient() DataBrokerServiceClient
ClearRecords(ctx context.Context)
UpdateRecords(ctx context.Context, records []*Record)
}
// A Syncer is a helper type for working with Sync and SyncLatest. It will make a call to
// SyncLatest to retrieve the latest version of the data, then begin syncing with a call
// to Sync. If the server version changes `ClearRecords` will be called and the process
// will start over.
type Syncer struct {
cfg *syncerConfig
handler SyncerHandler
backoff *backoff.ExponentialBackOff
recordVersion uint64
serverVersion uint64
closeCtx context.Context
closeCtxCancel func()
}
// NewSyncer creates a new Syncer.
func NewSyncer(handler SyncerHandler, options ...SyncerOption) *Syncer {
closeCtx, closeCtxCancel := context.WithCancel(context.Background())
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
return &Syncer{
cfg: getSyncerConfig(options...),
handler: handler,
backoff: bo,
closeCtx: closeCtx,
closeCtxCancel: closeCtxCancel,
}
}
// Close closes the Syncer.
func (syncer *Syncer) Close() error {
syncer.closeCtxCancel()
return nil
}
// Run runs the Syncer.
func (syncer *Syncer) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
go func() {
<-syncer.closeCtx.Done()
cancel()
}()
for {
var err error
if syncer.serverVersion == 0 {
err = syncer.init(ctx)
} else {
err = syncer.sync(ctx)
}
if err != nil {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(syncer.backoff.NextBackOff()):
}
}
}
}
func (syncer *Syncer) init(ctx context.Context) error {
syncer.log().Info().Msg("syncing latest records")
records, recordVersion, serverVersion, err := InitialSync(ctx, syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{
Type: syncer.cfg.typeURL,
})
if err != nil {
syncer.log().Error().Err(err).Msg("error during initial sync")
return err
}
syncer.backoff.Reset()
// reset the records as we have to sync latest
syncer.handler.ClearRecords(ctx)
syncer.recordVersion = recordVersion
syncer.serverVersion = serverVersion
syncer.handler.UpdateRecords(ctx, records)
return nil
}
func (syncer *Syncer) sync(ctx context.Context) error {
stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(ctx, &SyncRequest{
ServerVersion: syncer.serverVersion,
RecordVersion: syncer.recordVersion,
})
if err != nil {
syncer.log().Error().Err(err).Msg("error during sync")
return err
}
for {
res, err := stream.Recv()
if status.Code(err) == codes.Aborted {
syncer.log().Error().Err(err).Msg("aborted sync due to mismatched server version")
// server version changed, so re-init
syncer.serverVersion = 0
return nil
} else if err != nil {
return err
}
if syncer.recordVersion != res.GetRecord().GetVersion()-1 {
syncer.log().Error().Err(err).
Uint64("received", res.GetRecord().GetVersion()).
Msg("aborted sync due to missing record")
syncer.serverVersion = 0
return fmt.Errorf("missing record version")
}
syncer.recordVersion = res.GetRecord().GetVersion()
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
syncer.handler.UpdateRecords(ctx, []*Record{res.GetRecord()})
}
}
}
func (syncer *Syncer) log() *zerolog.Logger {
l := log.With().Str("service", "syncer").
Str("type", syncer.cfg.typeURL).
Uint64("server_version", syncer.serverVersion).
Uint64("record_version", syncer.recordVersion).Logger()
return &l
}

View file

@ -0,0 +1,222 @@
package databroker
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"github.com/pomerium/pomerium/internal/testutil"
)
type testSyncerHandler struct {
getDataBrokerServiceClient func() DataBrokerServiceClient
clearRecords func(ctx context.Context)
updateRecords func(ctx context.Context, records []*Record)
}
func (t testSyncerHandler) GetDataBrokerServiceClient() DataBrokerServiceClient {
return t.getDataBrokerServiceClient()
}
func (t testSyncerHandler) ClearRecords(ctx context.Context) {
t.clearRecords(ctx)
}
func (t testSyncerHandler) UpdateRecords(ctx context.Context, records []*Record) {
t.updateRecords(ctx, records)
}
type testServer struct {
DataBrokerServiceServer
sync func(request *SyncRequest, server DataBrokerService_SyncServer) error
syncLatest func(req *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error
}
func (t testServer) Sync(request *SyncRequest, server DataBrokerService_SyncServer) error {
return t.sync(request, server)
}
func (t testServer) SyncLatest(req *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error {
return t.syncLatest(req, server)
}
func TestSyncer(t *testing.T) {
ctx := context.Background()
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*10)
defer clearTimeout()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
lis := bufconn.Listen(1)
r1 := &Record{Version: 1000, Id: "r1"}
r2 := &Record{Version: 1001, Id: "r2"}
r3 := &Record{Version: 1002, Id: "r3"}
r5 := &Record{Version: 1004, Id: "r5"}
syncCount := 0
syncLatestCount := 0
gs := grpc.NewServer()
RegisterDataBrokerServiceServer(gs, testServer{
sync: func(request *SyncRequest, server DataBrokerService_SyncServer) error {
syncCount++
switch syncCount {
case 1:
return status.Error(codes.Internal, "SOME INTERNAL ERROR")
case 2:
return status.Error(codes.Aborted, "ABORTED")
case 3:
_ = server.Send(&SyncResponse{
ServerVersion: 2001,
Record: r3,
})
_ = server.Send(&SyncResponse{
ServerVersion: 2001,
Record: r5,
})
case 4:
select {} // block forever
default:
t.Fatal("unexpected call to sync", request)
}
return nil
},
syncLatest: func(req *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error {
syncLatestCount++
switch syncLatestCount {
case 1:
_ = server.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Record{
Record: r1,
},
})
_ = server.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Versions{
Versions: &Versions{
LatestRecordVersion: r1.Version,
ServerVersion: 2000,
},
},
})
case 2:
_ = server.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Record{
Record: r2,
},
})
_ = server.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Versions{
Versions: &Versions{
LatestRecordVersion: r2.Version,
ServerVersion: 2001,
},
},
})
case 3:
return status.Error(codes.Internal, "SOME INTERNAL ERROR")
case 4:
_ = server.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Record{
Record: r3,
},
})
_ = server.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Record{
Record: r5,
},
})
_ = server.Send(&SyncLatestResponse{
Response: &SyncLatestResponse_Versions{
Versions: &Versions{
LatestRecordVersion: r5.Version,
ServerVersion: 2001,
},
},
})
default:
t.Fatal("unexpected call to sync latest")
}
return nil
},
})
go func() { _ = gs.Serve(lis) }()
gc, err := grpc.DialContext(ctx, "bufnet",
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}),
grpc.WithInsecure())
require.NoError(t, err)
defer func() { _ = gc.Close() }()
clearCh := make(chan struct{})
updateCh := make(chan []*Record)
syncer := NewSyncer(testSyncerHandler{
getDataBrokerServiceClient: func() DataBrokerServiceClient {
return NewDataBrokerServiceClient(gc)
},
clearRecords: func(ctx context.Context) {
clearCh <- struct{}{}
},
updateRecords: func(ctx context.Context, records []*Record) {
updateCh <- records
},
})
go func() { _ = syncer.Run(ctx) }()
select {
case <-ctx.Done():
t.Fatal("1. expected call to clear records")
case <-clearCh:
}
select {
case <-ctx.Done():
t.Fatal("2. expected call to update records")
case records := <-updateCh:
testutil.AssertProtoJSONEqual(t, `[{"id": "r1", "version": "1000"}]`, records)
}
select {
case <-ctx.Done():
t.Fatal("3. expected call to clear records due to server version change")
case <-clearCh:
}
select {
case <-ctx.Done():
t.Fatal("4. expected call to update records")
case records := <-updateCh:
testutil.AssertProtoJSONEqual(t, `[{"id": "r2", "version": "1001"}]`, records)
}
select {
case <-ctx.Done():
t.Fatal("5. expected call to update records from sync")
case records := <-updateCh:
testutil.AssertProtoJSONEqual(t, `[{"id": "r3", "version": "1002"}]`, records)
}
select {
case <-ctx.Done():
t.Fatal("6. expected call to clear records due to skipped version")
case <-clearCh:
}
select {
case <-ctx.Done():
t.Fatal("7. expected call to update records")
case records := <-updateCh:
testutil.AssertProtoJSONEqual(t, `[{"id": "r3", "version": "1002"}, {"id": "r5", "version": "1004"}]`, records)
}
assert.NoError(t, syncer.Close())
}

View file

@ -17,9 +17,13 @@ import (
// Delete deletes a session from the databroker.
func Delete(ctx context.Context, client databroker.DataBrokerServiceClient, sessionID string) error {
any, _ := ptypes.MarshalAny(new(Session))
_, err := client.Delete(ctx, &databroker.DeleteRequest{
Type: any.GetTypeUrl(),
Id: sessionID,
_, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: sessionID,
Data: any,
DeletedAt: timestamppb.Now(),
},
})
return err
}
@ -44,13 +48,15 @@ func Get(ctx context.Context, client databroker.DataBrokerServiceClient, session
return &s, nil
}
// Set sets a session in the databroker.
func Set(ctx context.Context, client databroker.DataBrokerServiceClient, s *Session) (*databroker.SetResponse, error) {
// Put sets a session in the databroker.
func Put(ctx context.Context, client databroker.DataBrokerServiceClient, s *Session) (*databroker.PutResponse, error) {
any, _ := anypb.New(s)
res, err := client.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: s.Id,
Data: any,
res, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: s.Id,
Data: any,
},
})
return res, err
}

View file

@ -32,13 +32,15 @@ func Get(ctx context.Context, client databroker.DataBrokerServiceClient, userID
return &u, nil
}
// Set sets a user in the databroker.
func Set(ctx context.Context, client databroker.DataBrokerServiceClient, u *User) (*databroker.Record, error) {
// Put sets a user in the databroker.
func Put(ctx context.Context, client databroker.DataBrokerServiceClient, u *User) (*databroker.Record, error) {
any, _ := anypb.New(u)
res, err := client.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: u.Id,
Data: any,
res, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: u.Id,
Data: any,
},
})
if err != nil {
return nil, err
@ -46,13 +48,15 @@ func Set(ctx context.Context, client databroker.DataBrokerServiceClient, u *User
return res.GetRecord(), nil
}
// SetServiceAccount sets a service account in the databroker.
func SetServiceAccount(ctx context.Context, client databroker.DataBrokerServiceClient, sa *ServiceAccount) (*databroker.Record, error) {
// PutServiceAccount sets a service account in the databroker.
func PutServiceAccount(ctx context.Context, client databroker.DataBrokerServiceClient, sa *ServiceAccount) (*databroker.Record, error) {
any, _ := anypb.New(sa)
res, err := client.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: sa.GetId(),
Data: any,
res, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: sa.GetId(),
Data: any,
},
})
if err != nil {
return nil, err

View file

@ -5,6 +5,8 @@ import (
"context"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/protobuf/proto"
)
// SessionIDMetadataKey is the key in the metadata.
@ -52,3 +54,19 @@ func JWTFromGRPCRequest(ctx context.Context) (rawjwt string, ok bool) {
return rawjwts[0], true
}
// GetPeerAddr returns the peer address.
func GetPeerAddr(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if ok {
return p.Addr.String()
}
return ""
}
// GetTypeURL gets the TypeURL for a protobuf message.
func GetTypeURL(msg proto.Message) string {
// taken from the anypb package
const urlPrefix = "type.googleapis.com/"
return urlPrefix + string(msg.ProtoReflect().Descriptor().FullName())
}

View file

@ -12,9 +12,42 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type encryptedRecordStream struct {
underlying RecordStream
backend *encryptedBackend
err error
}
func (e *encryptedRecordStream) Close() error {
return e.underlying.Close()
}
func (e *encryptedRecordStream) Next(wait bool) bool {
return e.underlying.Next(wait)
}
func (e *encryptedRecordStream) Record() *databroker.Record {
r := e.underlying.Record()
if r != nil {
var err error
r, err = e.backend.decryptRecord(r)
if err != nil {
e.err = err
}
}
return r
}
func (e *encryptedRecordStream) Err() error {
if e.err == nil {
e.err = e.underlying.Err()
}
return e.err
}
type encryptedBackend struct {
Backend
cipher cipher.AEAD
underlying Backend
cipher cipher.AEAD
}
// NewEncryptedBackend creates a new encrypted backend.
@ -25,21 +58,17 @@ func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) {
}
return &encryptedBackend{
Backend: underlying,
cipher: c,
underlying: underlying,
cipher: c,
}, nil
}
func (e *encryptedBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
encrypted, err := e.encrypt(data)
if err != nil {
return err
}
return e.Backend.Put(ctx, id, encrypted)
func (e *encryptedBackend) Close() error {
return e.underlying.Close()
}
func (e *encryptedBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
record, err := e.Backend.Get(ctx, id)
func (e *encryptedBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) {
record, err := e.underlying.Get(ctx, recordType, id)
if err != nil {
return nil, err
}
@ -50,18 +79,41 @@ func (e *encryptedBackend) Get(ctx context.Context, id string) (*databroker.Reco
return record, nil
}
func (e *encryptedBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
records, err := e.Backend.List(ctx, sinceVersion)
func (e *encryptedBackend) GetAll(ctx context.Context) ([]*databroker.Record, uint64, error) {
records, version, err := e.underlying.GetAll(ctx)
if err != nil {
return nil, err
return nil, 0, err
}
for i := range records {
records[i], err = e.decryptRecord(records[i])
if err != nil {
return nil, err
return nil, 0, err
}
}
return records, nil
return records, version, nil
}
func (e *encryptedBackend) Put(ctx context.Context, record *databroker.Record) error {
encrypted, err := e.encrypt(record.GetData())
if err != nil {
return err
}
newRecord := proto.Clone(record).(*databroker.Record)
newRecord.Data = encrypted
return e.underlying.Put(ctx, newRecord)
}
func (e *encryptedBackend) Sync(ctx context.Context, version uint64) (RecordStream, error) {
stream, err := e.underlying.Sync(ctx, version)
if err != nil {
return nil, err
}
return &encryptedRecordStream{
underlying: stream,
backend: e,
}, nil
}
func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) {
@ -75,13 +127,16 @@ func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker
Type: data.TypeUrl,
Id: in.Id,
Data: data,
CreatedAt: in.CreatedAt,
ModifiedAt: in.ModifiedAt,
DeletedAt: in.DeletedAt,
}, nil
}
func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) {
if in == nil {
return nil, nil
}
var encrypted wrapperspb.BytesValue
err = in.UnmarshalTo(&encrypted)
if err != nil {

View file

@ -18,11 +18,11 @@ func TestEncryptedBackend(t *testing.T) {
m := map[string]*anypb.Any{}
backend := &mockBackend{
put: func(ctx context.Context, id string, data *anypb.Any) error {
m[id] = data
put: func(ctx context.Context, record *databroker.Record) error {
m[record.GetId()] = record.GetData()
return nil
},
get: func(ctx context.Context, id string) (*databroker.Record, error) {
get: func(ctx context.Context, recordType, id string) (*databroker.Record, error) {
data, ok := m[id]
if !ok {
return nil, errors.New("not found")
@ -32,7 +32,7 @@ func TestEncryptedBackend(t *testing.T) {
Data: data,
}, nil
},
getAll: func(ctx context.Context) ([]*databroker.Record, error) {
getAll: func(ctx context.Context) ([]*databroker.Record, uint64, error) {
var records []*databroker.Record
for id, data := range m {
records = append(records, &databroker.Record{
@ -40,17 +40,7 @@ func TestEncryptedBackend(t *testing.T) {
Data: data,
})
}
return records, nil
},
list: func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
var records []*databroker.Record
for id, data := range m {
records = append(records, &databroker.Record{
Id: id,
Data: data,
})
}
return records, nil
return records, 0, nil
},
}
@ -61,7 +51,11 @@ func TestEncryptedBackend(t *testing.T) {
any, _ := anypb.New(wrapperspb.String("HELLO WORLD"))
err = e.Put(ctx, "TEST-1", any)
err = e.Put(ctx, &databroker.Record{
Type: "",
Id: "TEST-1",
Data: any,
})
if !assert.NoError(t, err) {
return
}
@ -70,7 +64,7 @@ func TestEncryptedBackend(t *testing.T) {
assert.NotEqual(t, any.Value, m["TEST-1"].Value, "value should be encrypted")
}
record, err := e.Get(ctx, "TEST-1")
record, err := e.Get(ctx, "", "TEST-1")
if !assert.NoError(t, err) {
return
}
@ -78,17 +72,7 @@ func TestEncryptedBackend(t *testing.T) {
assert.Equal(t, any.Value, record.Data.Value, "value should be preserved")
assert.Equal(t, any.TypeUrl, record.Type, "record type should be preserved")
records, err := e.List(ctx, "")
if !assert.NoError(t, err) {
return
}
if assert.Len(t, records, 1) {
assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved")
assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved")
assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved")
}
records, err = e.List(ctx, "")
records, _, err := e.GetAll(ctx)
if !assert.NoError(t, err) {
return
}

View file

@ -0,0 +1,199 @@
// Package inmemory contains an in-memory implementation of the databroker backend.
package inmemory
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/google/btree"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
type recordKey struct {
Type string
ID string
}
type recordChange struct {
record *databroker.Record
}
func (change recordChange) Less(item btree.Item) bool {
that, ok := item.(recordChange)
if !ok {
return false
}
return change.record.GetVersion() < that.record.GetVersion()
}
// A Backend stores data in-memory.
type Backend struct {
cfg *config
onChange *signal.Signal
lastVersion uint64
closeOnce sync.Once
closed chan struct{}
mu sync.RWMutex
lookup map[recordKey]*databroker.Record
changes *btree.BTree
}
// New creates a new in-memory backend storage.
func New(options ...Option) *Backend {
cfg := getConfig(options...)
backend := &Backend{
cfg: cfg,
onChange: signal.New(),
closed: make(chan struct{}),
lookup: make(map[recordKey]*databroker.Record),
changes: btree.New(cfg.degree),
}
if cfg.expiry != 0 {
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-backend.closed:
return
case <-ticker.C:
}
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
}
}()
}
return backend
}
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
backend.mu.Lock()
defer backend.mu.Unlock()
for {
item := backend.changes.Min()
if item == nil {
break
}
change, ok := item.(recordChange)
if !ok {
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
}
if change.record.GetModifiedAt().AsTime().Before(cutoff) {
_ = backend.changes.DeleteMin()
continue
}
// nothing left to remove
break
}
}
// Close closes the in-memory store and erases any stored data.
func (backend *Backend) Close() error {
backend.closeOnce.Do(func() {
close(backend.closed)
backend.mu.Lock()
defer backend.mu.Unlock()
backend.lookup = map[recordKey]*databroker.Record{}
backend.changes = btree.New(backend.cfg.degree)
})
return nil
}
// Get gets a record from the in-memory store.
func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) {
backend.mu.RLock()
defer backend.mu.RUnlock()
key := recordKey{Type: recordType, ID: id}
record, ok := backend.lookup[key]
if !ok {
return nil, storage.ErrNotFound
}
return dup(record), nil
}
// GetAll gets all the records from the in-memory store.
func (backend *Backend) GetAll(_ context.Context) ([]*databroker.Record, uint64, error) {
backend.mu.RLock()
defer backend.mu.RUnlock()
var records []*databroker.Record
for _, record := range backend.lookup {
records = append(records, dup(record))
}
return records, backend.lastVersion, nil
}
// Put puts a record into the in-memory store.
func (backend *Backend) Put(_ context.Context, record *databroker.Record) error {
if record == nil {
return fmt.Errorf("records cannot be nil")
}
backend.mu.Lock()
defer backend.mu.Unlock()
defer backend.onChange.Broadcast()
record.ModifiedAt = timestamppb.Now()
record.Version = backend.nextVersion()
backend.changes.ReplaceOrInsert(recordChange{record: dup(record)})
key := recordKey{Type: record.GetType(), ID: record.GetId()}
if record.GetDeletedAt() != nil {
delete(backend.lookup, key)
} else {
backend.lookup[key] = dup(record)
}
return nil
}
// Sync returns a record stream for any changes after version.
func (backend *Backend) Sync(ctx context.Context, version uint64) (storage.RecordStream, error) {
return newRecordStream(ctx, backend, version), nil
}
func (backend *Backend) getSince(version uint64) []*databroker.Record {
backend.mu.RLock()
defer backend.mu.RUnlock()
var records []*databroker.Record
pivot := recordChange{record: &databroker.Record{Version: version}}
backend.changes.AscendGreaterOrEqual(pivot, func(item btree.Item) bool {
change, ok := item.(recordChange)
if !ok {
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
}
record := change.record
// skip the pivoting version as we only want records after it
if record.GetVersion() != version {
records = append(records, dup(record))
}
return true
})
return records
}
func (backend *Backend) nextVersion() uint64 {
return atomic.AddUint64(&backend.lastVersion, 1)
}
func dup(record *databroker.Record) *databroker.Record {
return proto.Clone(record).(*databroker.Record)
}

View file

@ -0,0 +1,152 @@
package inmemory
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
func TestBackend(t *testing.T) {
ctx := context.Background()
backend := New()
defer func() { _ = backend.Close() }()
t.Run("get missing record", func(t *testing.T) {
record, err := backend.Get(ctx, "TYPE", "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: "abcd",
Data: data,
}))
record, err := backend.Get(ctx, "TYPE", "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.Equal(t, data, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "abcd", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "TYPE", record.Type)
assert.Equal(t, uint64(1), record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: "abcd",
DeletedAt: timestamppb.Now(),
}))
record, err := backend.Get(ctx, "TYPE", "abcd")
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("get all records", func(t *testing.T) {
for i := 0; i < 1000; i++ {
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}))
}
records, version, err := backend.GetAll(ctx)
assert.NoError(t, err)
assert.Len(t, records, 1000)
assert.Equal(t, uint64(1002), version)
})
}
func TestExpiry(t *testing.T) {
ctx := context.Background()
backend := New(WithExpiry(0))
defer func() { _ = backend.Close() }()
for i := 0; i < 1000; i++ {
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}))
}
stream, err := backend.Sync(ctx, 0)
require.NoError(t, err)
var records []*databroker.Record
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 1000)
backend.removeChangesBefore(time.Now().Add(time.Second))
stream, err = backend.Sync(ctx, 0)
require.NoError(t, err)
records = nil
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 0)
}
func TestConcurrency(t *testing.T) {
ctx := context.Background()
backend := New()
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
for i := 0; i < 1000; i++ {
_, _, _ = backend.GetAll(ctx)
}
return nil
})
eg.Go(func() error {
for i := 0; i < 1000; i++ {
_ = backend.Put(ctx, &databroker.Record{
Id: fmt.Sprint(i),
})
}
return nil
})
assert.NoError(t, eg.Wait())
}
func TestStream(t *testing.T) {
ctx := context.Background()
backend := New()
defer func() { _ = backend.Close() }()
stream, err := backend.Sync(ctx, 0)
require.NoError(t, err)
defer func() { _ = stream.Close() }()
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
for i := 0; i < 10000; i++ {
assert.True(t, stream.Next(true))
assert.Equal(t, "TYPE", stream.Record().GetType())
assert.Equal(t, fmt.Sprint(i), stream.Record().GetId())
assert.Equal(t, uint64(i+1), stream.Record().GetVersion())
}
return nil
})
eg.Go(func() error {
for i := 0; i < 10000; i++ {
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}))
}
return nil
})
require.NoError(t, eg.Wait())
}

View file

@ -0,0 +1,36 @@
package inmemory
import "time"
type config struct {
degree int
expiry time.Duration
}
// An Option customizes the in-memory backend.
type Option func(cfg *config)
func getConfig(options ...Option) *config {
cfg := &config{
degree: 16,
expiry: time.Hour,
}
for _, option := range options {
option(cfg)
}
return cfg
}
// WithBTreeDegree sets the btree degree of the changes btree.
func WithBTreeDegree(degree int) Option {
return func(cfg *config) {
cfg.degree = degree
}
}
// WithExpiry sets the expiry for changes.
func WithExpiry(expiry time.Duration) Option {
return func(cfg *config) {
cfg.expiry = expiry
}
}

View file

@ -1,200 +0,0 @@
// Package inmemory is the in-memory database using b-trees.
package inmemory
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/google/btree"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
// Name is the storage type name for inmemory backend.
const Name = config.StorageInMemoryName
var _ storage.Backend = (*DB)(nil)
type byIDRecord struct {
*databroker.Record
}
func (k byIDRecord) Less(than btree.Item) bool {
k2, _ := than.(byIDRecord)
return k.GetId() < k2.GetId()
}
type byVersionRecord struct {
*databroker.Record
}
func (k byVersionRecord) Less(than btree.Item) bool {
k2, _ := than.(byVersionRecord)
return k.GetVersion() < k2.GetVersion()
}
// DB is an in-memory database of records using b-trees.
type DB struct {
recordType string
lastVersion uint64
mu sync.RWMutex
byID *btree.BTree
byVersion *btree.BTree
deletedIDs []string
onchange *signal.Signal
closeOnce sync.Once
closed chan struct{}
}
// NewDB creates a new in-memory database for the given record type.
func NewDB(recordType string, btreeDegree int) *DB {
s := signal.New()
return &DB{
recordType: recordType,
byID: btree.New(btreeDegree),
byVersion: btree.New(btreeDegree),
onchange: s,
closed: make(chan struct{}),
}
}
// ClearDeleted clears all the currently deleted records older than the given cutoff.
func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
db.mu.Lock()
defer db.mu.Unlock()
var remaining []string
for _, id := range db.deletedIDs {
record, _ := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
ts := record.GetDeletedAt().AsTime()
if ts.Before(cutoff) {
db.byID.Delete(record)
db.byVersion.Delete(byVersionRecord(record))
} else {
remaining = append(remaining, id)
}
}
db.deletedIDs = remaining
}
// Close closes the database. Any watchers will be closed.
func (db *DB) Close() error {
db.closeOnce.Do(func() {
close(db.closed)
})
return nil
}
// Delete marks a record as deleted.
func (db *DB) Delete(_ context.Context, id string) error {
db.mu.Lock()
defer db.mu.Unlock()
defer db.onchange.Broadcast()
db.replaceOrInsert(id, func(record *databroker.Record) {
record.DeletedAt = ptypes.TimestampNow()
db.deletedIDs = append(db.deletedIDs, id)
})
return nil
}
// Get gets a record from the db.
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
db.mu.RLock()
defer db.mu.RUnlock()
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok {
return nil, errors.New("not found")
}
return record.Record, nil
}
// GetAll gets all the records in the db.
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
db.mu.RLock()
defer db.mu.RUnlock()
var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record)
return true
})
return records, nil
}
// List lists all the changes since the given version.
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
db.mu.RLock()
defer db.mu.RUnlock()
var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord)
if record.Version > sinceVersion {
records = append(records, record.Record)
}
return true
})
return records, nil
}
// Put replaces or inserts a record in the db.
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
db.mu.Lock()
defer db.mu.Unlock()
defer db.onchange.Broadcast()
db.replaceOrInsert(id, func(record *databroker.Record) {
record.Data = data
})
return nil
}
// Watch returns the underlying signal.Signal binding channel to the caller.
// Then the caller can listen to the channel for detecting changes.
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
ch := db.onchange.Bind()
go func() {
select {
case <-db.closed:
case <-ctx.Done():
}
close(ch)
db.onchange.Unbind(ch)
}()
return ch
}
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if ok {
db.byVersion.Delete(byVersionRecord(record))
record.Record = proto.Clone(record.Record).(*databroker.Record)
} else {
record.Record = new(databroker.Record)
}
f(record.Record)
if record.CreatedAt == nil {
record.CreatedAt = ptypes.TimestampNow()
}
record.ModifiedAt = ptypes.TimestampNow()
record.Type = db.recordType
record.Id = id
record.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
db.byID.ReplaceOrInsert(record)
db.byVersion.ReplaceOrInsert(byVersionRecord(record))
}

View file

@ -1,108 +0,0 @@
package inmemory
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/anypb"
)
func TestDB(t *testing.T) {
ctx := context.Background()
db := NewDB("example", 2)
t.Run("get missing record", func(t *testing.T) {
record, err := db.Get(ctx, "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, "abcd", data))
record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "abcd", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "example", record.Type)
assert.Equal(t, "000000000001", record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
assert.NoError(t, db.Delete(ctx, "abcd"))
record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.NotNil(t, record.DeletedAt)
}
})
t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(ctx, time.Now().Add(time.Second))
record, err := db.Get(ctx, "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("keep remaining", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, "abcd", data))
assert.NoError(t, db.Delete(ctx, "abcd"))
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
assert.NotNil(t, record)
db.ClearDeleted(ctx, time.Now().Add(time.Second))
})
t.Run("list", func(t *testing.T) {
for i := 0; i < 10; i++ {
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
}
records, err := db.List(ctx, "")
require.NoError(t, err)
assert.Len(t, records, 10)
records, err = db.List(ctx, "00000000000A")
require.NoError(t, err)
assert.Len(t, records, 4)
records, err = db.List(ctx, "00000000000F")
require.NoError(t, err)
assert.Len(t, records, 0)
})
t.Run("delete twice", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, "abcd", data))
assert.NoError(t, db.Delete(ctx, "abcd"))
assert.NoError(t, db.Delete(ctx, "abcd"))
db.ClearDeleted(ctx, time.Now().Add(time.Second))
record, err := db.Get(ctx, "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
}
func TestConcurrency(t *testing.T) {
ctx := context.Background()
db := NewDB("example", 2)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
for i := 0; i < 1000; i++ {
_, _ = db.List(ctx, "")
}
return nil
})
eg.Go(func() error {
for i := 0; i < 1000; i++ {
db.Put(ctx, fmt.Sprint(i), new(anypb.Any))
}
return nil
})
eg.Wait()
}

View file

@ -0,0 +1,101 @@
package inmemory
import (
"context"
"sync"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
type recordStream struct {
ctx context.Context
backend *Backend
changed chan struct{}
ready []*databroker.Record
version uint64
closeOnce sync.Once
closed chan struct{}
}
func newRecordStream(ctx context.Context, backend *Backend, version uint64) *recordStream {
stream := &recordStream{
ctx: ctx,
backend: backend,
changed: backend.onChange.Bind(),
version: version,
closed: make(chan struct{}),
}
return stream
}
func (stream *recordStream) fill() {
stream.ready = stream.backend.getSince(stream.version)
if len(stream.ready) > 0 {
// records are sorted by version,
// so update the local version to the last record
stream.version = stream.ready[len(stream.ready)-1].GetVersion()
}
}
func (stream *recordStream) Close() error {
stream.closeOnce.Do(func() {
stream.backend.onChange.Unbind(stream.changed)
close(stream.closed)
})
return nil
}
func (stream *recordStream) Next(wait bool) bool {
if len(stream.ready) > 0 {
stream.ready = stream.ready[1:]
}
if len(stream.ready) > 0 {
return true
}
for {
stream.fill()
if len(stream.ready) > 0 {
return true
}
if wait {
select {
case <-stream.ctx.Done():
return false
case <-stream.closed:
return false
case <-stream.changed:
// query for records again
}
} else {
return false
}
}
}
func (stream *recordStream) Record() *databroker.Record {
var r *databroker.Record
if len(stream.ready) > 0 {
r = stream.ready[0]
}
return r
}
func (stream *recordStream) Err() error {
select {
case <-stream.ctx.Done():
return stream.ctx.Err()
case <-stream.closed:
return storage.ErrStreamClosed
case <-stream.backend.closed:
return storage.ErrStreamClosed
default:
return nil
}
}

View file

@ -0,0 +1,31 @@
package redis
import (
"context"
"time"
"github.com/go-redis/redis/v8"
pomeriumconfig "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
)
type logger struct {
}
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
log.Info().Str("service", "redis").Msgf(format, v...)
}
func init() {
redis.SetLogger(logger{})
}
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
Operation: operation,
Error: err,
Backend: pomeriumconfig.StorageRedisName,
}, time.Since(startTime))
}

View file

@ -2,32 +2,34 @@ package redis
import (
"crypto/tls"
"time"
)
type dbConfig struct {
tls *tls.Config
recordType string
type config struct {
tls *tls.Config
expiry time.Duration
}
// Option customizes a DB.
type Option func(*dbConfig)
// Option customizes a Backend.
type Option func(*config)
// WithRecordType sets the record type in the config.
func WithRecordType(recordType string) Option {
return func(cfg *dbConfig) {
cfg.recordType = recordType
}
}
// WithTLSConfig sets the tls.Config which DB uses.
// WithTLSConfig sets the tls.Config which Backend uses.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *dbConfig) {
return func(cfg *config) {
cfg.tls = tlsConfig
}
}
func getConfig(options ...Option) *dbConfig {
cfg := new(dbConfig)
// WithExpiry sets the expiry for changes.
func WithExpiry(expiry time.Duration) Option {
return func(cfg *config) {
cfg.expiry = expiry
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithExpiry(time.Hour * 24)(cfg)
for _, o := range options {
o(cfg)
}

View file

@ -5,29 +5,30 @@ import (
"context"
"errors"
"fmt"
"strconv"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
redis "github.com/go-redis/redis/v8"
"github.com/golang/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
// Name of the storage backend.
const Name = config.StorageRedisName
const (
maxTransactionRetries = 100
watchPollInterval = 30 * time.Second
lastVersionKey = "pomerium.last_version"
lastVersionChKey = "pomerium.last_version_ch"
recordHashKey = "pomerium.records"
changesSetKey = "pomerium.changes"
)
// custom errors
@ -35,21 +36,24 @@ var (
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
)
// DB implements the storage.Backend on top of redis.
type DB struct {
cfg *dbConfig
// Backend implements the storage.Backend on top of redis.
type Backend struct {
cfg *config
client *redis.Client
client *redis.Client
onChange *signal.Signal
closeOnce sync.Once
closed chan struct{}
}
// New creates a new redis storage backend.
func New(rawURL string, options ...Option) (*DB, error) {
db := &DB{
cfg: getConfig(options...),
closed: make(chan struct{}),
func New(rawURL string, options ...Option) (*Backend, error) {
cfg := getConfig(options...)
backend := &Backend{
cfg: cfg,
closed: make(chan struct{}),
onChange: signal.New(),
}
opts, err := redis.ParseURL(rawURL)
if err != nil {
@ -57,194 +61,150 @@ func New(rawURL string, options ...Option) (*DB, error) {
}
// when using TLS, the TLS config will not be set to nil, in which case we replace it with our own
if opts.TLSConfig != nil {
opts.TLSConfig = db.cfg.tls
opts.TLSConfig = backend.cfg.tls
}
db.client = redis.NewClient(opts)
metrics.AddRedisMetrics(db.client.PoolStats)
return db, nil
}
backend.client = redis.NewClient(opts)
metrics.AddRedisMetrics(backend.client.PoolStats)
go backend.listenForVersionChanges()
if cfg.expiry != 0 {
go func() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-backend.closed:
return
case <-ticker.C:
}
// ClearDeleted clears all the deleted records older than the cutoff time.
func (db *DB) ClearDeleted(ctx context.Context, cutoff time.Time) {
var err error
_, span := trace.StartSpan(ctx, "databroker.redis.ClearDeleted")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "clear_deleted", err) }(time.Now())
ids, _ := db.client.SMembers(ctx, formatDeletedSetKey(db.cfg.recordType)).Result()
records, _ := redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
_, err = db.client.Pipelined(ctx, func(p redis.Pipeliner) error {
for _, record := range records {
if record.GetDeletedAt().AsTime().Before(cutoff) {
p.HDel(ctx, formatRecordsKey(db.cfg.recordType), record.GetId())
p.ZRem(ctx, formatVersionSetKey(db.cfg.recordType), record.GetId())
p.SRem(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
}
}
return nil
})
}()
}
return backend, nil
}
// Close closes the underlying redis connection and any watchers.
func (db *DB) Close() error {
func (backend *Backend) Close() error {
var err error
db.closeOnce.Do(func() {
err = db.client.Close()
close(db.closed)
backend.closeOnce.Do(func() {
err = backend.client.Close()
close(backend.closed)
})
return err
}
// Delete marks a record as deleted.
func (db *DB) Delete(ctx context.Context, id string) (err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.Delete")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "delete", err) }(time.Now())
var record *databroker.Record
err = db.incrementVersion(ctx,
func(tx *redis.Tx, version int64) error {
var err error
record, err = redisGetRecord(ctx, tx, db.cfg.recordType, id)
if errors.Is(err, redis.Nil) {
// nothing to do, as the record doesn't exist
return nil
} else if err != nil {
return err
}
// mark it as deleted
record.DeletedAt = timestamppb.Now()
record.Version = formatVersion(version)
return nil
},
func(p redis.Pipeliner, version int64) error {
err := redisSetRecord(ctx, p, db.cfg.recordType, record)
if err != nil {
return err
}
// add it to the collection of deleted entries
p.SAdd(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
return nil
})
return err
}
// Get gets a record.
func (db *DB) Get(ctx context.Context, id string) (record *databroker.Record, err error) {
// Get gets a record from redis.
func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
return record, err
}
key, field := getHashKey(recordType, id)
cmd := backend.client.HGet(ctx, key, field)
raw, err := cmd.Result()
if err == redis.Nil {
return nil, storage.ErrNotFound
} else if err != nil {
return nil, err
}
// List lists all the records changed since the sinceVersion. Records are sorted in version order.
func (db *DB) List(ctx context.Context, sinceVersion string) (records []*databroker.Record, err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.List")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "list", err) }(time.Now())
var ids []string
ids, err = redisListIDsSince(ctx, db.client, db.cfg.recordType, sinceVersion)
var record databroker.Record
err = proto.Unmarshal([]byte(raw), &record)
if err != nil {
return nil, err
}
records, err = redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
return records, err
return &record, nil
}
// Put updates a record.
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
// GetAll gets all the records from redis.
func (backend *Backend) GetAll(ctx context.Context) (records []*databroker.Record, latestRecordVersion uint64, err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.GetAll")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "getall", err) }(time.Now())
p := backend.client.Pipeline()
lastVersionCmd := p.Get(ctx, lastVersionKey)
resultsCmd := p.HVals(ctx, recordHashKey)
_, err = p.Exec(ctx)
if err != nil {
return nil, 0, err
}
latestRecordVersion, err = lastVersionCmd.Uint64()
if errors.Is(err, redis.Nil) {
latestRecordVersion = 0
} else if err != nil {
return nil, 0, err
}
results, err := resultsCmd.Result()
if err != nil {
return nil, 0, err
}
for _, result := range results {
var record databroker.Record
err := proto.Unmarshal([]byte(result), &record)
if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected")
continue
}
records = append(records, &record)
}
return records, latestRecordVersion, nil
}
// Put puts a record into redis.
func (backend *Backend) Put(ctx context.Context, record *databroker.Record) (err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.Put")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
var record *databroker.Record
err = db.incrementVersion(ctx,
func(tx *redis.Tx, version int64) error {
var err error
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
if errors.Is(err, redis.Nil) {
record = new(databroker.Record)
record.CreatedAt = timestamppb.Now()
} else if err != nil {
return backend.incrementVersion(ctx,
func(tx *redis.Tx, version uint64) error {
record.ModifiedAt = timestamppb.Now()
record.Version = version
return nil
},
func(p redis.Pipeliner, version uint64) error {
bs, err := proto.Marshal(record)
if err != nil {
return err
}
record.ModifiedAt = timestamppb.Now()
record.Type = db.cfg.recordType
record.Id = id
record.Data = data
record.Version = formatVersion(version)
key, field := getHashKey(record.GetType(), record.GetId())
if record.DeletedAt != nil {
p.HDel(ctx, key, field)
} else {
p.HSet(ctx, key, field, bs)
}
p.ZAdd(ctx, changesSetKey, &redis.Z{
Score: float64(version),
Member: bs,
})
return nil
},
func(p redis.Pipeliner, version int64) error {
return redisSetRecord(ctx, p, db.cfg.recordType, record)
})
return err
}
// Watch returns a channel that is signaled any time the last version is incremented (ie on Put/Delete).
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
s := signal.New()
ch := s.Bind()
go func() {
defer s.Unbind(ch)
defer close(ch)
// force a check
poll := time.NewTicker(watchPollInterval)
defer poll.Stop()
// use pub/sub for quicker notify
pubsub := db.client.Subscribe(ctx, formatLastVersionChannelKey(db.cfg.recordType))
defer func() { _ = pubsub.Close() }()
pubsubCh := pubsub.Channel()
var lastVersion int64
for {
v, err := redisGetLastVersion(ctx, db.client, db.cfg.recordType)
if err != nil {
log.Error().Err(err).Msg("redis: error retrieving last version")
} else if v != lastVersion {
// don't broadcast the first time
if lastVersion != 0 {
s.Broadcast()
}
lastVersion = v
}
select {
case <-ctx.Done():
return
case <-db.closed:
return
case <-poll.C:
case <-pubsubCh:
// re-check
}
}
}()
return ch
// Sync returns a record stream of any records changed after the specified version.
func (backend *Backend) Sync(ctx context.Context, version uint64) (storage.RecordStream, error) {
return newRecordStream(ctx, backend, version), nil
}
// incrementVersion increments the last version key, runs the code in `query`, then attempts to commit the code in
// `commit`. If the last version changes in the interim, we will retry the transaction.
func (db *DB) incrementVersion(ctx context.Context,
query func(tx *redis.Tx, version int64) error,
commit func(p redis.Pipeliner, version int64) error,
func (backend *Backend) incrementVersion(ctx context.Context,
query func(tx *redis.Tx, version uint64) error,
commit func(p redis.Pipeliner, version uint64) error,
) error {
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
txf := func(tx *redis.Tx) error {
version, err := redisGetLastVersion(ctx, tx, db.cfg.recordType)
if err != nil {
version, err := tx.Get(ctx, lastVersionKey).Uint64()
if errors.Is(err, redis.Nil) {
version = 0
} else if err != nil {
return err
}
version++
@ -260,16 +220,23 @@ func (db *DB) incrementVersion(ctx context.Context,
if err != nil {
return err
}
p.Set(ctx, formatLastVersionKey(db.cfg.recordType), version, 0)
p.Publish(ctx, formatLastVersionChannelKey(db.cfg.recordType), version)
p.Set(ctx, lastVersionKey, version, 0)
p.Publish(ctx, lastVersionChKey, version)
return nil
})
return err
}
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
for i := 0; i < maxTransactionRetries; i++ {
err := db.client.Watch(ctx, txf, formatLastVersionKey(db.cfg.recordType))
err := backend.client.Watch(ctx, txf, lastVersionKey)
if errors.Is(err, redis.TxFailedErr) {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(bo.NextBackOff()):
}
continue // retry
} else if err != nil {
return err
@ -281,121 +248,81 @@ func (db *DB) incrementVersion(ctx context.Context,
return ErrExceededMaxRetries
}
func redisGetLastVersion(ctx context.Context, c redis.Cmdable, recordType string) (int64, error) {
version, err := c.Get(ctx, formatLastVersionKey(recordType)).Int64()
if errors.Is(err, redis.Nil) {
version = 0
} else if err != nil {
return 0, err
func (backend *Backend) listenForVersionChanges() {
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-backend.closed
cancel()
}()
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
outer:
for {
pubsub := backend.client.Subscribe(ctx, lastVersionChKey)
for {
msg, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
select {
case <-ctx.Done():
return
case <-time.After(bo.NextBackOff()):
}
continue outer
}
bo.Reset()
switch msg.(type) {
case *redis.Message:
backend.onChange.Broadcast()
}
}
}
return version, nil
}
func redisGetRecord(ctx context.Context, c redis.Cmdable, recordType string, id string) (*databroker.Record, error) {
records, err := redisGetRecords(ctx, c, recordType, []string{id})
if err != nil {
return nil, err
} else if len(records) < 1 {
return nil, redis.Nil
}
return records[0], nil
}
func redisGetRecords(ctx context.Context, c redis.Cmdable, recordType string, ids []string) ([]*databroker.Record, error) {
if len(ids) == 0 {
return nil, nil
}
results, err := c.HMGet(ctx, formatRecordsKey(recordType), ids...).Result()
if err != nil {
return nil, err
}
records := make([]*databroker.Record, 0, len(results))
for _, result := range results {
// results are returned as either nil or a string
if result == nil {
continue
}
rawstr, ok := result.(string)
if !ok {
continue
}
var record databroker.Record
err := proto.Unmarshal([]byte(rawstr), &record)
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
ctx := context.Background()
for {
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
Min: "-inf",
Max: "+inf",
Offset: 0,
Count: 1,
})
results, err := cmd.Result()
if err != nil {
continue
log.Error().Err(err).Msg("redis: error retrieving changes for expiration")
return
}
// nothing left to do
if len(results) == 0 {
return
}
var record databroker.Record
err = proto.Unmarshal([]byte(results[0]), &record)
if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected")
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
}
// if the record's modified timestamp is after the cutoff, we're all done, so break
if record.GetModifiedAt().AsTime().After(cutoff) {
break
}
// remove the record
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
if err != nil {
log.Error().Err(err).Msg("redis: error removing member")
return
}
records = append(records, &record)
}
return records, nil
}
func redisListIDsSince(ctx context.Context,
c redis.Cmdable, recordType string,
sinceVersion string,
) ([]string, error) {
v, err := strconv.ParseInt(sinceVersion, 16, 64)
if err != nil {
v = 0
}
rng := &redis.ZRangeBy{
Min: fmt.Sprintf("(%d", v),
Max: "+inf",
}
return c.ZRangeByScore(ctx, formatVersionSetKey(recordType), rng).Result()
}
func redisSetRecord(ctx context.Context, p redis.Pipeliner, recordType string, record *databroker.Record) error {
v, err := strconv.ParseInt(record.GetVersion(), 16, 64)
if err != nil {
v = 0
}
raw, err := proto.Marshal(record)
if err != nil {
return err
}
// store the record in the hash
p.HSet(ctx, formatRecordsKey(recordType), record.GetId(), string(raw))
// set its score for sorting by version
p.ZAdd(ctx, formatVersionSetKey(recordType), &redis.Z{
Score: float64(v),
Member: record.GetId(),
})
return nil
}
func formatDeletedSetKey(recordType string) string {
return fmt.Sprintf("%s_deleted_set", recordType)
}
func formatLastVersionChannelKey(recordType string) string {
return fmt.Sprintf("%s_last_version_ch", recordType)
}
func formatLastVersionKey(recordType string) string {
return fmt.Sprintf("%s_last_version", recordType)
}
func formatRecordsKey(recordType string) string {
return recordType
}
func formatVersion(version int64) string {
return fmt.Sprintf("%012d", version)
}
func formatVersionSetKey(recordType string) string {
return fmt.Sprintf("%s_version_set", recordType)
}
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
Operation: operation,
Error: err,
Backend: Name,
}, time.Since(startTime))
func getHashKey(recordType, id string) (key, field string) {
return recordHashKey, fmt.Sprintf("%s/%s", recordType, id)
}

View file

@ -2,192 +2,172 @@ package redis
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"runtime"
"strings"
"testing"
"time"
"github.com/ory/dockertest/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
var db *DB
func cleanup(ctx context.Context, db *DB, t *testing.T) {
require.NoError(t, db.client.FlushAll(ctx).Err())
}
func tlsConfig(rawURL string, t *testing.T) *tls.Config {
if !strings.HasPrefix(rawURL, "rediss") {
return nil
}
cert, err := cryptutil.CertificateFromFile("./testdata/tls/redis.crt", "./testdata/tls/redis.key")
require.NoError(t, err)
caCertPool := x509.NewCertPool()
caCert, err := ioutil.ReadFile("./testdata/tls/ca.crt")
require.NoError(t, err)
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig := &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{*cert},
}
return tlsConfig
}
func runWithRedisDockerImage(t *testing.T, runOpts *dockertest.RunOptions, withTLS bool, testFunc func(t *testing.T)) {
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Could not connect to docker: %s", err)
}
resource, err := pool.RunWithOptions(runOpts)
if err != nil {
t.Fatalf("Could not start resource: %s", err)
}
resource.Expire(30)
defer func() {
if err := pool.Purge(resource); err != nil {
t.Fatalf("Could not purge resource: %s", err)
}
}()
scheme := "redis"
if withTLS {
scheme = "rediss"
}
address := fmt.Sprintf(scheme+"://localhost:%s/0", resource.GetPort("6379/tcp"))
if err := pool.Retry(func() error {
var err error
db, err = New(address, WithRecordType("record_type"), WithTLSConfig(tlsConfig(address, t)))
if err != nil {
return err
}
err = db.client.Ping(context.Background()).Err()
return err
}); err != nil {
t.Fatalf("Could not connect to docker: %s", err)
}
testFunc(t)
}
func TestDB(t *testing.T) {
func TestBackend(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
cwd, err := os.Getwd()
assert.NoError(t, err)
tlsCmd := []string{
"--port", "0",
"--tls-port", "6379",
"--tls-cert-file", "/tls/redis.crt",
"--tls-key-file", "/tls/redis.key",
"--tls-ca-cert-file", "/tls/ca.crt",
}
tests := []struct {
name string
withTLS bool
runOpts *dockertest.RunOptions
}{
{"redis", false, &dockertest.RunOptions{Repository: "redis", Tag: "latest"}},
{"redis TLS", true, &dockertest.RunOptions{Repository: "redis", Tag: "latest", Cmd: tlsCmd, Mounts: []string{cwd + "/testdata/tls:/tls"}}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
runWithRedisDockerImage(t, tc.runOpts, tc.withTLS, testDB)
})
for _, useTLS := range []bool{true, false} {
require.NoError(t, testutil.WithTestRedis(useTLS, func(rawURL string) error {
ctx := context.Background()
var opts []Option
if useTLS {
opts = append(opts, WithTLSConfig(testutil.RedisTLSConfig()))
}
backend, err := New(rawURL, opts...)
require.NoError(t, err)
defer func() { _ = backend.Close() }()
t.Run("get missing record", func(t *testing.T) {
record, err := backend.Get(ctx, "TYPE", "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: "abcd",
Data: data,
}))
record, err := backend.Get(ctx, "TYPE", "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.Equal(t, data, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "abcd", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "TYPE", record.Type)
assert.Equal(t, uint64(1), record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: "abcd",
DeletedAt: timestamppb.Now(),
}))
record, err := backend.Get(ctx, "TYPE", "abcd")
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("get all records", func(t *testing.T) {
for i := 0; i < 1000; i++ {
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}))
}
records, version, err := backend.GetAll(ctx)
assert.NoError(t, err)
assert.Len(t, records, 1000)
assert.Equal(t, uint64(1002), version)
})
return nil
}))
}
}
func testDB(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
users := []*directory.User{
{Id: "u1", GroupIds: []string{"test", "admin"}},
{Id: "u2"},
{Id: "u3", GroupIds: []string{"test"}},
func TestChangeSignal(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
ids := []string{"a", "b", "c"}
id := ids[0]
t.Run("get missing record", func(t *testing.T) {
record, err := db.Get(ctx, id)
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data, _ := anypb.New(users[0])
assert.NoError(t, db.Put(ctx, id, data))
record, err := db.Get(ctx, id)
ctx := context.Background()
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*10)
defer clearTimeout()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
backend1, err := New(rawURL)
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt)
assert.NotEmpty(t, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "a", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "000000000001", record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
original, err := db.Get(ctx, id)
defer func() { _ = backend1.Close() }()
backend2, err := New(rawURL)
require.NoError(t, err)
assert.NoError(t, db.Delete(ctx, id))
record, err := db.Get(ctx, id)
require.NoError(t, err)
require.NotNil(t, record)
assert.NotNil(t, record.DeletedAt)
assert.NotEqual(t, original.GetVersion(), record.GetVersion())
})
t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(ctx, time.Now().Add(time.Second))
record, err := db.Get(ctx, id)
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("list", func(t *testing.T) {
cleanup(ctx, db, t)
defer func() { _ = backend2.Close() }()
for i := 0; i < 10; i++ {
id := fmt.Sprintf("%02d", i)
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, id, data))
}
ch := backend1.onChange.Bind()
defer backend1.onChange.Unbind(ch)
records, err := db.List(ctx, "")
assert.NoError(t, err)
assert.Len(t, records, 10)
records, err = db.List(ctx, "000000000005")
assert.NoError(t, err)
assert.Len(t, records, 5)
records, err = db.List(ctx, "000000000010")
assert.NoError(t, err)
assert.Len(t, records, 0)
})
t.Run("watch", func(t *testing.T) {
ch := db.Watch(ctx)
time.Sleep(time.Second)
go db.Put(ctx, "WATCH", new(anypb.Any))
go func() {
ticker := time.NewTicker(time.Millisecond * 100)
defer ticker.Stop()
for {
_ = backend2.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: "ID",
})
select {
case <-ctx.Done():
return
case <-ticker.C:
}
}
}()
select {
case <-ch:
case <-time.After(time.Second * 10):
t.Error("expected watch signal on put")
case <-ctx.Done():
t.Fatal("expected signal to be fired when another backend triggers a change")
}
})
return nil
}))
}
func TestExpiry(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
ctx := context.Background()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
backend, err := New(rawURL, WithExpiry(0))
require.NoError(t, err)
defer func() { _ = backend.Close() }()
for i := 0; i < 1000; i++ {
assert.NoError(t, backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}))
}
stream, err := backend.Sync(ctx, 0)
require.NoError(t, err)
var records []*databroker.Record
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 1000)
backend.removeChangesBefore(time.Now().Add(time.Second))
stream, err = backend.Sync(ctx, 0)
require.NoError(t, err)
records = nil
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 0)
return nil
}))
}

105
pkg/storage/redis/stream.go Normal file
View file

@ -0,0 +1,105 @@
package redis
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-redis/redis/v8"
"github.com/golang/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type recordStream struct {
ctx context.Context
backend *Backend
changed chan struct{}
version uint64
record *databroker.Record
err error
closeOnce sync.Once
closed chan struct{}
}
func newRecordStream(ctx context.Context, backend *Backend, version uint64) *recordStream {
return &recordStream{
ctx: ctx,
backend: backend,
changed: backend.onChange.Bind(),
version: version,
closed: make(chan struct{}),
}
}
func (stream *recordStream) Close() error {
stream.closeOnce.Do(func() {
stream.backend.onChange.Unbind(stream.changed)
close(stream.closed)
})
return nil
}
func (stream *recordStream) Next(block bool) bool {
if stream.err != nil {
return false
}
ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop()
for {
cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{
Min: fmt.Sprintf("(%d", stream.version),
Max: "+inf",
Offset: 0,
Count: 1,
})
results, err := cmd.Result()
if err != nil {
stream.err = err
return false
}
if len(results) > 0 {
result := results[0]
var record databroker.Record
err = proto.Unmarshal([]byte(result), &record)
if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected")
} else {
stream.record = &record
}
stream.version++
return true
}
if block {
select {
case <-stream.ctx.Done():
stream.err = stream.ctx.Err()
return false
case <-stream.closed:
return false
case <-ticker.C: // check again
case <-stream.changed: // check again
}
} else {
return false
}
}
}
func (stream *recordStream) Record() *databroker.Record {
return stream.record
}
func (stream *recordStream) Err() error {
return stream.err
}

View file

@ -3,8 +3,8 @@ package storage
import (
"context"
"errors"
"strings"
"time"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
@ -13,30 +13,39 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Errors
var (
ErrNotFound = errors.New("record not found")
ErrStreamClosed = errors.New("record stream closed")
)
// A RecordStream is a stream of records.
type RecordStream interface {
// Close closes the record stream and releases any underlying resources.
Close() error
// Next is called to retrieve the next record. If one is available it will
// be returned immediately. If none is available and block is true, the method
// will block until one is available or an error occurs. The error should be
// checked with a call to `.Err()`.
Next(block bool) bool
// Record returns the current record.
Record() *databroker.Record
// Err returns any error that occurred while streaming.
Err() error
}
// Backend is the interface required for a storage backend.
type Backend interface {
// Close closes the backend.
Close() error
// Put is used to insert or update a record.
Put(ctx context.Context, id string, data *anypb.Any) error
// Get is used to retrieve a record.
Get(ctx context.Context, id string) (*databroker.Record, error)
// List is used to retrieve all the records since a version.
List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
// Delete is used to mark a record as deleted.
Delete(ctx context.Context, id string) error
// ClearDeleted is used clear marked delete records.
ClearDeleted(ctx context.Context, cutoff time.Time)
// Watch returns a channel to the caller. The channel is used to notify
// about changes that happen in storage. When ctx is finished, Watch will close
// the channel.
Watch(ctx context.Context) <-chan struct{}
Get(ctx context.Context, recordType, id string) (*databroker.Record, error)
// GetAll gets all the records.
GetAll(ctx context.Context) (records []*databroker.Record, version uint64, err error)
// Put is used to insert or update a record.
Put(ctx context.Context, record *databroker.Record) error
// Sync syncs record changes after the specified version.
Sync(ctx context.Context, version uint64) (RecordStream, error)
}
// MatchAny searches any data with a query.

View file

@ -3,7 +3,6 @@ package storage
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
@ -13,50 +12,29 @@ import (
)
type mockBackend struct {
put func(ctx context.Context, id string, data *anypb.Any) error
get func(ctx context.Context, id string) (*databroker.Record, error)
getAll func(ctx context.Context) ([]*databroker.Record, error)
list func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
delete func(ctx context.Context, id string) error
clearDeleted func(ctx context.Context, cutoff time.Time)
query func(ctx context.Context, query string, offset, limit int) ([]*databroker.Record, int, error)
watch func(ctx context.Context) <-chan struct{}
put func(ctx context.Context, record *databroker.Record) error
get func(ctx context.Context, recordType, id string) (*databroker.Record, error)
getAll func(ctx context.Context) ([]*databroker.Record, uint64, error)
}
func (m *mockBackend) Close() error {
return nil
}
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
return m.put(ctx, id, data)
func (m *mockBackend) Put(ctx context.Context, record *databroker.Record) error {
return m.put(ctx, record)
}
func (m *mockBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
return m.get(ctx, id)
func (m *mockBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) {
return m.get(ctx, recordType, id)
}
func (m *mockBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) {
func (m *mockBackend) GetAll(ctx context.Context) ([]*databroker.Record, uint64, error) {
return m.getAll(ctx)
}
func (m *mockBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
return m.list(ctx, sinceVersion)
}
func (m *mockBackend) Delete(ctx context.Context, id string) error {
return m.delete(ctx, id)
}
func (m *mockBackend) ClearDeleted(ctx context.Context, cutoff time.Time) {
m.clearDeleted(ctx, cutoff)
}
func (m *mockBackend) Query(ctx context.Context, query string, offset, limit int) ([]*databroker.Record, int, error) {
return m.query(ctx, query, offset, limit)
}
func (m *mockBackend) Watch(ctx context.Context) <-chan struct{} {
return m.watch(ctx)
func (m *mockBackend) Sync(ctx context.Context, version uint64) (RecordStream, error) {
panic("implement me")
}
func TestMatchAny(t *testing.T) {