mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-17 11:07:18 +02:00
databroker: refactor databroker to sync all changes (#1879)
* refactor backend, implement encrypted store * refactor in-memory store * wip * wip * wip * add syncer test * fix redis expiry * fix linting issues * fix test by skipping non-config records * fix backoff import * fix init issues * fix query * wait for initial sync before starting directory sync * add type to SyncLatest * add more log messages, fix deadlock in in-memory store, always return server version from SyncLatest * update sync types and tests * add redis tests * skip macos in github actions * add comments to proto * split getBackend into separate methods * handle errors in initVersion * return different error for not found vs other errors in get * use exponential backoff for redis transaction retry * rename raw to result * use context instead of close channel * store type urls as constants in databroker * use timestampb instead of ptypes * fix group merging not waiting * change locked names * update GetAll to return latest record version * add method to grpcutil to get the type url for a protobuf type
This commit is contained in:
parent
b1871b0f2e
commit
5d60cff21e
66 changed files with 2762 additions and 2871 deletions
|
@ -98,6 +98,7 @@ linters:
|
|||
# - wsl
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
# List of regexps of issue texts to exclude, empty list by default.
|
||||
# But independently from this option we use default exclude patterns,
|
||||
# it can be disabled by `exclude-use-default: false`. To list all
|
||||
|
@ -140,6 +141,10 @@ issues:
|
|||
# good job Protobuffs!
|
||||
- "method XXX"
|
||||
- "SA1019"
|
||||
# EXC0001 errcheck: Almost all programs ignore errors on these functions and in most cases it's ok
|
||||
- Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv). is not checked
|
||||
|
||||
|
||||
|
||||
exclude-rules:
|
||||
# https://github.com/go-critic/go-critic/issues/926
|
||||
|
@ -179,6 +184,8 @@ issues:
|
|||
text: "Potential hardcoded credentials"
|
||||
linters:
|
||||
- gosec
|
||||
- linters: [golint]
|
||||
text: "should have a package comment"
|
||||
|
||||
# golangci.com configuration
|
||||
# https://github.com/golangci/golangci/wiki/Configuration
|
||||
|
|
|
@ -526,17 +526,17 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
|
||||
}
|
||||
_, err = user.Set(ctx, state.dataBrokerClient, mu.User)
|
||||
_, err = user.Put(ctx, state.dataBrokerClient, mu.User)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := session.Set(ctx, state.dataBrokerClient, s)
|
||||
res, err := session.Put(ctx, state.dataBrokerClient, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving session: %w", err)
|
||||
}
|
||||
sessionState.Version = sessions.Version(res.GetServerVersion())
|
||||
sessionState.Version = sessions.Version(fmt.Sprint(res.GetServerVersion()))
|
||||
|
||||
_, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{
|
||||
UserId: s.UserId,
|
||||
|
|
|
@ -23,7 +23,6 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
|
@ -166,7 +165,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
|
@ -246,9 +245,6 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
encryptedEncoder: mock.Encoder{},
|
||||
sharedEncoder: mock.Encoder{},
|
||||
dataBrokerClient: mockDataBrokerServiceClient{
|
||||
delete: func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
return nil, nil
|
||||
},
|
||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
data, err := ptypes.MarshalAny(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
|
@ -259,13 +255,16 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
|
@ -368,8 +367,8 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
},
|
||||
set: func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
|
||||
return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil
|
||||
put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
|
@ -514,7 +513,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
|
@ -633,7 +632,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
|
@ -672,21 +671,16 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
type mockDataBrokerServiceClient struct {
|
||||
databroker.DataBrokerServiceClient
|
||||
|
||||
delete func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
|
||||
set func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error)
|
||||
}
|
||||
|
||||
func (m mockDataBrokerServiceClient) Delete(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
return m.delete(ctx, in, opts...)
|
||||
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
|
||||
put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
|
||||
}
|
||||
|
||||
func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
return m.get(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
|
||||
return m.set(ctx, in, opts...)
|
||||
func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
return m.put(ctx, in, opts...)
|
||||
}
|
||||
|
||||
type mockDirectoryServiceClient struct {
|
||||
|
@ -729,7 +723,7 @@ func TestAuthenticate_SignOut_CSRF(t *testing.T) {
|
|||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
|
|
|
@ -24,19 +24,16 @@ type Authorize struct {
|
|||
currentOptions *config.AtomicOptions
|
||||
templates *template.Template
|
||||
|
||||
dataBrokerInitialSync map[string]chan struct{}
|
||||
dataBrokerInitialSync chan struct{}
|
||||
}
|
||||
|
||||
// New validates and creates a new Authorize service from a set of config options.
|
||||
func New(cfg *config.Config) (*Authorize, error) {
|
||||
a := Authorize{
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
store: evaluator.NewStore(),
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
dataBrokerInitialSync: map[string]chan struct{}{
|
||||
"type.googleapis.com/directory.Group": make(chan struct{}, 1),
|
||||
"type.googleapis.com/directory.User": make(chan struct{}, 1),
|
||||
},
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
store: evaluator.NewStore(),
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
dataBrokerInitialSync: make(chan struct{}),
|
||||
}
|
||||
|
||||
state, err := newAuthorizeStateFromConfig(cfg, a.store)
|
||||
|
@ -48,6 +45,22 @@ func New(cfg *config.Config) (*Authorize, error) {
|
|||
return &a, nil
|
||||
}
|
||||
|
||||
// Run runs the authorize service.
|
||||
func (a *Authorize) Run(ctx context.Context) error {
|
||||
return newDataBrokerSyncer(a).Run(ctx)
|
||||
}
|
||||
|
||||
// WaitForInitialSync blocks until the initial sync is complete.
|
||||
func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-a.dataBrokerInitialSync:
|
||||
}
|
||||
log.Info().Msg("initial sync from databroker complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOptions(o *config.Options) error {
|
||||
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
||||
return fmt.Errorf("authorize: bad 'SHARED_SECRET': %w", err)
|
||||
|
|
|
@ -115,7 +115,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
|
|||
},
|
||||
})
|
||||
store.UpdateRecord(&databroker.Record{
|
||||
Version: "1",
|
||||
Version: 1,
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: sessionID,
|
||||
Data: data,
|
||||
|
@ -126,7 +126,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
|
|||
Email: "foo@example.com",
|
||||
})
|
||||
store.UpdateRecord(&databroker.Record{
|
||||
Version: "1",
|
||||
Version: 1,
|
||||
Type: "type.googleapis.com/user.User",
|
||||
Id: userID,
|
||||
Data: data,
|
||||
|
@ -188,7 +188,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
|||
},
|
||||
})
|
||||
store.UpdateRecord(&databroker.Record{
|
||||
Version: fmt.Sprint(i),
|
||||
Version: uint64(i),
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: sessionID,
|
||||
Data: data,
|
||||
|
@ -198,7 +198,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
|||
Id: userID,
|
||||
})
|
||||
store.UpdateRecord(&databroker.Record{
|
||||
Version: fmt.Sprint(i),
|
||||
Version: uint64(i),
|
||||
Type: "type.googleapis.com/user.User",
|
||||
Id: userID,
|
||||
Data: data,
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
|
@ -40,9 +41,8 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
|
|||
}
|
||||
|
||||
record := new(databroker.Record)
|
||||
record.CreatedAt = timestamppb.Now()
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version = uuid.New().String()
|
||||
record.Version = cryptutil.NewRandomUInt64()
|
||||
record.Id = uuid.New().String()
|
||||
record.Data = any
|
||||
record.Type = any.TypeUrl
|
||||
|
@ -56,8 +56,8 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
|
|||
}
|
||||
|
||||
// ClearRecords removes all the records from the store.
|
||||
func (s *Store) ClearRecords(typeURL string) {
|
||||
rawPath := fmt.Sprintf("/databroker_data/%s", typeURL)
|
||||
func (s *Store) ClearRecords() {
|
||||
rawPath := "/databroker_data"
|
||||
s.delete(rawPath)
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ func TestStore(t *testing.T) {
|
|||
}
|
||||
any, _ := ptypes.MarshalAny(u)
|
||||
s.UpdateRecord(&databroker.Record{
|
||||
Version: "v1",
|
||||
Version: 1,
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
|
@ -43,7 +43,7 @@ func TestStore(t *testing.T) {
|
|||
}, v)
|
||||
|
||||
s.UpdateRecord(&databroker.Record{
|
||||
Version: "v2",
|
||||
Version: 2,
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
|
@ -55,7 +55,7 @@ func TestStore(t *testing.T) {
|
|||
assert.Nil(t, v)
|
||||
|
||||
s.UpdateRecord(&databroker.Record{
|
||||
Version: "v1",
|
||||
Version: 3,
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
|
@ -65,7 +65,7 @@ func TestStore(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotNil(t, v)
|
||||
|
||||
s.ClearRecords("type.googleapis.com/user.User")
|
||||
s.ClearRecords()
|
||||
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1"))
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, v)
|
||||
|
|
|
@ -22,16 +22,11 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
|
||||
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceAccountTypeURL = "type.googleapis.com/user.ServiceAccount"
|
||||
sessionTypeURL = "type.googleapis.com/session.Session"
|
||||
userTypeURL = "type.googleapis.com/user.User"
|
||||
)
|
||||
|
||||
// Check implements the envoy auth server gRPC endpoint.
|
||||
func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRequest) (*envoy_service_auth_v2.CheckResponse, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check")
|
||||
|
@ -108,18 +103,18 @@ func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) inte
|
|||
|
||||
state := a.state.Load()
|
||||
|
||||
s, ok := a.store.GetRecordData(sessionTypeURL, sessionID).(*session.Session)
|
||||
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
|
||||
sa, ok := a.store.GetRecordData(serviceAccountTypeURL, sessionID).(*user.ServiceAccount)
|
||||
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
|
||||
if ok {
|
||||
return sa
|
||||
}
|
||||
|
||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: sessionTypeURL,
|
||||
Type: grpcutil.GetTypeURL(new(session.Session)),
|
||||
Id: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -127,10 +122,10 @@ func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) inte
|
|||
return nil
|
||||
}
|
||||
|
||||
if current := a.store.GetRecordData(sessionTypeURL, sessionID); current == nil {
|
||||
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID); current == nil {
|
||||
a.store.UpdateRecord(res.GetRecord())
|
||||
}
|
||||
s, _ = a.store.GetRecordData(sessionTypeURL, sessionID).(*session.Session)
|
||||
s, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
|
||||
|
||||
return s
|
||||
}
|
||||
|
@ -141,13 +136,13 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
|
|||
|
||||
state := a.state.Load()
|
||||
|
||||
u, ok := a.store.GetRecordData(userTypeURL, userID).(*user.User)
|
||||
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
|
||||
if ok {
|
||||
return u
|
||||
}
|
||||
|
||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: userTypeURL,
|
||||
Type: grpcutil.GetTypeURL(new(user.User)),
|
||||
Id: userID,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -155,10 +150,10 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
|
|||
return nil
|
||||
}
|
||||
|
||||
if current := a.store.GetRecordData(userTypeURL, userID); current == nil {
|
||||
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID); current == nil {
|
||||
a.store.UpdateRecord(res.GetRecord())
|
||||
}
|
||||
u, _ = a.store.GetRecordData(userTypeURL, userID).(*user.User)
|
||||
u, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
|
||||
|
||||
return u
|
||||
}
|
||||
|
|
|
@ -325,7 +325,7 @@ func TestSync(t *testing.T) {
|
|||
})
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: in.GetId(),
|
||||
Data: data,
|
||||
|
@ -336,7 +336,7 @@ func TestSync(t *testing.T) {
|
|||
data, _ := ptypes.MarshalAny(&user.User{Id: in.GetId()})
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: in.GetId(),
|
||||
Data: data,
|
||||
|
@ -391,7 +391,7 @@ func TestSync(t *testing.T) {
|
|||
}
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: in.GetId(),
|
||||
Data: data,
|
||||
|
@ -418,7 +418,7 @@ func TestSync(t *testing.T) {
|
|||
})
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: "0001",
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: in.GetId(),
|
||||
Data: data,
|
||||
|
|
198
authorize/run.go
198
authorize/run.go
|
@ -1,198 +0,0 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// Run runs the authorize server.
|
||||
func (a *Authorize) Run(ctx context.Context) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
updateTypes := make(chan []string)
|
||||
eg.Go(func() error {
|
||||
return a.runTypesSyncer(ctx, updateTypes)
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
return a.runDataSyncer(ctx, updateTypes)
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
// WaitForInitialSync waits for the initial sync to complete.
|
||||
func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
|
||||
for typeURL, ch := range a.dataBrokerInitialSync {
|
||||
log.Info().Str("type_url", typeURL).Msg("waiting for initial sync")
|
||||
select {
|
||||
case <-ch:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
log.Info().Str("type_url", typeURL).Msg("initial sync complete")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Authorize) runTypesSyncer(ctx context.Context, updateTypes chan<- []string) error {
|
||||
log.Info().Msg("starting type sync")
|
||||
return tryForever(ctx, func(backoff interface{ Reset() }) error {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync")
|
||||
defer span.End()
|
||||
stream, err := a.state.Load().dataBrokerClient.SyncTypes(ctx, new(emptypb.Empty))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backoff.Reset()
|
||||
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
return stream.Context().Err()
|
||||
case updateTypes <- res.GetTypes():
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Authorize) runDataSyncer(ctx context.Context, updateTypes <-chan []string) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
seen := map[string]struct{}{}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case types := <-updateTypes:
|
||||
for _, dataType := range types {
|
||||
dataType := dataType
|
||||
if _, ok := seen[dataType]; !ok {
|
||||
eg.Go(func() error {
|
||||
return a.runDataTypeSyncer(ctx, dataType)
|
||||
})
|
||||
seen[dataType] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (a *Authorize) runDataTypeSyncer(ctx context.Context, typeURL string) error {
|
||||
var serverVersion, recordVersion string
|
||||
|
||||
log.Info().Str("type_url", typeURL).Msg("starting data initial load")
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.InitialSync")
|
||||
backoff := backoff.NewExponentialBackOff()
|
||||
backoff.MaxElapsedTime = 0
|
||||
for {
|
||||
res, err := databroker.InitialSync(ctx, a.state.Load().dataBrokerClient, &databroker.SyncRequest{
|
||||
Type: typeURL,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("type_url", typeURL).Msg("error getting data")
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff.NextBackOff()):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
serverVersion = res.GetServerVersion()
|
||||
|
||||
for _, record := range res.GetRecords() {
|
||||
a.store.UpdateRecord(record)
|
||||
recordVersion = record.GetVersion()
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
span.End()
|
||||
|
||||
if ch, ok := a.dataBrokerInitialSync[typeURL]; ok {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("type_url", typeURL).Msg("starting data syncer")
|
||||
return tryForever(ctx, func(backoff interface{ Reset() }) error {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync")
|
||||
defer span.End()
|
||||
stream, err := a.state.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
ServerVersion: serverVersion,
|
||||
RecordVersion: recordVersion,
|
||||
Type: typeURL,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backoff.Reset()
|
||||
if res.GetServerVersion() != serverVersion {
|
||||
log.Info().
|
||||
Str("old_version", serverVersion).
|
||||
Str("new_version", res.GetServerVersion()).
|
||||
Str("type_url", typeURL).
|
||||
Msg("detected new server version, clearing data")
|
||||
serverVersion = res.GetServerVersion()
|
||||
recordVersion = ""
|
||||
a.store.ClearRecords(typeURL)
|
||||
}
|
||||
for _, record := range res.GetRecords() {
|
||||
if record.GetVersion() > recordVersion {
|
||||
recordVersion = record.GetVersion()
|
||||
}
|
||||
}
|
||||
|
||||
for _, record := range res.GetRecords() {
|
||||
a.store.UpdateRecord(record)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func tryForever(ctx context.Context, callback func(onSuccess interface{ Reset() }) error) error {
|
||||
backoff := backoff.NewExponentialBackOff()
|
||||
backoff.MaxElapsedTime = 0
|
||||
for {
|
||||
err := callback(backoff)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("sync error")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff.NextBackOff()):
|
||||
}
|
||||
}
|
||||
}
|
41
authorize/sync.go
Normal file
41
authorize/sync.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type dataBrokerSyncer struct {
|
||||
*databroker.Syncer
|
||||
authorize *Authorize
|
||||
signalOnce sync.Once
|
||||
}
|
||||
|
||||
func newDataBrokerSyncer(authorize *Authorize) *dataBrokerSyncer {
|
||||
syncer := &dataBrokerSyncer{
|
||||
authorize: authorize,
|
||||
}
|
||||
syncer.Syncer = databroker.NewSyncer(syncer)
|
||||
return syncer
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return syncer.authorize.state.Load().dataBrokerClient
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
|
||||
syncer.authorize.store.ClearRecords()
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
|
||||
for _, record := range records {
|
||||
syncer.authorize.store.UpdateRecord(record)
|
||||
}
|
||||
|
||||
// the first time we update records we signal the initial sync
|
||||
syncer.signalOnce.Do(func() {
|
||||
close(syncer.authorize.dataBrokerInitialSync)
|
||||
})
|
||||
}
|
|
@ -43,7 +43,7 @@ func loadCachedCredential(serverURL string) *ExecCredential {
|
|||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
var creds ExecCredential
|
||||
err = json.NewDecoder(f).Decode(&creds)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// Package main implements the pomerium-cli.
|
||||
package main
|
||||
|
||||
import (
|
||||
|
|
|
@ -109,9 +109,9 @@ func decodeJWTClaimHeadersHookFunc() mapstructure.DecodeHookFunc {
|
|||
// A StringSlice is a slice of strings.
|
||||
type StringSlice []string
|
||||
|
||||
// NewStringSlice creatse a new StringSlice.
|
||||
// NewStringSlice creates a new StringSlice.
|
||||
func NewStringSlice(values ...string) StringSlice {
|
||||
return StringSlice(values)
|
||||
return values
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -197,7 +197,7 @@ type WeightedURL struct {
|
|||
LbWeight uint32
|
||||
}
|
||||
|
||||
// Validate validates the WeightedURL.
|
||||
// Validate validates that the WeightedURL is valid.
|
||||
func (u *WeightedURL) Validate() error {
|
||||
if u.URL.Hostname() == "" {
|
||||
return errHostnameMustBeSpecified
|
||||
|
@ -227,6 +227,7 @@ func ParseWeightedURL(dst string) (*WeightedURL, error) {
|
|||
return &WeightedURL{*u, w}, nil
|
||||
}
|
||||
|
||||
// String returns the WeightedURL as a string.
|
||||
func (u *WeightedURL) String() string {
|
||||
str := u.URL.String()
|
||||
if u.LbWeight == 0 {
|
||||
|
@ -235,7 +236,7 @@ func (u *WeightedURL) String() string {
|
|||
return fmt.Sprintf("{url=%s, weight=%d}", str, u.LbWeight)
|
||||
}
|
||||
|
||||
// WeightedURLs is a slice of WeightedURL.
|
||||
// WeightedURLs is a slice of WeightedURLs.
|
||||
type WeightedURLs []WeightedURL
|
||||
|
||||
// ParseWeightedUrls parses
|
||||
|
@ -285,9 +286,9 @@ func (urls WeightedURLs) Validate() (HasWeight, error) {
|
|||
}
|
||||
|
||||
if noWeight {
|
||||
return HasWeight(false), nil
|
||||
return false, nil
|
||||
}
|
||||
return HasWeight(true), nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Flatten converts weighted url array into indidual arrays of urls and weights
|
||||
|
@ -311,7 +312,7 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
|
|||
return str, wghts, nil
|
||||
}
|
||||
|
||||
// DecodePolicyBase64Hook creates a mapstructure DecodeHookFunc.
|
||||
// DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data.
|
||||
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
if t != reflect.TypeOf([]Policy{}) {
|
||||
|
@ -332,7 +333,7 @@ func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
|||
return nil, fmt.Errorf("base64 decoding policy data: %w", err)
|
||||
}
|
||||
|
||||
out := []map[interface{}]interface{}{}
|
||||
var out []map[interface{}]interface{}
|
||||
if err = yaml.Unmarshal(bytes, &out); err != nil {
|
||||
return nil, fmt.Errorf("parsing base64-encoded policy data as yaml: %w", err)
|
||||
}
|
||||
|
@ -341,7 +342,7 @@ func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// DecodePolicyHookFunc creates a mapstructure DecodeHookFunc.
|
||||
// DecodePolicyHookFunc returns a Decode Hook for mapstructure.
|
||||
func DecodePolicyHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
if t != reflect.TypeOf(Policy{}) {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// Package databroker contains the databroker service.
|
||||
package databroker
|
||||
|
||||
import (
|
||||
|
@ -5,8 +6,6 @@ import (
|
|||
"encoding/base64"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/golang/protobuf/ptypes/empty"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/databroker"
|
||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -52,13 +51,6 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) {
|
|||
srv.sharedKey.Store(bs)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Delete(ctx context.Context, req *databrokerpb.DeleteRequest) (*empty.Empty, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.Delete(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
return nil, err
|
||||
|
@ -66,13 +58,6 @@ func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetReque
|
|||
return srv.server.Get(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) GetAll(ctx context.Context, req *databrokerpb.GetAllRequest) (*databrokerpb.GetAllResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.GetAll(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
return nil, err
|
||||
|
@ -80,11 +65,11 @@ func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryR
|
|||
return srv.server.Query(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Set(ctx context.Context, req *databrokerpb.SetRequest) (*databrokerpb.SetResponse, error) {
|
||||
func (srv *dataBrokerServer) Put(ctx context.Context, req *databrokerpb.PutRequest) (*databrokerpb.PutResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.Set(ctx, req)
|
||||
return srv.server.Put(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error {
|
||||
|
@ -94,16 +79,9 @@ func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrok
|
|||
return srv.server.Sync(req, stream)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) GetTypes(ctx context.Context, req *empty.Empty) (*databrokerpb.GetTypesResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.GetTypes(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) SyncTypes(req *empty.Empty, stream databrokerpb.DataBrokerService_SyncTypesServer) error {
|
||||
func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, stream databrokerpb.DataBrokerService_SyncLatestServer) error {
|
||||
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.server.SyncTypes(req, stream)
|
||||
return srv.server.SyncLatest(req, stream)
|
||||
}
|
||||
|
|
|
@ -7,9 +7,10 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
|
||||
|
@ -51,44 +52,44 @@ func TestServerSync(t *testing.T) {
|
|||
any, _ := ptypes.MarshalAny(new(user.User))
|
||||
numRecords := 200
|
||||
|
||||
var serverVersion uint64
|
||||
|
||||
for i := 0; i < numRecords; i++ {
|
||||
c.Set(ctx, &databroker.SetRequest{Type: any.TypeUrl, Id: strconv.Itoa(i), Data: any})
|
||||
res, err := c.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.TypeUrl,
|
||||
Id: strconv.Itoa(i),
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
serverVersion = res.GetServerVersion()
|
||||
}
|
||||
|
||||
t.Run("Sync ok", func(t *testing.T) {
|
||||
client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()})
|
||||
client, _ := c.Sync(ctx, &databroker.SyncRequest{
|
||||
ServerVersion: serverVersion,
|
||||
})
|
||||
count := 0
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
_, err := client.Recv()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
count += len(res.Records)
|
||||
count++
|
||||
if count == numRecords {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("Error occurred while syncing", func(t *testing.T) {
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()})
|
||||
count := 0
|
||||
numRecordsWanted := 100
|
||||
cancelFuncCalled := false
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
if err != nil {
|
||||
assert.True(t, cancelFuncCalled)
|
||||
break
|
||||
}
|
||||
count += len(res.Records)
|
||||
if count == numRecordsWanted {
|
||||
cancelFunc()
|
||||
cancelFuncCalled = true
|
||||
}
|
||||
}
|
||||
t.Run("Aborted", func(t *testing.T) {
|
||||
client, err := c.Sync(ctx, &databroker.SyncRequest{
|
||||
ServerVersion: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = client.Recv()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.Aborted, status.Code(err))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -104,19 +105,25 @@ func BenchmarkSync(b *testing.B) {
|
|||
numRecords := 10000
|
||||
|
||||
for i := 0; i < numRecords; i++ {
|
||||
c.Set(ctx, &databroker.SetRequest{Type: any.TypeUrl, Id: strconv.Itoa(i), Data: any})
|
||||
_, _ = c.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.TypeUrl,
|
||||
Id: strconv.Itoa(i),
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()})
|
||||
client, _ := c.Sync(ctx, &databroker.SyncRequest{})
|
||||
count := 0
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
_, err := client.Recv()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
count += len(res.Records)
|
||||
count++
|
||||
if count == numRecords {
|
||||
break
|
||||
}
|
||||
|
|
|
@ -31,10 +31,12 @@ func (c *DataBroker) RefreshUser(ctx context.Context, req *directory.RefreshUser
|
|||
return nil, err
|
||||
}
|
||||
|
||||
_, err = c.dataBrokerServer.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
_, err = c.dataBrokerServer.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -61,6 +61,7 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
// Result is a result used for rendering.
|
||||
type Result struct {
|
||||
Headers map[string]string `json:"headers"`
|
||||
Method string `json:"method"`
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// Package filemgr defines a Manager for managing files for the controlplane.
|
||||
package filemgr
|
||||
|
||||
import (
|
||||
|
|
|
@ -13,8 +13,6 @@ var (
|
|||
// DefaultDeletePermanentlyAfter is the default amount of time to wait before deleting
|
||||
// a record permanently.
|
||||
DefaultDeletePermanentlyAfter = time.Hour
|
||||
// DefaultBTreeDegree is the default number of items to store in each node of the BTree.
|
||||
DefaultBTreeDegree = 8
|
||||
// DefaultStorageType is the default storage type that Server use
|
||||
DefaultStorageType = "memory"
|
||||
// DefaultGetAllPageSize is the default page size for GetAll calls.
|
||||
|
@ -23,7 +21,6 @@ var (
|
|||
|
||||
type serverConfig struct {
|
||||
deletePermanentlyAfter time.Duration
|
||||
btreeDegree int
|
||||
secret []byte
|
||||
storageType string
|
||||
storageConnectionString string
|
||||
|
@ -36,7 +33,6 @@ type serverConfig struct {
|
|||
func newServerConfig(options ...ServerOption) *serverConfig {
|
||||
cfg := new(serverConfig)
|
||||
WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg)
|
||||
WithBTreeDegree(DefaultBTreeDegree)(cfg)
|
||||
WithStorageType(DefaultStorageType)(cfg)
|
||||
WithGetAllPageSize(DefaultGetAllPageSize)(cfg)
|
||||
for _, option := range options {
|
||||
|
@ -48,13 +44,6 @@ func newServerConfig(options ...ServerOption) *serverConfig {
|
|||
// A ServerOption customizes the server.
|
||||
type ServerOption func(*serverConfig)
|
||||
|
||||
// WithBTreeDegree sets the number of items to store in each node of the BTree.
|
||||
func WithBTreeDegree(degree int) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.btreeDegree = degree
|
||||
}
|
||||
}
|
||||
|
||||
// WithDeletePermanentlyAfter sets the deletePermanentlyAfter duration.
|
||||
// If a record is deleted via Delete, it will be permanently deleted after
|
||||
// the given duration.
|
||||
|
|
|
@ -3,12 +3,7 @@ package databroker
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
|
@ -17,15 +12,9 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
)
|
||||
|
||||
var configTypeURL string
|
||||
|
||||
func init() {
|
||||
any, _ := ptypes.MarshalAny(new(configpb.Config))
|
||||
configTypeURL = any.GetTypeUrl()
|
||||
}
|
||||
|
||||
// ConfigSource provides a new Config source that decorates an underlying config with
|
||||
// configuration derived from the data broker.
|
||||
type ConfigSource struct {
|
||||
|
@ -35,8 +24,6 @@ type ConfigSource struct {
|
|||
dbConfigs map[string]*configpb.Config
|
||||
updaterHash uint64
|
||||
cancel func()
|
||||
serverVersion string
|
||||
recordVersion string
|
||||
|
||||
config.ChangeDispatcher
|
||||
}
|
||||
|
@ -188,78 +175,51 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
|||
ctx := context.Background()
|
||||
ctx, src.cancel = context.WithCancel(ctx)
|
||||
|
||||
go tryForever(ctx, func(onSuccess func()) error {
|
||||
src.mu.Lock()
|
||||
serverVersion, recordVersion := src.serverVersion, src.recordVersion
|
||||
src.mu.Unlock()
|
||||
|
||||
stream, err := client.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: configTypeURL,
|
||||
ServerVersion: serverVersion,
|
||||
RecordVersion: recordVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
onSuccess()
|
||||
|
||||
if len(res.GetRecords()) > 0 {
|
||||
src.onSync(res.GetRecords())
|
||||
for _, record := range res.GetRecords() {
|
||||
recordVersion = record.GetVersion()
|
||||
}
|
||||
}
|
||||
|
||||
src.mu.Lock()
|
||||
src.serverVersion, src.recordVersion = res.GetServerVersion(), recordVersion
|
||||
src.mu.Unlock()
|
||||
}
|
||||
})
|
||||
syncer := databroker.NewSyncer(&syncerHandler{
|
||||
client: client,
|
||||
src: src,
|
||||
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))))
|
||||
go func() { _ = syncer.Run(ctx) }()
|
||||
}
|
||||
|
||||
func (src *ConfigSource) onSync(records []*databroker.Record) {
|
||||
src.mu.Lock()
|
||||
type syncerHandler struct {
|
||||
src *ConfigSource
|
||||
client databroker.DataBrokerServiceClient
|
||||
}
|
||||
|
||||
func (s *syncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *syncerHandler) ClearRecords(ctx context.Context) {
|
||||
s.src.mu.Lock()
|
||||
s.src.dbConfigs = map[string]*configpb.Config{}
|
||||
s.src.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *syncerHandler) UpdateRecords(ctx context.Context, records []*databroker.Record) {
|
||||
if len(records) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.src.mu.Lock()
|
||||
for _, record := range records {
|
||||
if record.GetDeletedAt() != nil {
|
||||
delete(src.dbConfigs, record.GetId())
|
||||
delete(s.src.dbConfigs, record.GetId())
|
||||
continue
|
||||
}
|
||||
|
||||
var cfgpb configpb.Config
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &cfgpb)
|
||||
err := record.GetData().UnmarshalTo(&cfgpb)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("databroker: error decoding config")
|
||||
delete(src.dbConfigs, record.GetId())
|
||||
delete(s.src.dbConfigs, record.GetId())
|
||||
continue
|
||||
}
|
||||
|
||||
src.dbConfigs[record.GetId()] = &cfgpb
|
||||
s.src.dbConfigs[record.GetId()] = &cfgpb
|
||||
}
|
||||
src.mu.Unlock()
|
||||
s.src.mu.Unlock()
|
||||
|
||||
src.rebuild(false)
|
||||
}
|
||||
|
||||
func tryForever(ctx context.Context, callback func(onSuccess func()) error) {
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
for {
|
||||
err := callback(bo.Reset)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Warn().Err(err).Msg("sync error")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
}
|
||||
s.src.rebuild(false)
|
||||
}
|
||||
|
|
|
@ -7,9 +7,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
|
@ -24,7 +24,7 @@ func TestConfigSource(t *testing.T) {
|
|||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer li.Close()
|
||||
defer func() { _ = li.Close() }()
|
||||
|
||||
dataBrokerServer := New()
|
||||
srv := grpc.NewServer()
|
||||
|
@ -45,7 +45,7 @@ func TestConfigSource(t *testing.T) {
|
|||
})
|
||||
cfgs <- src.GetConfig()
|
||||
|
||||
data, _ := ptypes.MarshalAny(&configpb.Config{
|
||||
data, _ := anypb.New(&configpb.Config{
|
||||
Name: "config",
|
||||
Routes: []*configpb.Route{
|
||||
{
|
||||
|
@ -54,10 +54,12 @@ func TestConfigSource(t *testing.T) {
|
|||
},
|
||||
},
|
||||
})
|
||||
_, _ = dataBrokerServer.Set(ctx, &databroker.SetRequest{
|
||||
Type: configTypeURL,
|
||||
Id: "1",
|
||||
Data: data,
|
||||
_, _ = dataBrokerServer.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: data.TypeUrl,
|
||||
Id: "1",
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
|
|
|
@ -4,28 +4,24 @@ package databroker
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/empty"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||
"github.com/pomerium/pomerium/pkg/storage/redis"
|
||||
|
@ -34,81 +30,57 @@ import (
|
|||
const (
|
||||
recordTypeServerVersion = "server_version"
|
||||
serverVersionKey = "version"
|
||||
syncBatchSize = 100
|
||||
)
|
||||
|
||||
// newUUID returns a new UUID. This make it easy to stub out in tests.
|
||||
var newUUID = uuid.New
|
||||
|
||||
// Server implements the databroker service using an in memory database.
|
||||
type Server struct {
|
||||
cfg *serverConfig
|
||||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
byType map[string]storage.Backend
|
||||
onTypechange *signal.Signal
|
||||
mu sync.RWMutex
|
||||
version uint64
|
||||
backend storage.Backend
|
||||
}
|
||||
|
||||
// New creates a new server.
|
||||
func New(options ...ServerOption) *Server {
|
||||
srv := &Server{
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]storage.Backend),
|
||||
onTypechange: signal.New(),
|
||||
}
|
||||
srv.UpdateConfig(options...)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
srv.mu.RLock()
|
||||
tm := time.Now().Add(-srv.cfg.deletePermanentlyAfter)
|
||||
srv.mu.RUnlock()
|
||||
|
||||
var recordTypes []string
|
||||
srv.mu.RLock()
|
||||
for recordType := range srv.byType {
|
||||
recordTypes = append(recordTypes, recordType)
|
||||
}
|
||||
srv.mu.RUnlock()
|
||||
|
||||
for _, recordType := range recordTypes {
|
||||
db, _, err := srv.getDB(recordType, true)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
db.ClearDeleted(context.Background(), tm)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return srv
|
||||
}
|
||||
|
||||
func (srv *Server) initVersion() {
|
||||
dbServerVersion, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||
db, _, err := srv.getBackendLocked()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to init server version")
|
||||
return
|
||||
}
|
||||
|
||||
// Get version from storage first.
|
||||
if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
|
||||
var sv databroker.ServerVersion
|
||||
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
|
||||
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
|
||||
srv.version = sv.Version
|
||||
r, err := db.Get(context.Background(), recordTypeServerVersion, serverVersionKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
var sv wrapperspb.UInt64Value
|
||||
if err := r.GetData().UnmarshalTo(&sv); err == nil {
|
||||
srv.log.Debug().Uint64("server_version", sv.Value).Msg("got db version from Backend")
|
||||
srv.version = sv.Value
|
||||
}
|
||||
return
|
||||
case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one
|
||||
case err != nil:
|
||||
log.Error().Err(err).Msg("failed to retrieve server version")
|
||||
return
|
||||
}
|
||||
|
||||
srv.version = newUUID().String()
|
||||
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
|
||||
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
|
||||
srv.version = cryptutil.NewRandomUInt64()
|
||||
data, _ := anypb.New(wrapperspb.UInt64(srv.version))
|
||||
if err := db.Put(context.Background(), &databroker.Record{
|
||||
Type: recordTypeServerVersion,
|
||||
Id: serverVersionKey,
|
||||
Data: data,
|
||||
}); err != nil {
|
||||
srv.log.Warn().Err(err).Msg("failed to save server version.")
|
||||
}
|
||||
}
|
||||
|
@ -125,121 +97,49 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
|
|||
}
|
||||
srv.cfg = cfg
|
||||
|
||||
for t, db := range srv.byType {
|
||||
err := db.Close()
|
||||
if srv.backend != nil {
|
||||
err := srv.backend.Close()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("databroker: error closing backend")
|
||||
log.Error().Err(err).Msg("databroker: error closing backend")
|
||||
}
|
||||
delete(srv.byType, t)
|
||||
srv.backend = nil
|
||||
}
|
||||
|
||||
srv.initVersion()
|
||||
}
|
||||
|
||||
// Delete deletes a record from the in-memory list.
|
||||
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
Str("type", req.GetType()).
|
||||
Str("id", req.GetId()).
|
||||
Msg("delete")
|
||||
|
||||
db, _, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Delete(ctx, req.GetId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return new(empty.Empty), nil
|
||||
}
|
||||
|
||||
// Get gets a record from the in-memory list.
|
||||
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Get")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
Str("peer", grpcutil.GetPeerAddr(ctx)).
|
||||
Str("type", req.GetType()).
|
||||
Str("id", req.GetId()).
|
||||
Msg("get")
|
||||
|
||||
db, _, err := srv.getDB(req.GetType(), true)
|
||||
db, _, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record, err := db.Get(ctx, req.GetId())
|
||||
if err != nil {
|
||||
record, err := db.Get(ctx, req.GetType(), req.GetId())
|
||||
switch {
|
||||
case errors.Is(err, storage.ErrNotFound):
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
if record.DeletedAt != nil {
|
||||
case err != nil:
|
||||
return nil, status.Error(codes.Internal, err.Error())
|
||||
case record.DeletedAt != nil:
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
return &databroker.GetResponse{Record: record}, nil
|
||||
}
|
||||
|
||||
// GetAll gets all the records from the backend.
|
||||
func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*databroker.GetAllResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.GetAll")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
Str("type", req.GetType()).
|
||||
Msg("get all")
|
||||
|
||||
db, version, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
all, err := db.List(ctx, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// sort by record version
|
||||
sort.Slice(all, func(i, j int) bool {
|
||||
return all[i].Version < all[j].Version
|
||||
})
|
||||
|
||||
var recordVersion string
|
||||
records := make([]*databroker.Record, 0, len(all))
|
||||
for _, record := range all {
|
||||
// skip previous page records
|
||||
if record.GetVersion() <= req.PageToken {
|
||||
continue
|
||||
}
|
||||
|
||||
recordVersion = record.GetVersion()
|
||||
if record.DeletedAt == nil {
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
// stop when we've hit the page size
|
||||
if len(records) >= srv.cfg.getAllPageSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
nextPageToken := recordVersion
|
||||
if len(records) < srv.cfg.getAllPageSize {
|
||||
nextPageToken = ""
|
||||
}
|
||||
|
||||
return &databroker.GetAllResponse{
|
||||
ServerVersion: version,
|
||||
RecordVersion: recordVersion,
|
||||
Records: records,
|
||||
NextPageToken: nextPageToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Query queries for records.
|
||||
func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Query")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
Str("peer", grpcutil.GetPeerAddr(ctx)).
|
||||
Str("type", req.GetType()).
|
||||
Str("query", req.GetQuery()).
|
||||
Int64("offset", req.GetOffset()).
|
||||
|
@ -248,21 +148,25 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
|||
|
||||
query := strings.ToLower(req.GetQuery())
|
||||
|
||||
db, _, err := srv.getDB(req.GetType(), true)
|
||||
db, _, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
all, err := db.List(ctx, "")
|
||||
all, _, err := db.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var filtered []*databroker.Record
|
||||
for _, record := range all {
|
||||
if record.DeletedAt == nil && storage.MatchAny(record.GetData(), query) {
|
||||
filtered = append(filtered, record)
|
||||
if record.GetType() != req.GetType() {
|
||||
continue
|
||||
}
|
||||
if query != "" && !storage.MatchAny(record.GetData(), query) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, record)
|
||||
}
|
||||
|
||||
records, totalCount := databroker.ApplyOffsetAndLimit(filtered, int(req.GetOffset()), int(req.GetLimit()))
|
||||
|
@ -272,197 +176,160 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Set updates a record in the in-memory list, or adds a new one.
|
||||
func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databroker.SetResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Set")
|
||||
// Put updates a record in the in-memory list, or adds a new one.
|
||||
func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databroker.PutResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Put")
|
||||
defer span.End()
|
||||
record := req.GetRecord()
|
||||
|
||||
srv.log.Info().
|
||||
Str("type", req.GetType()).
|
||||
Str("id", req.GetId()).
|
||||
Msg("set")
|
||||
Str("peer", grpcutil.GetPeerAddr(ctx)).
|
||||
Str("type", record.GetType()).
|
||||
Str("id", record.GetId()).
|
||||
Msg("put")
|
||||
|
||||
db, version, err := srv.getDB(req.GetType(), true)
|
||||
db, version, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
||||
if err := db.Put(ctx, record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record, err := db.Get(ctx, req.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &databroker.SetResponse{
|
||||
Record: record,
|
||||
return &databroker.PutResponse{
|
||||
ServerVersion: version,
|
||||
Record: record,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (srv *Server) doSync(ctx context.Context,
|
||||
serverVersion string, recordVersion *string,
|
||||
db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
||||
updated, err := db.List(ctx, *recordVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(updated) == 0 {
|
||||
return nil
|
||||
}
|
||||
*recordVersion = updated[len(updated)-1].Version
|
||||
for i := 0; i < len(updated); i += syncBatchSize {
|
||||
j := i + syncBatchSize
|
||||
if j > len(updated) {
|
||||
j = len(updated)
|
||||
}
|
||||
if err := stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: serverVersion,
|
||||
Records: updated[i:j],
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync streams updates for the given record type.
|
||||
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
Str("type", req.GetType()).
|
||||
Str("server_version", req.GetServerVersion()).
|
||||
Str("record_version", req.GetRecordVersion()).
|
||||
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
|
||||
Uint64("server_version", req.GetServerVersion()).
|
||||
Uint64("record_version", req.GetRecordVersion()).
|
||||
Msg("sync")
|
||||
|
||||
db, serverVersion, err := srv.getDB(req.GetType(), true)
|
||||
backend, serverVersion, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordVersion := req.GetRecordVersion()
|
||||
// reset record version if the server versions don't match
|
||||
if req.GetServerVersion() != serverVersion {
|
||||
serverVersion = req.GetServerVersion()
|
||||
recordVersion = ""
|
||||
// send the new server version to the client
|
||||
err := stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: serverVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return status.Errorf(codes.Aborted, "invalid server version, expected: %d", req.GetServerVersion())
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var ch <-chan struct{}
|
||||
if !req.GetNoWait() {
|
||||
ch = db.Watch(ctx)
|
||||
recordStream, err := backend.Sync(ctx, req.GetRecordVersion())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = recordStream.Close() }()
|
||||
|
||||
for recordStream.Next(true) {
|
||||
err = stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: serverVersion,
|
||||
Record: recordStream.Record(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Do first sync, so we won't miss anything.
|
||||
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||
return recordStream.Err()
|
||||
}
|
||||
|
||||
// SyncLatest returns the latest value of every record in the databroker as a stream of records.
|
||||
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
|
||||
Str("type", req.GetType()).
|
||||
Msg("sync latest")
|
||||
|
||||
backend, serverVersion, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if req.GetNoWait() {
|
||||
return nil
|
||||
ctx := stream.Context()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
records, latestRecordVersion, err := backend.GetAll(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for range ch {
|
||||
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTypes returns all the known record types.
|
||||
func (srv *Server) GetTypes(ctx context.Context, _ *emptypb.Empty) (*databroker.GetTypesResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.GetTypes")
|
||||
defer span.End()
|
||||
var recordTypes []string
|
||||
srv.mu.RLock()
|
||||
for recordType := range srv.byType {
|
||||
recordTypes = append(recordTypes, recordType)
|
||||
}
|
||||
srv.mu.RUnlock()
|
||||
|
||||
sort.Strings(recordTypes)
|
||||
return &databroker.GetTypesResponse{
|
||||
Types: recordTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SyncTypes synchronizes all the known record types.
|
||||
func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerService_SyncTypesServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncTypes")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
Msg("sync types")
|
||||
|
||||
ch := srv.onTypechange.Bind()
|
||||
defer srv.onTypechange.Unbind(ch)
|
||||
|
||||
var prev []string
|
||||
for {
|
||||
res, err := srv.GetTypes(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if prev == nil || !reflect.DeepEqual(prev, res.Types) {
|
||||
err := stream.Send(res)
|
||||
for _, record := range records {
|
||||
if req.GetType() == "" || req.GetType() == record.GetType() {
|
||||
err = stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Record{
|
||||
Record: record,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prev = res.Types
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
return stream.Context().Err()
|
||||
case <-ch:
|
||||
}
|
||||
}
|
||||
|
||||
// always send the server version last in case there are no records
|
||||
return stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Versions{
|
||||
Versions: &databroker.Versions{
|
||||
ServerVersion: serverVersion,
|
||||
LatestRecordVersion: latestRecordVersion,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (srv *Server) getDB(recordType string, lock bool) (db storage.Backend, version string, err error) {
|
||||
func (srv *Server) getBackend() (backend storage.Backend, version uint64, err error) {
|
||||
// double-checked locking:
|
||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||
if lock {
|
||||
srv.mu.RLock()
|
||||
}
|
||||
db = srv.byType[recordType]
|
||||
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
|
||||
srv.mu.RLock()
|
||||
backend = srv.backend
|
||||
version = srv.version
|
||||
if lock {
|
||||
srv.mu.RUnlock()
|
||||
}
|
||||
if db == nil {
|
||||
if lock {
|
||||
srv.mu.Lock()
|
||||
}
|
||||
db = srv.byType[recordType]
|
||||
srv.mu.RUnlock()
|
||||
if backend == nil {
|
||||
srv.mu.Lock()
|
||||
backend = srv.backend
|
||||
version = srv.version
|
||||
var err error
|
||||
if db == nil {
|
||||
db, err = srv.newDB(recordType)
|
||||
srv.byType[recordType] = db
|
||||
defer srv.onTypechange.Broadcast()
|
||||
}
|
||||
if lock {
|
||||
srv.mu.Unlock()
|
||||
if backend == nil {
|
||||
backend, err = srv.newBackendLocked()
|
||||
srv.backend = backend
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
return db, version, nil
|
||||
return backend, version, nil
|
||||
}
|
||||
|
||||
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
||||
func (srv *Server) getBackendLocked() (backend storage.Backend, version uint64, err error) {
|
||||
backend = srv.backend
|
||||
version = srv.version
|
||||
if backend == nil {
|
||||
var err error
|
||||
backend, err = srv.newBackendLocked()
|
||||
srv.backend = backend
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
return backend, version, nil
|
||||
}
|
||||
|
||||
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to read databroker CA file")
|
||||
|
@ -478,11 +345,12 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
|||
|
||||
switch srv.cfg.storageType {
|
||||
case config.StorageInMemoryName:
|
||||
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
||||
srv.log.Info().Msg("using in-memory store")
|
||||
return inmemory.New(), nil
|
||||
case config.StorageRedisName:
|
||||
db, err = redis.New(
|
||||
srv.log.Info().Msg("using redis store")
|
||||
backend, err = redis.New(
|
||||
srv.cfg.storageConnectionString,
|
||||
redis.WithRecordType(recordType),
|
||||
redis.WithTLSConfig(tlsConfig),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -492,10 +360,10 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
|||
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
||||
}
|
||||
if srv.cfg.secret != nil {
|
||||
db, err = storage.NewEncryptedBackend(srv.cfg.secret, db)
|
||||
backend, err = storage.NewEncryptedBackend(srv.cfg.secret, backend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return db, nil
|
||||
return backend, nil
|
||||
}
|
||||
|
|
|
@ -2,101 +2,27 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func newServer(cfg *serverConfig) *Server {
|
||||
return &Server{
|
||||
version: uuid.New().String(),
|
||||
version: 11,
|
||||
cfg: cfg,
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]storage.Backend),
|
||||
onTypechange: signal.New(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_initVersion(t *testing.T) {
|
||||
cfg := newServerConfig()
|
||||
t.Run("nil db", func(t *testing.T) {
|
||||
srvVersion := uuid.New()
|
||||
oldNewUUID := newUUID
|
||||
newUUID = func() uuid.UUID {
|
||||
return srvVersion
|
||||
}
|
||||
defer func() { newUUID = oldNewUUID }()
|
||||
|
||||
srv := newServer(cfg)
|
||||
srv.byType[recordTypeServerVersion] = nil
|
||||
srv.initVersion()
|
||||
assert.Equal(t, srvVersion.String(), srv.version)
|
||||
})
|
||||
t.Run("new server with random version", func(t *testing.T) {
|
||||
srvVersion := uuid.New()
|
||||
oldNewUUID := newUUID
|
||||
newUUID = func() uuid.UUID {
|
||||
return srvVersion
|
||||
}
|
||||
defer func() { newUUID = oldNewUUID }()
|
||||
|
||||
srv := newServer(cfg)
|
||||
ctx := context.Background()
|
||||
db, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||
require.NoError(t, err)
|
||||
r, err := db.Get(ctx, serverVersionKey)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, r)
|
||||
srv.initVersion()
|
||||
assert.Equal(t, srvVersion.String(), srv.version)
|
||||
r, err = db.Get(ctx, serverVersionKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
var sv databroker.ServerVersion
|
||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||
assert.Equal(t, srvVersion.String(), sv.Version)
|
||||
})
|
||||
t.Run("init version twice should get the same version", func(t *testing.T) {
|
||||
srv := newServer(cfg)
|
||||
ctx := context.Background()
|
||||
db, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||
require.NoError(t, err)
|
||||
r, err := db.Get(ctx, serverVersionKey)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, r)
|
||||
|
||||
srv.initVersion()
|
||||
srvVersion := srv.version
|
||||
|
||||
r, err = db.Get(ctx, serverVersionKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
var sv databroker.ServerVersion
|
||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||
assert.Equal(t, srvVersion, sv.Version)
|
||||
|
||||
// re-init version should get the same value as above
|
||||
srv.version = "foo"
|
||||
srv.initVersion()
|
||||
assert.Equal(t, srvVersion, srv.version)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Get(t *testing.T) {
|
||||
cfg := newServerConfig()
|
||||
t.Run("ignore deleted", func(t *testing.T) {
|
||||
|
@ -107,15 +33,22 @@ func TestServer_Get(t *testing.T) {
|
|||
any, err := anypb.New(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
srv.Set(context.Background(), &databroker.SetRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
_, err = srv.Put(context.Background(), &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
srv.Delete(context.Background(), &databroker.DeleteRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
assert.NoError(t, err)
|
||||
_, err = srv.Put(context.Background(), &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
DeletedAt: timestamppb.Now(),
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = srv.Get(context.Background(), &databroker.GetRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
|
@ -124,61 +57,3 @@ func TestServer_Get(t *testing.T) {
|
|||
assert.Equal(t, codes.NotFound, status.Code(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_GetAll(t *testing.T) {
|
||||
cfg := newServerConfig(
|
||||
WithGetAllPageSize(5),
|
||||
)
|
||||
t.Run("ignore deleted", func(t *testing.T) {
|
||||
srv := newServer(cfg)
|
||||
|
||||
s := new(session.Session)
|
||||
s.Id = "1"
|
||||
any, err := anypb.New(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
srv.Set(context.Background(), &databroker.SetRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
})
|
||||
srv.Delete(context.Background(), &databroker.DeleteRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
})
|
||||
res, err := srv.GetAll(context.Background(), &databroker.GetAllRequest{
|
||||
Type: any.TypeUrl,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res.GetRecords(), 0)
|
||||
})
|
||||
t.Run("paging", func(t *testing.T) {
|
||||
srv := newServer(cfg)
|
||||
|
||||
any, err := anypb.New(wrapperspb.String("TEST"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 7; i++ {
|
||||
srv.Set(context.Background(), &databroker.SetRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: fmt.Sprint(i),
|
||||
Data: any,
|
||||
})
|
||||
}
|
||||
|
||||
res, err := srv.GetAll(context.Background(), &databroker.GetAllRequest{
|
||||
Type: any.TypeUrl,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res.GetRecords(), 5)
|
||||
assert.Equal(t, res.GetNextPageToken(), "000000000005")
|
||||
|
||||
res, err = srv.GetAll(context.Background(), &databroker.GetAllRequest{
|
||||
Type: any.TypeUrl,
|
||||
PageToken: res.GetNextPageToken(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res.GetRecords(), 2)
|
||||
assert.Equal(t, res.GetNextPageToken(), "")
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,14 +7,14 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/google/btree"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/identity/identity"
|
||||
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -37,13 +38,8 @@ type Authenticator interface {
|
|||
}
|
||||
|
||||
type (
|
||||
sessionMessage struct {
|
||||
record *databroker.Record
|
||||
session *session.Session
|
||||
}
|
||||
userMessage struct {
|
||||
record *databroker.Record
|
||||
user *user.User
|
||||
updateRecordsMessage struct {
|
||||
records []*databroker.Record
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -52,19 +48,13 @@ type Manager struct {
|
|||
cfg *atomicConfig
|
||||
log zerolog.Logger
|
||||
|
||||
sessions sessionCollection
|
||||
sessionScheduler *scheduler.Scheduler
|
||||
userScheduler *scheduler.Scheduler
|
||||
|
||||
users userCollection
|
||||
userScheduler *scheduler.Scheduler
|
||||
|
||||
directoryUsers map[string]*directory.User
|
||||
directoryUsersServerVersion string
|
||||
directoryUsersRecordVersion string
|
||||
|
||||
directoryGroups map[string]*directory.Group
|
||||
directoryGroupsServerVersion string
|
||||
directoryGroupsRecordVersion string
|
||||
sessions sessionCollection
|
||||
users userCollection
|
||||
directoryUsers map[string]*directory.User
|
||||
directoryGroups map[string]*directory.Group
|
||||
|
||||
directoryNextRefresh time.Time
|
||||
|
||||
|
@ -79,14 +69,8 @@ func New(
|
|||
cfg: newAtomicConfig(newConfig()),
|
||||
log: log.With().Str("service", "identity_manager").Logger(),
|
||||
|
||||
sessions: sessionCollection{
|
||||
BTree: btree.New(8),
|
||||
},
|
||||
sessionScheduler: scheduler.New(),
|
||||
users: userCollection{
|
||||
BTree: btree.New(8),
|
||||
},
|
||||
userScheduler: scheduler.New(),
|
||||
userScheduler: scheduler.New(),
|
||||
|
||||
dataBrokerSemaphore: semaphore.NewWeighted(dataBrokerParallelism),
|
||||
}
|
||||
|
@ -101,52 +85,47 @@ func (mgr *Manager) UpdateConfig(options ...Option) {
|
|||
|
||||
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
||||
func (mgr *Manager) Run(ctx context.Context) error {
|
||||
err := mgr.initDirectoryGroups(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize directory groups: %w", err)
|
||||
}
|
||||
update := make(chan updateRecordsMessage, 1)
|
||||
clear := make(chan struct{}, 1)
|
||||
|
||||
err = mgr.initDirectoryUsers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize directory users: %w", err)
|
||||
}
|
||||
syncer := newDataBrokerSyncer(mgr.cfg, mgr.log, update, clear)
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
updatedSession := make(chan sessionMessage, 1)
|
||||
eg.Go(func() error {
|
||||
return mgr.syncSessions(ctx, updatedSession)
|
||||
return syncer.Run(ctx)
|
||||
})
|
||||
|
||||
updatedUser := make(chan userMessage, 1)
|
||||
eg.Go(func() error {
|
||||
return mgr.syncUsers(ctx, updatedUser)
|
||||
})
|
||||
|
||||
updatedDirectoryGroup := make(chan *directory.Group, 1)
|
||||
eg.Go(func() error {
|
||||
return mgr.syncDirectoryGroups(ctx, updatedDirectoryGroup)
|
||||
})
|
||||
|
||||
updatedDirectoryUser := make(chan *directory.User, 1)
|
||||
eg.Go(func() error {
|
||||
return mgr.syncDirectoryUsers(ctx, updatedDirectoryUser)
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
return mgr.refreshLoop(ctx, updatedSession, updatedUser, updatedDirectoryUser, updatedDirectoryGroup)
|
||||
return mgr.refreshLoop(ctx, update, clear)
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshLoop(
|
||||
ctx context.Context,
|
||||
updatedSession <-chan sessionMessage,
|
||||
updatedUser <-chan userMessage,
|
||||
updatedDirectoryUser <-chan *directory.User,
|
||||
updatedDirectoryGroup <-chan *directory.Group,
|
||||
) error {
|
||||
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
|
||||
// wait for initial sync
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-clear:
|
||||
mgr.directoryGroups = make(map[string]*directory.Group)
|
||||
mgr.directoryUsers = make(map[string]*directory.User)
|
||||
mgr.sessions = sessionCollection{BTree: btree.New(8)}
|
||||
mgr.users = userCollection{BTree: btree.New(8)}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case msg := <-update:
|
||||
mgr.onUpdateRecords(ctx, msg)
|
||||
}
|
||||
|
||||
mgr.log.Info().
|
||||
Int("directory_groups", len(mgr.directoryGroups)).
|
||||
Int("directory_users", len(mgr.directoryUsers)).
|
||||
Int("sessions", mgr.sessions.Len()).
|
||||
Int("users", mgr.users.Len()).
|
||||
Msg("initial sync complete")
|
||||
|
||||
// start refreshing
|
||||
maxWait := time.Minute * 10
|
||||
nextTime := time.Now().Add(maxWait)
|
||||
if mgr.directoryNextRefresh.Before(nextTime) {
|
||||
|
@ -160,14 +139,13 @@ func (mgr *Manager) refreshLoop(
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s := <-updatedSession:
|
||||
mgr.onUpdateSession(ctx, s)
|
||||
case u := <-updatedUser:
|
||||
mgr.onUpdateUser(ctx, u)
|
||||
case du := <-updatedDirectoryUser:
|
||||
mgr.onUpdateDirectoryUser(ctx, du)
|
||||
case dg := <-updatedDirectoryGroup:
|
||||
mgr.onUpdateDirectoryGroup(ctx, dg)
|
||||
case <-clear:
|
||||
mgr.directoryGroups = make(map[string]*directory.Group)
|
||||
mgr.directoryUsers = make(map[string]*directory.User)
|
||||
mgr.sessions = sessionCollection{BTree: btree.New(8)}
|
||||
mgr.users = userCollection{BTree: btree.New(8)}
|
||||
case msg := <-update:
|
||||
mgr.onUpdateRecords(ctx, msg)
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
|
@ -249,7 +227,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
curDG, ok := mgr.directoryGroups[groupID]
|
||||
if !ok || !proto.Equal(newDG, curDG) {
|
||||
id := newDG.GetId()
|
||||
any, err := ptypes.MarshalAny(newDG)
|
||||
any, err := anypb.New(newDG)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
|
@ -260,10 +238,12 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
}
|
||||
defer mgr.dataBrokerSemaphore.Release(1)
|
||||
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update directory group: %s", id)
|
||||
|
@ -277,7 +257,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
_, ok := lookup[groupID]
|
||||
if !ok {
|
||||
id := curDG.GetId()
|
||||
any, err := ptypes.MarshalAny(curDG)
|
||||
any, err := anypb.New(curDG)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
|
@ -288,9 +268,12 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
}
|
||||
defer mgr.dataBrokerSemaphore.Release(1)
|
||||
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
DeletedAt: timestamppb.Now(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete directory group: %s", id)
|
||||
|
@ -299,6 +282,10 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("manager: failed to merge groups")
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.User) {
|
||||
|
@ -313,7 +300,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
curDU, ok := mgr.directoryUsers[userID]
|
||||
if !ok || !proto.Equal(newDU, curDU) {
|
||||
id := newDU.GetId()
|
||||
any, err := ptypes.MarshalAny(newDU)
|
||||
any, err := anypb.New(newDU)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||
return
|
||||
|
@ -325,10 +312,12 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
defer mgr.dataBrokerSemaphore.Release(1)
|
||||
|
||||
client := mgr.cfg.Load().dataBrokerClient
|
||||
if _, err := client.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
if _, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update directory user: %s", id)
|
||||
}
|
||||
|
@ -341,7 +330,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
_, ok := lookup[userID]
|
||||
if !ok {
|
||||
id := curDU.GetId()
|
||||
any, err := ptypes.MarshalAny(curDU)
|
||||
any, err := anypb.New(curDU)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||
return
|
||||
|
@ -353,9 +342,13 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
defer mgr.dataBrokerSemaphore.Release(1)
|
||||
|
||||
client := mgr.cfg.Load().dataBrokerClient
|
||||
if _, err := client.Delete(ctx, &databroker.DeleteRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
if _, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
DeletedAt: timestamppb.Now(),
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to delete directory user (%s): %w", id, err)
|
||||
}
|
||||
|
@ -384,8 +377,8 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
expiry, err := ptypes.Timestamp(s.GetExpiresAt())
|
||||
if err == nil && !expiry.After(time.Now()) {
|
||||
expiry := s.GetExpiresAt().AsTime()
|
||||
if !expiry.After(time.Now()) {
|
||||
mgr.log.Info().
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
|
@ -435,7 +428,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
res, err := session.Set(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
|
||||
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -444,7 +437,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
mgr.onUpdateSession(ctx, sessionMessage{record: res.GetRecord(), session: s.Session})
|
||||
mgr.onUpdateSession(ctx, res.GetRecord(), s.Session)
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||
|
@ -486,7 +479,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
record, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -495,257 +488,78 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
mgr.onUpdateUser(ctx, userMessage{record: record, user: u.User})
|
||||
mgr.onUpdateUser(ctx, record, u.User)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- sessionMessage) error {
|
||||
mgr.log.Info().Msg("syncing sessions")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(session.Session))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error syncing sessions: %w", err)
|
||||
}
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error receiving sessions: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range res.GetRecords() {
|
||||
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
|
||||
for _, record := range msg.records {
|
||||
switch record.GetType() {
|
||||
case grpcutil.GetTypeURL(new(directory.Group)):
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := record.GetData().UnmarshalTo(&pbDirectoryGroup)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Msgf("error unmarshaling directory group: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup)
|
||||
case grpcutil.GetTypeURL(new(directory.User)):
|
||||
var pbDirectoryUser directory.User
|
||||
err := record.GetData().UnmarshalTo(&pbDirectoryUser)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Msgf("error unmarshaling directory user: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser)
|
||||
case grpcutil.GetTypeURL(new(session.Session)):
|
||||
var pbSession session.Session
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbSession)
|
||||
err := record.GetData().UnmarshalTo(&pbSession)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling session: %w", err)
|
||||
mgr.log.Warn().Msgf("error unmarshaling session: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- sessionMessage{record: record, session: &pbSession}:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- userMessage) error {
|
||||
mgr.log.Info().Msg("syncing users")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(user.User))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error syncing users: %w", err)
|
||||
}
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error receiving users: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range res.GetRecords() {
|
||||
mgr.onUpdateSession(ctx, record, &pbSession)
|
||||
case grpcutil.GetTypeURL(new(user.User)):
|
||||
var pbUser user.User
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbUser)
|
||||
err := record.GetData().UnmarshalTo(&pbUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling user: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- userMessage{record: record, user: &pbUser}:
|
||||
mgr.log.Warn().Msgf("error unmarshaling user: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) initDirectoryUsers(ctx context.Context) error {
|
||||
mgr.log.Info().Msg("initializing directory users")
|
||||
func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
||||
|
||||
any, err := ptypes.MarshalAny(new(directory.User))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return exponentialTry(ctx, func() error {
|
||||
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting all directory users: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryUsers = map[string]*directory.User{}
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryUser directory.User
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory user: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser
|
||||
mgr.directoryUsersRecordVersion = record.GetVersion()
|
||||
}
|
||||
mgr.directoryUsersServerVersion = res.GetServerVersion()
|
||||
|
||||
mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory.User) error {
|
||||
mgr.log.Info().Msg("syncing directory users")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(directory.User))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
ServerVersion: mgr.directoryUsersServerVersion,
|
||||
RecordVersion: mgr.directoryUsersRecordVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error syncing directory users: %w", err)
|
||||
}
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error receiving directory users: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryUser directory.User
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory user: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- &pbDirectoryUser:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
|
||||
mgr.log.Info().Msg("initializing directory groups")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(directory.Group))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return exponentialTry(ctx, func() error {
|
||||
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting all directory groups: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryGroups = map[string]*directory.Group{}
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory group: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup
|
||||
mgr.directoryGroupsRecordVersion = record.GetVersion()
|
||||
}
|
||||
mgr.directoryGroupsServerVersion = res.GetServerVersion()
|
||||
|
||||
mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error {
|
||||
mgr.log.Info().Msg("syncing directory groups")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(directory.Group))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
ServerVersion: mgr.directoryGroupsServerVersion,
|
||||
RecordVersion: mgr.directoryGroupsRecordVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error syncing directory groups: %w", err)
|
||||
}
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error receiving directory groups: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory group: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- &pbDirectoryGroup:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateSession(_ context.Context, msg sessionMessage) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
|
||||
|
||||
if msg.record.GetDeletedAt() != nil {
|
||||
mgr.sessions.Delete(msg.session.GetUserId(), msg.session.GetId())
|
||||
if record.GetDeletedAt() != nil {
|
||||
mgr.sessions.Delete(session.GetUserId(), session.GetId())
|
||||
return
|
||||
}
|
||||
|
||||
// update session
|
||||
s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId())
|
||||
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
|
||||
s.lastRefresh = time.Now()
|
||||
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
|
||||
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
|
||||
s.Session = msg.session
|
||||
s.Session = session
|
||||
mgr.sessions.ReplaceOrInsert(s)
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateUser(_ context.Context, msg userMessage) {
|
||||
mgr.userScheduler.Remove(msg.user.GetId())
|
||||
func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) {
|
||||
mgr.userScheduler.Remove(user.GetId())
|
||||
|
||||
if msg.record.GetDeletedAt() != nil {
|
||||
mgr.users.Delete(msg.user.GetId())
|
||||
if record.GetDeletedAt() != nil {
|
||||
mgr.users.Delete(user.GetId())
|
||||
return
|
||||
}
|
||||
|
||||
u, _ := mgr.users.Get(msg.user.GetId())
|
||||
u, _ := mgr.users.Get(user.GetId())
|
||||
u.lastRefresh = time.Now()
|
||||
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
|
||||
u.User = msg.user
|
||||
u.User = user
|
||||
mgr.users.ReplaceOrInsert(u)
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
}
|
||||
|
@ -779,22 +593,3 @@ func isTemporaryError(err error) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// exponentialTry executes f until it succeeds or ctx is Done.
|
||||
func exponentialTry(ctx context.Context, f func() error) error {
|
||||
backoff := backoff.NewExponentialBackOff()
|
||||
backoff.MaxElapsedTime = 0
|
||||
for {
|
||||
err := f()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff.NextBackOff()):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
|
58
internal/identity/manager/sync.go
Normal file
58
internal/identity/manager/sync.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type dataBrokerSyncer struct {
|
||||
cfg *atomicConfig
|
||||
log zerolog.Logger
|
||||
|
||||
update chan<- updateRecordsMessage
|
||||
clear chan<- struct{}
|
||||
|
||||
syncer *databroker.Syncer
|
||||
}
|
||||
|
||||
func newDataBrokerSyncer(
|
||||
cfg *atomicConfig,
|
||||
log zerolog.Logger,
|
||||
update chan<- updateRecordsMessage,
|
||||
clear chan<- struct{},
|
||||
) *dataBrokerSyncer {
|
||||
syncer := &dataBrokerSyncer{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
|
||||
update: update,
|
||||
clear: clear,
|
||||
}
|
||||
syncer.syncer = databroker.NewSyncer(syncer)
|
||||
return syncer
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) {
|
||||
return syncer.syncer.Run(ctx)
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case syncer.clear <- struct{}{}:
|
||||
}
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return syncer.cfg.Load().dataBrokerClient
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case syncer.update <- updateRecordsMessage{records: records}:
|
||||
}
|
||||
}
|
2
internal/registry/registry.go
Normal file
2
internal/registry/registry.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package registry implements a service registry server.
|
||||
package registry
|
|
@ -9,40 +9,23 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/stats"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
var statsHandler = &ocgrpc.ServerHandler{}
|
||||
|
||||
type testProto struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (t testProto) Reset() {}
|
||||
func (t testProto) ProtoMessage() {}
|
||||
func (t testProto) String() string {
|
||||
return t.message
|
||||
}
|
||||
|
||||
func (t testProto) XXX_Size() int {
|
||||
return len([]byte(t.message))
|
||||
}
|
||||
|
||||
func (t testProto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return []byte(t.message), nil
|
||||
}
|
||||
|
||||
type testInvoker struct {
|
||||
invokeResult error
|
||||
statsHandler stats.Handler
|
||||
}
|
||||
|
||||
func (t testInvoker) UnaryInvoke(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
|
||||
r := reply.(*testProto)
|
||||
r.message = "hello"
|
||||
r := reply.(*wrapperspb.StringValue)
|
||||
r.Value = "hello"
|
||||
|
||||
ctx = t.statsHandler.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
|
||||
t.statsHandler.HandleRPC(ctx, &stats.InPayload{Client: true, Length: len(r.message)})
|
||||
t.statsHandler.HandleRPC(ctx, &stats.OutPayload{Client: true, Length: len(r.message)})
|
||||
t.statsHandler.HandleRPC(ctx, &stats.InPayload{Client: true, Length: len(r.Value)})
|
||||
t.statsHandler.HandleRPC(ctx, &stats.OutPayload{Client: true, Length: len(r.Value)})
|
||||
t.statsHandler.HandleRPC(ctx, &stats.End{Client: true, Error: t.invokeResult})
|
||||
|
||||
return t.invokeResult
|
||||
|
@ -106,7 +89,7 @@ func Test_GRPCClientInterceptor(t *testing.T) {
|
|||
invokeResult: tt.errorCode,
|
||||
statsHandler: &ocgrpc.ClientHandler{},
|
||||
}
|
||||
var reply testProto
|
||||
var reply wrapperspb.StringValue
|
||||
|
||||
interceptor(context.Background(), tt.method, nil, &reply, newTestCC(t), invoker.UnaryInvoke)
|
||||
|
||||
|
|
105
internal/testutil/redis.go
Normal file
105
internal/testutil/redis.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/ory/dockertest/v3"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
const maxWait = time.Minute
|
||||
|
||||
// WithTestRedis creates a test a test redis instance using docker.
|
||||
func WithTestRedis(useTLS bool, handler func(rawURL string) error) error {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait)
|
||||
defer clearTimeout()
|
||||
|
||||
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts := &dockertest.RunOptions{
|
||||
Repository: "redis",
|
||||
Tag: "6",
|
||||
}
|
||||
scheme := "redis"
|
||||
if useTLS {
|
||||
opts.Mounts = []string{
|
||||
filepath.Join(TestDataRoot(), "tls") + ":/tls",
|
||||
}
|
||||
opts.Cmd = []string{
|
||||
"--port", "0",
|
||||
"--tls-port", "6379",
|
||||
"--tls-cert-file", "/tls/redis.crt",
|
||||
"--tls-key-file", "/tls/redis.key",
|
||||
"--tls-ca-cert-file", "/tls/ca.crt",
|
||||
}
|
||||
scheme = "rediss"
|
||||
}
|
||||
|
||||
resource, err := pool.RunWithOptions(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = resource.Expire(uint(maxWait.Seconds()))
|
||||
|
||||
redisURL := fmt.Sprintf("%s://%s/0", scheme, resource.GetHostPort("6379/tcp"))
|
||||
if err := pool.Retry(func() error {
|
||||
options, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if useTLS {
|
||||
options.TLSConfig = RedisTLSConfig()
|
||||
}
|
||||
|
||||
client := redis.NewClient(options)
|
||||
defer client.Close()
|
||||
|
||||
return client.Ping(ctx).Err()
|
||||
}); err != nil {
|
||||
_ = pool.Purge(resource)
|
||||
return err
|
||||
}
|
||||
|
||||
e := handler(redisURL)
|
||||
|
||||
if err := pool.Purge(resource); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
// RedisTLSConfig returns the TLS Config to use with redis.
|
||||
func RedisTLSConfig() *tls.Config {
|
||||
cert, err := cryptutil.CertificateFromFile(
|
||||
filepath.Join(TestDataRoot(), "tls", "redis.crt"),
|
||||
filepath.Join(TestDataRoot(), "tls", "redis.key"),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCert, err := ioutil.ReadFile(filepath.Join(TestDataRoot(), "tls", "ca.crt"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
|
@ -3,6 +3,8 @@ package testutil
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -32,3 +34,28 @@ func toProtoJSON(protoMsg interface{}) json.RawMessage {
|
|||
bs, _ := protojson.Marshal(v2)
|
||||
return bs
|
||||
}
|
||||
|
||||
// ModRoot returns the directory containing the go.mod file.
|
||||
func ModRoot() string {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
panic("error getting working directory")
|
||||
}
|
||||
|
||||
for {
|
||||
if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() {
|
||||
return dir
|
||||
}
|
||||
d := filepath.Dir(dir)
|
||||
if d == dir {
|
||||
break
|
||||
}
|
||||
dir = d
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// TestDataRoot returns the testdata directory.
|
||||
func TestDataRoot() string {
|
||||
return filepath.Join(ModRoot(), "internal", "testutil", "testdata")
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package cryptutil
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// DefaultKeySize is the default key size in bytes.
|
||||
|
@ -29,6 +30,13 @@ func NewRandomStringN(c int) string {
|
|||
return base64.StdEncoding.EncodeToString(randomBytes(c))
|
||||
}
|
||||
|
||||
// NewRandomUInt64 returns a random uint64.
|
||||
//
|
||||
// Panics if source of randomness fails.
|
||||
func NewRandomUInt64() uint64 {
|
||||
return binary.LittleEndian.Uint64(randomBytes(8))
|
||||
}
|
||||
|
||||
// randomBytes generates C number of random bytes suitable for cryptographic
|
||||
// operations.
|
||||
//
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// Package config contains protobuf definitions for config.
|
||||
package config
|
||||
|
||||
// IsSet returns true if one of the route redirect options has been chosen.
|
||||
|
|
|
@ -3,10 +3,9 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// GetUserID gets the databroker user id from a provider user id.
|
||||
|
@ -37,19 +36,17 @@ func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, t
|
|||
return records, len(all)
|
||||
}
|
||||
|
||||
// InitialSync performs a sync with no_wait set to true and then returns all the results.
|
||||
func InitialSync(ctx context.Context, client DataBrokerServiceClient, in *SyncRequest) (*SyncResponse, error) {
|
||||
dup := new(SyncRequest)
|
||||
proto.Merge(dup, in)
|
||||
dup.NoWait = true
|
||||
|
||||
stream, err := client.Sync(ctx, dup)
|
||||
// InitialSync performs a sync latest and then returns all the results.
|
||||
func InitialSync(
|
||||
ctx context.Context,
|
||||
client DataBrokerServiceClient,
|
||||
req *SyncLatestRequest,
|
||||
) (records []*Record, recordVersion, serverVersion uint64, err error) {
|
||||
stream, err := client.SyncLatest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
finalRes := &SyncResponse{}
|
||||
|
||||
loop:
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
|
@ -57,12 +54,19 @@ loop:
|
|||
case err == io.EOF:
|
||||
break loop
|
||||
case err != nil:
|
||||
return nil, err
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
|
||||
finalRes.ServerVersion = res.GetServerVersion()
|
||||
finalRes.Records = append(finalRes.Records, res.GetRecords()...)
|
||||
switch res := res.GetResponse().(type) {
|
||||
case *SyncLatestResponse_Versions:
|
||||
recordVersion = res.Versions.GetLatestRecordVersion()
|
||||
serverVersion = res.Versions.GetServerVersion()
|
||||
case *SyncLatestResponse_Record:
|
||||
records = append(records, res.Record)
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected response: %T", res))
|
||||
}
|
||||
}
|
||||
|
||||
return finalRes, nil
|
||||
return records, recordVersion, serverVersion, nil
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -4,46 +4,27 @@ package databroker;
|
|||
option go_package = "github.com/pomerium/pomerium/pkg/grpc/databroker";
|
||||
|
||||
import "google/protobuf/any.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message ServerVersion {
|
||||
string version = 1;
|
||||
}
|
||||
|
||||
message Record {
|
||||
string version = 1;
|
||||
uint64 version = 1;
|
||||
string type = 2;
|
||||
string id = 3;
|
||||
google.protobuf.Any data = 4;
|
||||
google.protobuf.Timestamp created_at = 5;
|
||||
google.protobuf.Timestamp modified_at = 6;
|
||||
google.protobuf.Timestamp deleted_at = 7;
|
||||
google.protobuf.Timestamp modified_at = 5;
|
||||
google.protobuf.Timestamp deleted_at = 6;
|
||||
}
|
||||
|
||||
message DeleteRequest {
|
||||
string type = 1;
|
||||
string id = 2;
|
||||
message Versions {
|
||||
// the server version indicates the version of the server storing the data
|
||||
uint64 server_version = 1;
|
||||
uint64 latest_record_version = 2;
|
||||
}
|
||||
|
||||
message GetRequest {
|
||||
string type = 1;
|
||||
string id = 2;
|
||||
}
|
||||
message GetResponse {
|
||||
Record record = 1;
|
||||
}
|
||||
|
||||
message GetAllRequest {
|
||||
string type = 1;
|
||||
string page_token = 2;
|
||||
}
|
||||
message GetAllResponse {
|
||||
repeated Record records = 1;
|
||||
string server_version = 2;
|
||||
string record_version = 3;
|
||||
string next_page_token = 4;
|
||||
}
|
||||
message GetResponse { Record record = 1; }
|
||||
|
||||
message QueryRequest {
|
||||
string type = 1;
|
||||
|
@ -56,39 +37,39 @@ message QueryResponse {
|
|||
int64 total_count = 2;
|
||||
}
|
||||
|
||||
message SetRequest {
|
||||
string type = 1;
|
||||
string id = 2;
|
||||
google.protobuf.Any data = 3;
|
||||
}
|
||||
message SetResponse {
|
||||
Record record = 1;
|
||||
string server_version = 2;
|
||||
message PutRequest { Record record = 1; }
|
||||
message PutResponse {
|
||||
uint64 server_version = 1;
|
||||
Record record = 2;
|
||||
}
|
||||
|
||||
message SyncRequest {
|
||||
string server_version = 1;
|
||||
string record_version = 2;
|
||||
string type = 3;
|
||||
bool no_wait = 4;
|
||||
uint64 server_version = 1;
|
||||
uint64 record_version = 2;
|
||||
}
|
||||
message SyncResponse {
|
||||
string server_version = 1;
|
||||
repeated Record records = 2;
|
||||
uint64 server_version = 1;
|
||||
Record record = 2;
|
||||
}
|
||||
|
||||
message GetTypesResponse {
|
||||
repeated string types = 1;
|
||||
message SyncLatestRequest { string type = 1; }
|
||||
message SyncLatestResponse {
|
||||
oneof response {
|
||||
Record record = 1;
|
||||
Versions versions = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// The DataBrokerService stores key-value data.
|
||||
service DataBrokerService {
|
||||
rpc Delete(DeleteRequest) returns (google.protobuf.Empty);
|
||||
// Get gets a record.
|
||||
rpc Get(GetRequest) returns (GetResponse);
|
||||
rpc GetAll(GetAllRequest) returns (GetAllResponse);
|
||||
// Put saves a record.
|
||||
rpc Put(PutRequest) returns (PutResponse);
|
||||
// Query queries for records.
|
||||
rpc Query(QueryRequest) returns (QueryResponse);
|
||||
rpc Set(SetRequest) returns (SetResponse);
|
||||
// Sync streams changes to records after the specified version.
|
||||
rpc Sync(SyncRequest) returns (stream SyncResponse);
|
||||
|
||||
rpc GetTypes(google.protobuf.Empty) returns (GetTypesResponse);
|
||||
rpc SyncTypes(google.protobuf.Empty) returns (stream GetTypesResponse);
|
||||
// SyncLatest streams the latest version of every record.
|
||||
rpc SyncLatest(SyncLatestRequest) returns (stream SyncLatestResponse);
|
||||
}
|
||||
|
|
|
@ -61,18 +61,26 @@ func TestInitialSync(t *testing.T) {
|
|||
|
||||
r1 := new(Record)
|
||||
r2 := new(Record)
|
||||
r3 := new(Record)
|
||||
|
||||
m := &mockServer{
|
||||
sync: func(req *SyncRequest, stream DataBrokerService_SyncServer) error {
|
||||
assert.Equal(t, true, req.GetNoWait())
|
||||
stream.Send(&SyncResponse{
|
||||
ServerVersion: "a",
|
||||
Records: []*Record{r1, r2},
|
||||
syncLatest: func(req *SyncLatestRequest, stream DataBrokerService_SyncLatestServer) error {
|
||||
stream.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Record{
|
||||
Record: r1,
|
||||
},
|
||||
})
|
||||
stream.Send(&SyncResponse{
|
||||
ServerVersion: "b",
|
||||
Records: []*Record{r3},
|
||||
stream.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Record{
|
||||
Record: r2,
|
||||
},
|
||||
})
|
||||
stream.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Versions{
|
||||
Versions: &Versions{
|
||||
LatestRecordVersion: 2,
|
||||
ServerVersion: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
return nil
|
||||
},
|
||||
|
@ -90,20 +98,19 @@ func TestInitialSync(t *testing.T) {
|
|||
|
||||
c := NewDataBrokerServiceClient(cc)
|
||||
|
||||
res, err := InitialSync(ctx, c, &SyncRequest{
|
||||
Type: "TEST",
|
||||
})
|
||||
records, recordVersion, serverVersion, err := InitialSync(ctx, c, new(SyncLatestRequest))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "b", res.GetServerVersion())
|
||||
assert.Equal(t, []*Record{r1, r2, r3}, res.GetRecords())
|
||||
assert.Equal(t, uint64(2), recordVersion)
|
||||
assert.Equal(t, uint64(1), serverVersion)
|
||||
assert.Equal(t, []*Record{r1, r2}, records)
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
DataBrokerServiceServer
|
||||
|
||||
sync func(*SyncRequest, DataBrokerService_SyncServer) error
|
||||
syncLatest func(empty *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error
|
||||
}
|
||||
|
||||
func (m *mockServer) Sync(req *SyncRequest, stream DataBrokerService_SyncServer) error {
|
||||
return m.sync(req, stream)
|
||||
func (m *mockServer) SyncLatest(req *SyncLatestRequest, stream DataBrokerService_SyncLatestServer) error {
|
||||
return m.syncLatest(req, stream)
|
||||
}
|
||||
|
|
171
pkg/grpc/databroker/syncer.go
Normal file
171
pkg/grpc/databroker/syncer.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
package databroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v4"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
type syncerConfig struct {
|
||||
typeURL string
|
||||
}
|
||||
|
||||
// A SyncerOption customizes the syncer configuration.
|
||||
type SyncerOption func(cfg *syncerConfig)
|
||||
|
||||
func getSyncerConfig(options ...SyncerOption) *syncerConfig {
|
||||
cfg := new(syncerConfig)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// WithTypeURL restricts the sync'd results to the given type.
|
||||
func WithTypeURL(typeURL string) SyncerOption {
|
||||
return func(cfg *syncerConfig) {
|
||||
cfg.typeURL = typeURL
|
||||
}
|
||||
}
|
||||
|
||||
// A SyncerHandler receives sync events from the Syncer.
|
||||
type SyncerHandler interface {
|
||||
GetDataBrokerServiceClient() DataBrokerServiceClient
|
||||
ClearRecords(ctx context.Context)
|
||||
UpdateRecords(ctx context.Context, records []*Record)
|
||||
}
|
||||
|
||||
// A Syncer is a helper type for working with Sync and SyncLatest. It will make a call to
|
||||
// SyncLatest to retrieve the latest version of the data, then begin syncing with a call
|
||||
// to Sync. If the server version changes `ClearRecords` will be called and the process
|
||||
// will start over.
|
||||
type Syncer struct {
|
||||
cfg *syncerConfig
|
||||
handler SyncerHandler
|
||||
backoff *backoff.ExponentialBackOff
|
||||
|
||||
recordVersion uint64
|
||||
serverVersion uint64
|
||||
|
||||
closeCtx context.Context
|
||||
closeCtxCancel func()
|
||||
}
|
||||
|
||||
// NewSyncer creates a new Syncer.
|
||||
func NewSyncer(handler SyncerHandler, options ...SyncerOption) *Syncer {
|
||||
closeCtx, closeCtxCancel := context.WithCancel(context.Background())
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
return &Syncer{
|
||||
cfg: getSyncerConfig(options...),
|
||||
handler: handler,
|
||||
backoff: bo,
|
||||
|
||||
closeCtx: closeCtx,
|
||||
closeCtxCancel: closeCtxCancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the Syncer.
|
||||
func (syncer *Syncer) Close() error {
|
||||
syncer.closeCtxCancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run runs the Syncer.
|
||||
func (syncer *Syncer) Run(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
<-syncer.closeCtx.Done()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
for {
|
||||
var err error
|
||||
if syncer.serverVersion == 0 {
|
||||
err = syncer.init(ctx)
|
||||
} else {
|
||||
err = syncer.sync(ctx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(syncer.backoff.NextBackOff()):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (syncer *Syncer) init(ctx context.Context) error {
|
||||
syncer.log().Info().Msg("syncing latest records")
|
||||
records, recordVersion, serverVersion, err := InitialSync(ctx, syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{
|
||||
Type: syncer.cfg.typeURL,
|
||||
})
|
||||
if err != nil {
|
||||
syncer.log().Error().Err(err).Msg("error during initial sync")
|
||||
return err
|
||||
}
|
||||
syncer.backoff.Reset()
|
||||
|
||||
// reset the records as we have to sync latest
|
||||
syncer.handler.ClearRecords(ctx)
|
||||
|
||||
syncer.recordVersion = recordVersion
|
||||
syncer.serverVersion = serverVersion
|
||||
syncer.handler.UpdateRecords(ctx, records)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (syncer *Syncer) sync(ctx context.Context) error {
|
||||
stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(ctx, &SyncRequest{
|
||||
ServerVersion: syncer.serverVersion,
|
||||
RecordVersion: syncer.recordVersion,
|
||||
})
|
||||
if err != nil {
|
||||
syncer.log().Error().Err(err).Msg("error during sync")
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
if status.Code(err) == codes.Aborted {
|
||||
syncer.log().Error().Err(err).Msg("aborted sync due to mismatched server version")
|
||||
// server version changed, so re-init
|
||||
syncer.serverVersion = 0
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if syncer.recordVersion != res.GetRecord().GetVersion()-1 {
|
||||
syncer.log().Error().Err(err).
|
||||
Uint64("received", res.GetRecord().GetVersion()).
|
||||
Msg("aborted sync due to missing record")
|
||||
syncer.serverVersion = 0
|
||||
return fmt.Errorf("missing record version")
|
||||
}
|
||||
syncer.recordVersion = res.GetRecord().GetVersion()
|
||||
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
|
||||
syncer.handler.UpdateRecords(ctx, []*Record{res.GetRecord()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (syncer *Syncer) log() *zerolog.Logger {
|
||||
l := log.With().Str("service", "syncer").
|
||||
Str("type", syncer.cfg.typeURL).
|
||||
Uint64("server_version", syncer.serverVersion).
|
||||
Uint64("record_version", syncer.recordVersion).Logger()
|
||||
return &l
|
||||
}
|
222
pkg/grpc/databroker/syncer_test.go
Normal file
222
pkg/grpc/databroker/syncer_test.go
Normal file
|
@ -0,0 +1,222 @@
|
|||
package databroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
type testSyncerHandler struct {
|
||||
getDataBrokerServiceClient func() DataBrokerServiceClient
|
||||
clearRecords func(ctx context.Context)
|
||||
updateRecords func(ctx context.Context, records []*Record)
|
||||
}
|
||||
|
||||
func (t testSyncerHandler) GetDataBrokerServiceClient() DataBrokerServiceClient {
|
||||
return t.getDataBrokerServiceClient()
|
||||
}
|
||||
|
||||
func (t testSyncerHandler) ClearRecords(ctx context.Context) {
|
||||
t.clearRecords(ctx)
|
||||
}
|
||||
|
||||
func (t testSyncerHandler) UpdateRecords(ctx context.Context, records []*Record) {
|
||||
t.updateRecords(ctx, records)
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
DataBrokerServiceServer
|
||||
sync func(request *SyncRequest, server DataBrokerService_SyncServer) error
|
||||
syncLatest func(req *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error
|
||||
}
|
||||
|
||||
func (t testServer) Sync(request *SyncRequest, server DataBrokerService_SyncServer) error {
|
||||
return t.sync(request, server)
|
||||
}
|
||||
|
||||
func (t testServer) SyncLatest(req *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error {
|
||||
return t.syncLatest(req, server)
|
||||
}
|
||||
|
||||
func TestSyncer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*10)
|
||||
defer clearTimeout()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
lis := bufconn.Listen(1)
|
||||
r1 := &Record{Version: 1000, Id: "r1"}
|
||||
r2 := &Record{Version: 1001, Id: "r2"}
|
||||
r3 := &Record{Version: 1002, Id: "r3"}
|
||||
r5 := &Record{Version: 1004, Id: "r5"}
|
||||
|
||||
syncCount := 0
|
||||
syncLatestCount := 0
|
||||
|
||||
gs := grpc.NewServer()
|
||||
RegisterDataBrokerServiceServer(gs, testServer{
|
||||
sync: func(request *SyncRequest, server DataBrokerService_SyncServer) error {
|
||||
syncCount++
|
||||
switch syncCount {
|
||||
case 1:
|
||||
return status.Error(codes.Internal, "SOME INTERNAL ERROR")
|
||||
case 2:
|
||||
return status.Error(codes.Aborted, "ABORTED")
|
||||
case 3:
|
||||
_ = server.Send(&SyncResponse{
|
||||
ServerVersion: 2001,
|
||||
Record: r3,
|
||||
})
|
||||
_ = server.Send(&SyncResponse{
|
||||
ServerVersion: 2001,
|
||||
Record: r5,
|
||||
})
|
||||
case 4:
|
||||
select {} // block forever
|
||||
default:
|
||||
t.Fatal("unexpected call to sync", request)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
syncLatest: func(req *SyncLatestRequest, server DataBrokerService_SyncLatestServer) error {
|
||||
syncLatestCount++
|
||||
switch syncLatestCount {
|
||||
case 1:
|
||||
_ = server.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Record{
|
||||
Record: r1,
|
||||
},
|
||||
})
|
||||
_ = server.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Versions{
|
||||
Versions: &Versions{
|
||||
LatestRecordVersion: r1.Version,
|
||||
ServerVersion: 2000,
|
||||
},
|
||||
},
|
||||
})
|
||||
case 2:
|
||||
_ = server.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Record{
|
||||
Record: r2,
|
||||
},
|
||||
})
|
||||
_ = server.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Versions{
|
||||
Versions: &Versions{
|
||||
LatestRecordVersion: r2.Version,
|
||||
ServerVersion: 2001,
|
||||
},
|
||||
},
|
||||
})
|
||||
case 3:
|
||||
return status.Error(codes.Internal, "SOME INTERNAL ERROR")
|
||||
case 4:
|
||||
_ = server.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Record{
|
||||
Record: r3,
|
||||
},
|
||||
})
|
||||
_ = server.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Record{
|
||||
Record: r5,
|
||||
},
|
||||
})
|
||||
_ = server.Send(&SyncLatestResponse{
|
||||
Response: &SyncLatestResponse_Versions{
|
||||
Versions: &Versions{
|
||||
LatestRecordVersion: r5.Version,
|
||||
ServerVersion: 2001,
|
||||
},
|
||||
},
|
||||
})
|
||||
default:
|
||||
t.Fatal("unexpected call to sync latest")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
go func() { _ = gs.Serve(lis) }()
|
||||
|
||||
gc, err := grpc.DialContext(ctx, "bufnet",
|
||||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return lis.Dial()
|
||||
}),
|
||||
grpc.WithInsecure())
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = gc.Close() }()
|
||||
|
||||
clearCh := make(chan struct{})
|
||||
updateCh := make(chan []*Record)
|
||||
syncer := NewSyncer(testSyncerHandler{
|
||||
getDataBrokerServiceClient: func() DataBrokerServiceClient {
|
||||
return NewDataBrokerServiceClient(gc)
|
||||
},
|
||||
clearRecords: func(ctx context.Context) {
|
||||
clearCh <- struct{}{}
|
||||
},
|
||||
updateRecords: func(ctx context.Context, records []*Record) {
|
||||
updateCh <- records
|
||||
},
|
||||
})
|
||||
go func() { _ = syncer.Run(ctx) }()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("1. expected call to clear records")
|
||||
case <-clearCh:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("2. expected call to update records")
|
||||
case records := <-updateCh:
|
||||
testutil.AssertProtoJSONEqual(t, `[{"id": "r1", "version": "1000"}]`, records)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("3. expected call to clear records due to server version change")
|
||||
case <-clearCh:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("4. expected call to update records")
|
||||
case records := <-updateCh:
|
||||
testutil.AssertProtoJSONEqual(t, `[{"id": "r2", "version": "1001"}]`, records)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("5. expected call to update records from sync")
|
||||
case records := <-updateCh:
|
||||
testutil.AssertProtoJSONEqual(t, `[{"id": "r3", "version": "1002"}]`, records)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("6. expected call to clear records due to skipped version")
|
||||
case <-clearCh:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("7. expected call to update records")
|
||||
case records := <-updateCh:
|
||||
testutil.AssertProtoJSONEqual(t, `[{"id": "r3", "version": "1002"}, {"id": "r5", "version": "1004"}]`, records)
|
||||
}
|
||||
|
||||
assert.NoError(t, syncer.Close())
|
||||
}
|
|
@ -17,9 +17,13 @@ import (
|
|||
// Delete deletes a session from the databroker.
|
||||
func Delete(ctx context.Context, client databroker.DataBrokerServiceClient, sessionID string) error {
|
||||
any, _ := ptypes.MarshalAny(new(Session))
|
||||
_, err := client.Delete(ctx, &databroker.DeleteRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: sessionID,
|
||||
_, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: sessionID,
|
||||
Data: any,
|
||||
DeletedAt: timestamppb.Now(),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
@ -44,13 +48,15 @@ func Get(ctx context.Context, client databroker.DataBrokerServiceClient, session
|
|||
return &s, nil
|
||||
}
|
||||
|
||||
// Set sets a session in the databroker.
|
||||
func Set(ctx context.Context, client databroker.DataBrokerServiceClient, s *Session) (*databroker.SetResponse, error) {
|
||||
// Put sets a session in the databroker.
|
||||
func Put(ctx context.Context, client databroker.DataBrokerServiceClient, s *Session) (*databroker.PutResponse, error) {
|
||||
any, _ := anypb.New(s)
|
||||
res, err := client.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
res, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
return res, err
|
||||
}
|
||||
|
|
|
@ -32,13 +32,15 @@ func Get(ctx context.Context, client databroker.DataBrokerServiceClient, userID
|
|||
return &u, nil
|
||||
}
|
||||
|
||||
// Set sets a user in the databroker.
|
||||
func Set(ctx context.Context, client databroker.DataBrokerServiceClient, u *User) (*databroker.Record, error) {
|
||||
// Put sets a user in the databroker.
|
||||
func Put(ctx context.Context, client databroker.DataBrokerServiceClient, u *User) (*databroker.Record, error) {
|
||||
any, _ := anypb.New(u)
|
||||
res, err := client.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.Id,
|
||||
Data: any,
|
||||
res, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.Id,
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -46,13 +48,15 @@ func Set(ctx context.Context, client databroker.DataBrokerServiceClient, u *User
|
|||
return res.GetRecord(), nil
|
||||
}
|
||||
|
||||
// SetServiceAccount sets a service account in the databroker.
|
||||
func SetServiceAccount(ctx context.Context, client databroker.DataBrokerServiceClient, sa *ServiceAccount) (*databroker.Record, error) {
|
||||
// PutServiceAccount sets a service account in the databroker.
|
||||
func PutServiceAccount(ctx context.Context, client databroker.DataBrokerServiceClient, sa *ServiceAccount) (*databroker.Record, error) {
|
||||
any, _ := anypb.New(sa)
|
||||
res, err := client.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: sa.GetId(),
|
||||
Data: any,
|
||||
res, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: sa.GetId(),
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"context"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// SessionIDMetadataKey is the key in the metadata.
|
||||
|
@ -52,3 +54,19 @@ func JWTFromGRPCRequest(ctx context.Context) (rawjwt string, ok bool) {
|
|||
|
||||
return rawjwts[0], true
|
||||
}
|
||||
|
||||
// GetPeerAddr returns the peer address.
|
||||
func GetPeerAddr(ctx context.Context) string {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if ok {
|
||||
return p.Addr.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetTypeURL gets the TypeURL for a protobuf message.
|
||||
func GetTypeURL(msg proto.Message) string {
|
||||
// taken from the anypb package
|
||||
const urlPrefix = "type.googleapis.com/"
|
||||
return urlPrefix + string(msg.ProtoReflect().Descriptor().FullName())
|
||||
}
|
||||
|
|
|
@ -12,9 +12,42 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type encryptedRecordStream struct {
|
||||
underlying RecordStream
|
||||
backend *encryptedBackend
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Close() error {
|
||||
return e.underlying.Close()
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Next(wait bool) bool {
|
||||
return e.underlying.Next(wait)
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Record() *databroker.Record {
|
||||
r := e.underlying.Record()
|
||||
if r != nil {
|
||||
var err error
|
||||
r, err = e.backend.decryptRecord(r)
|
||||
if err != nil {
|
||||
e.err = err
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Err() error {
|
||||
if e.err == nil {
|
||||
e.err = e.underlying.Err()
|
||||
}
|
||||
return e.err
|
||||
}
|
||||
|
||||
type encryptedBackend struct {
|
||||
Backend
|
||||
cipher cipher.AEAD
|
||||
underlying Backend
|
||||
cipher cipher.AEAD
|
||||
}
|
||||
|
||||
// NewEncryptedBackend creates a new encrypted backend.
|
||||
|
@ -25,21 +58,17 @@ func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) {
|
|||
}
|
||||
|
||||
return &encryptedBackend{
|
||||
Backend: underlying,
|
||||
cipher: c,
|
||||
underlying: underlying,
|
||||
cipher: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||
encrypted, err := e.encrypt(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.Backend.Put(ctx, id, encrypted)
|
||||
func (e *encryptedBackend) Close() error {
|
||||
return e.underlying.Close()
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
|
||||
record, err := e.Backend.Get(ctx, id)
|
||||
func (e *encryptedBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
record, err := e.underlying.Get(ctx, recordType, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -50,18 +79,41 @@ func (e *encryptedBackend) Get(ctx context.Context, id string) (*databroker.Reco
|
|||
return record, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
records, err := e.Backend.List(ctx, sinceVersion)
|
||||
func (e *encryptedBackend) GetAll(ctx context.Context) ([]*databroker.Record, uint64, error) {
|
||||
records, version, err := e.underlying.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
for i := range records {
|
||||
records[i], err = e.decryptRecord(records[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
return records, nil
|
||||
return records, version, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Put(ctx context.Context, record *databroker.Record) error {
|
||||
encrypted, err := e.encrypt(record.GetData())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRecord := proto.Clone(record).(*databroker.Record)
|
||||
newRecord.Data = encrypted
|
||||
|
||||
return e.underlying.Put(ctx, newRecord)
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Sync(ctx context.Context, version uint64) (RecordStream, error) {
|
||||
stream, err := e.underlying.Sync(ctx, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &encryptedRecordStream{
|
||||
underlying: stream,
|
||||
backend: e,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) {
|
||||
|
@ -75,13 +127,16 @@ func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker
|
|||
Type: data.TypeUrl,
|
||||
Id: in.Id,
|
||||
Data: data,
|
||||
CreatedAt: in.CreatedAt,
|
||||
ModifiedAt: in.ModifiedAt,
|
||||
DeletedAt: in.DeletedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) {
|
||||
if in == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var encrypted wrapperspb.BytesValue
|
||||
err = in.UnmarshalTo(&encrypted)
|
||||
if err != nil {
|
||||
|
|
|
@ -18,11 +18,11 @@ func TestEncryptedBackend(t *testing.T) {
|
|||
|
||||
m := map[string]*anypb.Any{}
|
||||
backend := &mockBackend{
|
||||
put: func(ctx context.Context, id string, data *anypb.Any) error {
|
||||
m[id] = data
|
||||
put: func(ctx context.Context, record *databroker.Record) error {
|
||||
m[record.GetId()] = record.GetData()
|
||||
return nil
|
||||
},
|
||||
get: func(ctx context.Context, id string) (*databroker.Record, error) {
|
||||
get: func(ctx context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
data, ok := m[id]
|
||||
if !ok {
|
||||
return nil, errors.New("not found")
|
||||
|
@ -32,7 +32,7 @@ func TestEncryptedBackend(t *testing.T) {
|
|||
Data: data,
|
||||
}, nil
|
||||
},
|
||||
getAll: func(ctx context.Context) ([]*databroker.Record, error) {
|
||||
getAll: func(ctx context.Context) ([]*databroker.Record, uint64, error) {
|
||||
var records []*databroker.Record
|
||||
for id, data := range m {
|
||||
records = append(records, &databroker.Record{
|
||||
|
@ -40,17 +40,7 @@ func TestEncryptedBackend(t *testing.T) {
|
|||
Data: data,
|
||||
})
|
||||
}
|
||||
return records, nil
|
||||
},
|
||||
list: func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
var records []*databroker.Record
|
||||
for id, data := range m {
|
||||
records = append(records, &databroker.Record{
|
||||
Id: id,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
return records, nil
|
||||
return records, 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -61,7 +51,11 @@ func TestEncryptedBackend(t *testing.T) {
|
|||
|
||||
any, _ := anypb.New(wrapperspb.String("HELLO WORLD"))
|
||||
|
||||
err = e.Put(ctx, "TEST-1", any)
|
||||
err = e.Put(ctx, &databroker.Record{
|
||||
Type: "",
|
||||
Id: "TEST-1",
|
||||
Data: any,
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
@ -70,7 +64,7 @@ func TestEncryptedBackend(t *testing.T) {
|
|||
assert.NotEqual(t, any.Value, m["TEST-1"].Value, "value should be encrypted")
|
||||
}
|
||||
|
||||
record, err := e.Get(ctx, "TEST-1")
|
||||
record, err := e.Get(ctx, "", "TEST-1")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
@ -78,17 +72,7 @@ func TestEncryptedBackend(t *testing.T) {
|
|||
assert.Equal(t, any.Value, record.Data.Value, "value should be preserved")
|
||||
assert.Equal(t, any.TypeUrl, record.Type, "record type should be preserved")
|
||||
|
||||
records, err := e.List(ctx, "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if assert.Len(t, records, 1) {
|
||||
assert.Equal(t, any.TypeUrl, records[0].Data.TypeUrl, "type should be preserved")
|
||||
assert.Equal(t, any.Value, records[0].Data.Value, "value should be preserved")
|
||||
assert.Equal(t, any.TypeUrl, records[0].Type, "record type should be preserved")
|
||||
}
|
||||
|
||||
records, err = e.List(ctx, "")
|
||||
records, _, err := e.GetAll(ctx)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
|
199
pkg/storage/inmemory/backend.go
Normal file
199
pkg/storage/inmemory/backend.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
// Package inmemory contains an in-memory implementation of the databroker backend.
|
||||
package inmemory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/btree"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
type recordKey struct {
|
||||
Type string
|
||||
ID string
|
||||
}
|
||||
|
||||
type recordChange struct {
|
||||
record *databroker.Record
|
||||
}
|
||||
|
||||
func (change recordChange) Less(item btree.Item) bool {
|
||||
that, ok := item.(recordChange)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return change.record.GetVersion() < that.record.GetVersion()
|
||||
}
|
||||
|
||||
// A Backend stores data in-memory.
|
||||
type Backend struct {
|
||||
cfg *config
|
||||
onChange *signal.Signal
|
||||
|
||||
lastVersion uint64
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
|
||||
mu sync.RWMutex
|
||||
lookup map[recordKey]*databroker.Record
|
||||
changes *btree.BTree
|
||||
}
|
||||
|
||||
// New creates a new in-memory backend storage.
|
||||
func New(options ...Option) *Backend {
|
||||
cfg := getConfig(options...)
|
||||
backend := &Backend{
|
||||
cfg: cfg,
|
||||
onChange: signal.New(),
|
||||
closed: make(chan struct{}),
|
||||
lookup: make(map[recordKey]*databroker.Record),
|
||||
changes: btree.New(cfg.degree),
|
||||
}
|
||||
if cfg.expiry != 0 {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-backend.closed:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
|
||||
}
|
||||
}()
|
||||
}
|
||||
return backend
|
||||
}
|
||||
|
||||
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
|
||||
for {
|
||||
item := backend.changes.Min()
|
||||
if item == nil {
|
||||
break
|
||||
}
|
||||
change, ok := item.(recordChange)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
|
||||
}
|
||||
if change.record.GetModifiedAt().AsTime().Before(cutoff) {
|
||||
_ = backend.changes.DeleteMin()
|
||||
continue
|
||||
}
|
||||
|
||||
// nothing left to remove
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the in-memory store and erases any stored data.
|
||||
func (backend *Backend) Close() error {
|
||||
backend.closeOnce.Do(func() {
|
||||
close(backend.closed)
|
||||
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
|
||||
backend.lookup = map[recordKey]*databroker.Record{}
|
||||
backend.changes = btree.New(backend.cfg.degree)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get gets a record from the in-memory store.
|
||||
func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
backend.mu.RLock()
|
||||
defer backend.mu.RUnlock()
|
||||
|
||||
key := recordKey{Type: recordType, ID: id}
|
||||
record, ok := backend.lookup[key]
|
||||
if !ok {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return dup(record), nil
|
||||
}
|
||||
|
||||
// GetAll gets all the records from the in-memory store.
|
||||
func (backend *Backend) GetAll(_ context.Context) ([]*databroker.Record, uint64, error) {
|
||||
backend.mu.RLock()
|
||||
defer backend.mu.RUnlock()
|
||||
|
||||
var records []*databroker.Record
|
||||
for _, record := range backend.lookup {
|
||||
records = append(records, dup(record))
|
||||
}
|
||||
return records, backend.lastVersion, nil
|
||||
}
|
||||
|
||||
// Put puts a record into the in-memory store.
|
||||
func (backend *Backend) Put(_ context.Context, record *databroker.Record) error {
|
||||
if record == nil {
|
||||
return fmt.Errorf("records cannot be nil")
|
||||
}
|
||||
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
defer backend.onChange.Broadcast()
|
||||
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version = backend.nextVersion()
|
||||
backend.changes.ReplaceOrInsert(recordChange{record: dup(record)})
|
||||
|
||||
key := recordKey{Type: record.GetType(), ID: record.GetId()}
|
||||
if record.GetDeletedAt() != nil {
|
||||
delete(backend.lookup, key)
|
||||
} else {
|
||||
backend.lookup[key] = dup(record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync returns a record stream for any changes after version.
|
||||
func (backend *Backend) Sync(ctx context.Context, version uint64) (storage.RecordStream, error) {
|
||||
return newRecordStream(ctx, backend, version), nil
|
||||
}
|
||||
|
||||
func (backend *Backend) getSince(version uint64) []*databroker.Record {
|
||||
backend.mu.RLock()
|
||||
defer backend.mu.RUnlock()
|
||||
|
||||
var records []*databroker.Record
|
||||
pivot := recordChange{record: &databroker.Record{Version: version}}
|
||||
backend.changes.AscendGreaterOrEqual(pivot, func(item btree.Item) bool {
|
||||
change, ok := item.(recordChange)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
|
||||
}
|
||||
record := change.record
|
||||
// skip the pivoting version as we only want records after it
|
||||
if record.GetVersion() != version {
|
||||
records = append(records, dup(record))
|
||||
}
|
||||
return true
|
||||
})
|
||||
return records
|
||||
}
|
||||
|
||||
func (backend *Backend) nextVersion() uint64 {
|
||||
return atomic.AddUint64(&backend.lastVersion, 1)
|
||||
}
|
||||
|
||||
func dup(record *databroker.Record) *databroker.Record {
|
||||
return proto.Clone(record).(*databroker.Record)
|
||||
}
|
152
pkg/storage/inmemory/backend_test.go
Normal file
152
pkg/storage/inmemory/backend_test.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
package inmemory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
func TestBackend(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
backend := New()
|
||||
defer func() { _ = backend.Close() }()
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: "abcd",
|
||||
Data: data,
|
||||
}))
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.Equal(t, data, record.Data)
|
||||
assert.Nil(t, record.DeletedAt)
|
||||
assert.Equal(t, "abcd", record.Id)
|
||||
assert.NotNil(t, record.ModifiedAt)
|
||||
assert.Equal(t, "TYPE", record.Type)
|
||||
assert.Equal(t, uint64(1), record.Version)
|
||||
}
|
||||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: "abcd",
|
||||
DeletedAt: timestamppb.Now(),
|
||||
}))
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get all records", func(t *testing.T) {
|
||||
for i := 0; i < 1000; i++ {
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: fmt.Sprint(i),
|
||||
}))
|
||||
}
|
||||
records, version, err := backend.GetAll(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 1000)
|
||||
assert.Equal(t, uint64(1002), version)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiry(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
backend := New(WithExpiry(0))
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: fmt.Sprint(i),
|
||||
}))
|
||||
}
|
||||
stream, err := backend.Sync(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
var records []*databroker.Record
|
||||
for stream.Next(false) {
|
||||
records = append(records, stream.Record())
|
||||
}
|
||||
_ = stream.Close()
|
||||
require.Len(t, records, 1000)
|
||||
|
||||
backend.removeChangesBefore(time.Now().Add(time.Second))
|
||||
|
||||
stream, err = backend.Sync(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
records = nil
|
||||
for stream.Next(false) {
|
||||
records = append(records, stream.Record())
|
||||
}
|
||||
_ = stream.Close()
|
||||
require.Len(t, records, 0)
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
backend := New()
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 1000; i++ {
|
||||
_, _, _ = backend.GetAll(ctx)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 1000; i++ {
|
||||
_ = backend.Put(ctx, &databroker.Record{
|
||||
Id: fmt.Sprint(i),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, eg.Wait())
|
||||
}
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
backend := New()
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
stream, err := backend.Sync(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = stream.Close() }()
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 10000; i++ {
|
||||
assert.True(t, stream.Next(true))
|
||||
assert.Equal(t, "TYPE", stream.Record().GetType())
|
||||
assert.Equal(t, fmt.Sprint(i), stream.Record().GetId())
|
||||
assert.Equal(t, uint64(i+1), stream.Record().GetVersion())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 10000; i++ {
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: fmt.Sprint(i),
|
||||
}))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, eg.Wait())
|
||||
}
|
36
pkg/storage/inmemory/config.go
Normal file
36
pkg/storage/inmemory/config.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package inmemory
|
||||
|
||||
import "time"
|
||||
|
||||
type config struct {
|
||||
degree int
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
// An Option customizes the in-memory backend.
|
||||
type Option func(cfg *config)
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := &config{
|
||||
degree: 16,
|
||||
expiry: time.Hour,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// WithBTreeDegree sets the btree degree of the changes btree.
|
||||
func WithBTreeDegree(degree int) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.degree = degree
|
||||
}
|
||||
}
|
||||
|
||||
// WithExpiry sets the expiry for changes.
|
||||
func WithExpiry(expiry time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.expiry = expiry
|
||||
}
|
||||
}
|
|
@ -1,200 +0,0 @@
|
|||
// Package inmemory is the in-memory database using b-trees.
|
||||
package inmemory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/google/btree"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// Name is the storage type name for inmemory backend.
|
||||
const Name = config.StorageInMemoryName
|
||||
|
||||
var _ storage.Backend = (*DB)(nil)
|
||||
|
||||
type byIDRecord struct {
|
||||
*databroker.Record
|
||||
}
|
||||
|
||||
func (k byIDRecord) Less(than btree.Item) bool {
|
||||
k2, _ := than.(byIDRecord)
|
||||
return k.GetId() < k2.GetId()
|
||||
}
|
||||
|
||||
type byVersionRecord struct {
|
||||
*databroker.Record
|
||||
}
|
||||
|
||||
func (k byVersionRecord) Less(than btree.Item) bool {
|
||||
k2, _ := than.(byVersionRecord)
|
||||
return k.GetVersion() < k2.GetVersion()
|
||||
}
|
||||
|
||||
// DB is an in-memory database of records using b-trees.
|
||||
type DB struct {
|
||||
recordType string
|
||||
|
||||
lastVersion uint64
|
||||
|
||||
mu sync.RWMutex
|
||||
byID *btree.BTree
|
||||
byVersion *btree.BTree
|
||||
deletedIDs []string
|
||||
onchange *signal.Signal
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// NewDB creates a new in-memory database for the given record type.
|
||||
func NewDB(recordType string, btreeDegree int) *DB {
|
||||
s := signal.New()
|
||||
return &DB{
|
||||
recordType: recordType,
|
||||
byID: btree.New(btreeDegree),
|
||||
byVersion: btree.New(btreeDegree),
|
||||
onchange: s,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ClearDeleted clears all the currently deleted records older than the given cutoff.
|
||||
func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
var remaining []string
|
||||
for _, id := range db.deletedIDs {
|
||||
record, _ := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
ts := record.GetDeletedAt().AsTime()
|
||||
if ts.Before(cutoff) {
|
||||
db.byID.Delete(record)
|
||||
db.byVersion.Delete(byVersionRecord(record))
|
||||
} else {
|
||||
remaining = append(remaining, id)
|
||||
}
|
||||
}
|
||||
db.deletedIDs = remaining
|
||||
}
|
||||
|
||||
// Close closes the database. Any watchers will be closed.
|
||||
func (db *DB) Close() error {
|
||||
db.closeOnce.Do(func() {
|
||||
close(db.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete marks a record as deleted.
|
||||
func (db *DB) Delete(_ context.Context, id string) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
defer db.onchange.Broadcast()
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.DeletedAt = ptypes.TimestampNow()
|
||||
db.deletedIDs = append(db.deletedIDs, id)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get gets a record from the db.
|
||||
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if !ok {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return record.Record, nil
|
||||
}
|
||||
|
||||
// GetAll gets all the records in the db.
|
||||
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
var records []*databroker.Record
|
||||
db.byID.Ascend(func(item btree.Item) bool {
|
||||
records = append(records, item.(byIDRecord).Record)
|
||||
return true
|
||||
})
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// List lists all the changes since the given version.
|
||||
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
var records []*databroker.Record
|
||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||
record := i.(byVersionRecord)
|
||||
if record.Version > sinceVersion {
|
||||
records = append(records, record.Record)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// Put replaces or inserts a record in the db.
|
||||
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
defer db.onchange.Broadcast()
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.Data = data
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Watch returns the underlying signal.Signal binding channel to the caller.
|
||||
// Then the caller can listen to the channel for detecting changes.
|
||||
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
||||
ch := db.onchange.Bind()
|
||||
go func() {
|
||||
select {
|
||||
case <-db.closed:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
close(ch)
|
||||
db.onchange.Unbind(ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if ok {
|
||||
db.byVersion.Delete(byVersionRecord(record))
|
||||
record.Record = proto.Clone(record.Record).(*databroker.Record)
|
||||
} else {
|
||||
record.Record = new(databroker.Record)
|
||||
}
|
||||
f(record.Record)
|
||||
if record.CreatedAt == nil {
|
||||
record.CreatedAt = ptypes.TimestampNow()
|
||||
}
|
||||
record.ModifiedAt = ptypes.TimestampNow()
|
||||
record.Type = db.recordType
|
||||
record.Id = id
|
||||
record.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
|
||||
db.byID.ReplaceOrInsert(record)
|
||||
db.byVersion.ReplaceOrInsert(byVersionRecord(record))
|
||||
}
|
|
@ -1,108 +0,0 @@
|
|||
package inmemory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := NewDB("example", 2)
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.CreatedAt)
|
||||
assert.Equal(t, data, record.Data)
|
||||
assert.Nil(t, record.DeletedAt)
|
||||
assert.Equal(t, "abcd", record.Id)
|
||||
assert.NotNil(t, record.ModifiedAt)
|
||||
assert.Equal(t, "example", record.Type)
|
||||
assert.Equal(t, "000000000001", record.Version)
|
||||
}
|
||||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.DeletedAt)
|
||||
}
|
||||
})
|
||||
t.Run("clear deleted", func(t *testing.T) {
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("keep remaining", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, record)
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
})
|
||||
t.Run("list", func(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
|
||||
}
|
||||
|
||||
records, err := db.List(ctx, "")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 10)
|
||||
records, err = db.List(ctx, "00000000000A")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 4)
|
||||
records, err = db.List(ctx, "00000000000F")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 0)
|
||||
})
|
||||
t.Run("delete twice", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := NewDB("example", 2)
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 1000; i++ {
|
||||
_, _ = db.List(ctx, "")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 1000; i++ {
|
||||
db.Put(ctx, fmt.Sprint(i), new(anypb.Any))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Wait()
|
||||
}
|
101
pkg/storage/inmemory/stream.go
Normal file
101
pkg/storage/inmemory/stream.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package inmemory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
type recordStream struct {
|
||||
ctx context.Context
|
||||
backend *Backend
|
||||
|
||||
changed chan struct{}
|
||||
ready []*databroker.Record
|
||||
version uint64
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newRecordStream(ctx context.Context, backend *Backend, version uint64) *recordStream {
|
||||
stream := &recordStream{
|
||||
ctx: ctx,
|
||||
backend: backend,
|
||||
|
||||
changed: backend.onChange.Bind(),
|
||||
version: version,
|
||||
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
return stream
|
||||
}
|
||||
|
||||
func (stream *recordStream) fill() {
|
||||
stream.ready = stream.backend.getSince(stream.version)
|
||||
if len(stream.ready) > 0 {
|
||||
// records are sorted by version,
|
||||
// so update the local version to the last record
|
||||
stream.version = stream.ready[len(stream.ready)-1].GetVersion()
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *recordStream) Close() error {
|
||||
stream.closeOnce.Do(func() {
|
||||
stream.backend.onChange.Unbind(stream.changed)
|
||||
close(stream.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stream *recordStream) Next(wait bool) bool {
|
||||
if len(stream.ready) > 0 {
|
||||
stream.ready = stream.ready[1:]
|
||||
}
|
||||
if len(stream.ready) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for {
|
||||
stream.fill()
|
||||
if len(stream.ready) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if wait {
|
||||
select {
|
||||
case <-stream.ctx.Done():
|
||||
return false
|
||||
case <-stream.closed:
|
||||
return false
|
||||
case <-stream.changed:
|
||||
// query for records again
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *recordStream) Record() *databroker.Record {
|
||||
var r *databroker.Record
|
||||
if len(stream.ready) > 0 {
|
||||
r = stream.ready[0]
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (stream *recordStream) Err() error {
|
||||
select {
|
||||
case <-stream.ctx.Done():
|
||||
return stream.ctx.Err()
|
||||
case <-stream.closed:
|
||||
return storage.ErrStreamClosed
|
||||
case <-stream.backend.closed:
|
||||
return storage.ErrStreamClosed
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
31
pkg/storage/redis/observe.go
Normal file
31
pkg/storage/redis/observe.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
pomeriumconfig "github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
}
|
||||
|
||||
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
log.Info().Str("service", "redis").Msgf(format, v...)
|
||||
}
|
||||
|
||||
func init() {
|
||||
redis.SetLogger(logger{})
|
||||
}
|
||||
|
||||
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
|
||||
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
|
||||
Operation: operation,
|
||||
Error: err,
|
||||
Backend: pomeriumconfig.StorageRedisName,
|
||||
}, time.Since(startTime))
|
||||
}
|
|
@ -2,32 +2,34 @@ package redis
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
type dbConfig struct {
|
||||
tls *tls.Config
|
||||
recordType string
|
||||
type config struct {
|
||||
tls *tls.Config
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
// Option customizes a DB.
|
||||
type Option func(*dbConfig)
|
||||
// Option customizes a Backend.
|
||||
type Option func(*config)
|
||||
|
||||
// WithRecordType sets the record type in the config.
|
||||
func WithRecordType(recordType string) Option {
|
||||
return func(cfg *dbConfig) {
|
||||
cfg.recordType = recordType
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig sets the tls.Config which DB uses.
|
||||
// WithTLSConfig sets the tls.Config which Backend uses.
|
||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
||||
return func(cfg *dbConfig) {
|
||||
return func(cfg *config) {
|
||||
cfg.tls = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *dbConfig {
|
||||
cfg := new(dbConfig)
|
||||
// WithExpiry sets the expiry for changes.
|
||||
func WithExpiry(expiry time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.expiry = expiry
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithExpiry(time.Hour * 24)(cfg)
|
||||
for _, o := range options {
|
||||
o(cfg)
|
||||
}
|
||||
|
|
|
@ -5,29 +5,30 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// Name of the storage backend.
|
||||
const Name = config.StorageRedisName
|
||||
|
||||
const (
|
||||
maxTransactionRetries = 100
|
||||
watchPollInterval = 30 * time.Second
|
||||
|
||||
lastVersionKey = "pomerium.last_version"
|
||||
lastVersionChKey = "pomerium.last_version_ch"
|
||||
recordHashKey = "pomerium.records"
|
||||
changesSetKey = "pomerium.changes"
|
||||
)
|
||||
|
||||
// custom errors
|
||||
|
@ -35,21 +36,24 @@ var (
|
|||
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
|
||||
)
|
||||
|
||||
// DB implements the storage.Backend on top of redis.
|
||||
type DB struct {
|
||||
cfg *dbConfig
|
||||
// Backend implements the storage.Backend on top of redis.
|
||||
type Backend struct {
|
||||
cfg *config
|
||||
|
||||
client *redis.Client
|
||||
client *redis.Client
|
||||
onChange *signal.Signal
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new redis storage backend.
|
||||
func New(rawURL string, options ...Option) (*DB, error) {
|
||||
db := &DB{
|
||||
cfg: getConfig(options...),
|
||||
closed: make(chan struct{}),
|
||||
func New(rawURL string, options ...Option) (*Backend, error) {
|
||||
cfg := getConfig(options...)
|
||||
backend := &Backend{
|
||||
cfg: cfg,
|
||||
closed: make(chan struct{}),
|
||||
onChange: signal.New(),
|
||||
}
|
||||
opts, err := redis.ParseURL(rawURL)
|
||||
if err != nil {
|
||||
|
@ -57,194 +61,150 @@ func New(rawURL string, options ...Option) (*DB, error) {
|
|||
}
|
||||
// when using TLS, the TLS config will not be set to nil, in which case we replace it with our own
|
||||
if opts.TLSConfig != nil {
|
||||
opts.TLSConfig = db.cfg.tls
|
||||
opts.TLSConfig = backend.cfg.tls
|
||||
}
|
||||
db.client = redis.NewClient(opts)
|
||||
metrics.AddRedisMetrics(db.client.PoolStats)
|
||||
return db, nil
|
||||
}
|
||||
backend.client = redis.NewClient(opts)
|
||||
metrics.AddRedisMetrics(backend.client.PoolStats)
|
||||
go backend.listenForVersionChanges()
|
||||
if cfg.expiry != 0 {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-backend.closed:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
// ClearDeleted clears all the deleted records older than the cutoff time.
|
||||
func (db *DB) ClearDeleted(ctx context.Context, cutoff time.Time) {
|
||||
var err error
|
||||
|
||||
_, span := trace.StartSpan(ctx, "databroker.redis.ClearDeleted")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "clear_deleted", err) }(time.Now())
|
||||
|
||||
ids, _ := db.client.SMembers(ctx, formatDeletedSetKey(db.cfg.recordType)).Result()
|
||||
records, _ := redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
|
||||
_, err = db.client.Pipelined(ctx, func(p redis.Pipeliner) error {
|
||||
for _, record := range records {
|
||||
if record.GetDeletedAt().AsTime().Before(cutoff) {
|
||||
p.HDel(ctx, formatRecordsKey(db.cfg.recordType), record.GetId())
|
||||
p.ZRem(ctx, formatVersionSetKey(db.cfg.recordType), record.GetId())
|
||||
p.SRem(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
|
||||
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
}
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying redis connection and any watchers.
|
||||
func (db *DB) Close() error {
|
||||
func (backend *Backend) Close() error {
|
||||
var err error
|
||||
db.closeOnce.Do(func() {
|
||||
err = db.client.Close()
|
||||
close(db.closed)
|
||||
backend.closeOnce.Do(func() {
|
||||
err = backend.client.Close()
|
||||
close(backend.closed)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete marks a record as deleted.
|
||||
func (db *DB) Delete(ctx context.Context, id string) (err error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.redis.Delete")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "delete", err) }(time.Now())
|
||||
|
||||
var record *databroker.Record
|
||||
err = db.incrementVersion(ctx,
|
||||
func(tx *redis.Tx, version int64) error {
|
||||
var err error
|
||||
record, err = redisGetRecord(ctx, tx, db.cfg.recordType, id)
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// nothing to do, as the record doesn't exist
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// mark it as deleted
|
||||
record.DeletedAt = timestamppb.Now()
|
||||
record.Version = formatVersion(version)
|
||||
|
||||
return nil
|
||||
},
|
||||
func(p redis.Pipeliner, version int64) error {
|
||||
err := redisSetRecord(ctx, p, db.cfg.recordType, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add it to the collection of deleted entries
|
||||
p.SAdd(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Get gets a record.
|
||||
func (db *DB) Get(ctx context.Context, id string) (record *databroker.Record, err error) {
|
||||
// Get gets a record from redis.
|
||||
func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
|
||||
|
||||
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
|
||||
return record, err
|
||||
}
|
||||
key, field := getHashKey(recordType, id)
|
||||
cmd := backend.client.HGet(ctx, key, field)
|
||||
raw, err := cmd.Result()
|
||||
if err == redis.Nil {
|
||||
return nil, storage.ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// List lists all the records changed since the sinceVersion. Records are sorted in version order.
|
||||
func (db *DB) List(ctx context.Context, sinceVersion string) (records []*databroker.Record, err error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.redis.List")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "list", err) }(time.Now())
|
||||
|
||||
var ids []string
|
||||
ids, err = redisListIDsSince(ctx, db.client, db.cfg.recordType, sinceVersion)
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(raw), &record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records, err = redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
|
||||
return records, err
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// Put updates a record.
|
||||
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
|
||||
// GetAll gets all the records from redis.
|
||||
func (backend *Backend) GetAll(ctx context.Context) (records []*databroker.Record, latestRecordVersion uint64, err error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.redis.GetAll")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "getall", err) }(time.Now())
|
||||
|
||||
p := backend.client.Pipeline()
|
||||
lastVersionCmd := p.Get(ctx, lastVersionKey)
|
||||
resultsCmd := p.HVals(ctx, recordHashKey)
|
||||
_, err = p.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
latestRecordVersion, err = lastVersionCmd.Uint64()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
latestRecordVersion = 0
|
||||
} else if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
results, err := resultsCmd.Result()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
var record databroker.Record
|
||||
err := proto.Unmarshal([]byte(result), &record)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("redis: invalid record detected")
|
||||
continue
|
||||
}
|
||||
records = append(records, &record)
|
||||
}
|
||||
return records, latestRecordVersion, nil
|
||||
}
|
||||
|
||||
// Put puts a record into redis.
|
||||
func (backend *Backend) Put(ctx context.Context, record *databroker.Record) (err error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.redis.Put")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
|
||||
|
||||
var record *databroker.Record
|
||||
err = db.incrementVersion(ctx,
|
||||
func(tx *redis.Tx, version int64) error {
|
||||
var err error
|
||||
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
|
||||
if errors.Is(err, redis.Nil) {
|
||||
record = new(databroker.Record)
|
||||
record.CreatedAt = timestamppb.Now()
|
||||
} else if err != nil {
|
||||
return backend.incrementVersion(ctx,
|
||||
func(tx *redis.Tx, version uint64) error {
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version = version
|
||||
return nil
|
||||
},
|
||||
func(p redis.Pipeliner, version uint64) error {
|
||||
bs, err := proto.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Type = db.cfg.recordType
|
||||
record.Id = id
|
||||
record.Data = data
|
||||
record.Version = formatVersion(version)
|
||||
|
||||
key, field := getHashKey(record.GetType(), record.GetId())
|
||||
if record.DeletedAt != nil {
|
||||
p.HDel(ctx, key, field)
|
||||
} else {
|
||||
p.HSet(ctx, key, field, bs)
|
||||
}
|
||||
p.ZAdd(ctx, changesSetKey, &redis.Z{
|
||||
Score: float64(version),
|
||||
Member: bs,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
func(p redis.Pipeliner, version int64) error {
|
||||
return redisSetRecord(ctx, p, db.cfg.recordType, record)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Watch returns a channel that is signaled any time the last version is incremented (ie on Put/Delete).
|
||||
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
||||
s := signal.New()
|
||||
ch := s.Bind()
|
||||
go func() {
|
||||
defer s.Unbind(ch)
|
||||
defer close(ch)
|
||||
|
||||
// force a check
|
||||
poll := time.NewTicker(watchPollInterval)
|
||||
defer poll.Stop()
|
||||
|
||||
// use pub/sub for quicker notify
|
||||
pubsub := db.client.Subscribe(ctx, formatLastVersionChannelKey(db.cfg.recordType))
|
||||
defer func() { _ = pubsub.Close() }()
|
||||
pubsubCh := pubsub.Channel()
|
||||
|
||||
var lastVersion int64
|
||||
|
||||
for {
|
||||
v, err := redisGetLastVersion(ctx, db.client, db.cfg.recordType)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("redis: error retrieving last version")
|
||||
} else if v != lastVersion {
|
||||
// don't broadcast the first time
|
||||
if lastVersion != 0 {
|
||||
s.Broadcast()
|
||||
}
|
||||
lastVersion = v
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-db.closed:
|
||||
return
|
||||
case <-poll.C:
|
||||
case <-pubsubCh:
|
||||
// re-check
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
// Sync returns a record stream of any records changed after the specified version.
|
||||
func (backend *Backend) Sync(ctx context.Context, version uint64) (storage.RecordStream, error) {
|
||||
return newRecordStream(ctx, backend, version), nil
|
||||
}
|
||||
|
||||
// incrementVersion increments the last version key, runs the code in `query`, then attempts to commit the code in
|
||||
// `commit`. If the last version changes in the interim, we will retry the transaction.
|
||||
func (db *DB) incrementVersion(ctx context.Context,
|
||||
query func(tx *redis.Tx, version int64) error,
|
||||
commit func(p redis.Pipeliner, version int64) error,
|
||||
func (backend *Backend) incrementVersion(ctx context.Context,
|
||||
query func(tx *redis.Tx, version uint64) error,
|
||||
commit func(p redis.Pipeliner, version uint64) error,
|
||||
) error {
|
||||
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
|
||||
txf := func(tx *redis.Tx) error {
|
||||
version, err := redisGetLastVersion(ctx, tx, db.cfg.recordType)
|
||||
if err != nil {
|
||||
version, err := tx.Get(ctx, lastVersionKey).Uint64()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
version = 0
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
version++
|
||||
|
@ -260,16 +220,23 @@ func (db *DB) incrementVersion(ctx context.Context,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Set(ctx, formatLastVersionKey(db.cfg.recordType), version, 0)
|
||||
p.Publish(ctx, formatLastVersionChannelKey(db.cfg.recordType), version)
|
||||
p.Set(ctx, lastVersionKey, version, 0)
|
||||
p.Publish(ctx, lastVersionChKey, version)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
for i := 0; i < maxTransactionRetries; i++ {
|
||||
err := db.client.Watch(ctx, txf, formatLastVersionKey(db.cfg.recordType))
|
||||
err := backend.client.Watch(ctx, txf, lastVersionKey)
|
||||
if errors.Is(err, redis.TxFailedErr) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
continue // retry
|
||||
} else if err != nil {
|
||||
return err
|
||||
|
@ -281,121 +248,81 @@ func (db *DB) incrementVersion(ctx context.Context,
|
|||
return ErrExceededMaxRetries
|
||||
}
|
||||
|
||||
func redisGetLastVersion(ctx context.Context, c redis.Cmdable, recordType string) (int64, error) {
|
||||
version, err := c.Get(ctx, formatLastVersionKey(recordType)).Int64()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
version = 0
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
func (backend *Backend) listenForVersionChanges() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-backend.closed
|
||||
cancel()
|
||||
}()
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
|
||||
outer:
|
||||
for {
|
||||
pubsub := backend.client.Subscribe(ctx, lastVersionChKey)
|
||||
for {
|
||||
msg, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
continue outer
|
||||
}
|
||||
bo.Reset()
|
||||
|
||||
switch msg.(type) {
|
||||
case *redis.Message:
|
||||
backend.onChange.Broadcast()
|
||||
}
|
||||
}
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
func redisGetRecord(ctx context.Context, c redis.Cmdable, recordType string, id string) (*databroker.Record, error) {
|
||||
records, err := redisGetRecords(ctx, c, recordType, []string{id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(records) < 1 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return records[0], nil
|
||||
}
|
||||
|
||||
func redisGetRecords(ctx context.Context, c redis.Cmdable, recordType string, ids []string) ([]*databroker.Record, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
results, err := c.HMGet(ctx, formatRecordsKey(recordType), ids...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
records := make([]*databroker.Record, 0, len(results))
|
||||
for _, result := range results {
|
||||
// results are returned as either nil or a string
|
||||
if result == nil {
|
||||
continue
|
||||
}
|
||||
rawstr, ok := result.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var record databroker.Record
|
||||
err := proto.Unmarshal([]byte(rawstr), &record)
|
||||
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
|
||||
ctx := context.Background()
|
||||
for {
|
||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
||||
Min: "-inf",
|
||||
Max: "+inf",
|
||||
Offset: 0,
|
||||
Count: 1,
|
||||
})
|
||||
results, err := cmd.Result()
|
||||
if err != nil {
|
||||
continue
|
||||
log.Error().Err(err).Msg("redis: error retrieving changes for expiration")
|
||||
return
|
||||
}
|
||||
|
||||
// nothing left to do
|
||||
if len(results) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(results[0]), &record)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("redis: invalid record detected")
|
||||
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
|
||||
}
|
||||
|
||||
// if the record's modified timestamp is after the cutoff, we're all done, so break
|
||||
if record.GetModifiedAt().AsTime().After(cutoff) {
|
||||
break
|
||||
}
|
||||
|
||||
// remove the record
|
||||
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("redis: error removing member")
|
||||
return
|
||||
}
|
||||
records = append(records, &record)
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func redisListIDsSince(ctx context.Context,
|
||||
c redis.Cmdable, recordType string,
|
||||
sinceVersion string,
|
||||
) ([]string, error) {
|
||||
v, err := strconv.ParseInt(sinceVersion, 16, 64)
|
||||
if err != nil {
|
||||
v = 0
|
||||
}
|
||||
rng := &redis.ZRangeBy{
|
||||
Min: fmt.Sprintf("(%d", v),
|
||||
Max: "+inf",
|
||||
}
|
||||
return c.ZRangeByScore(ctx, formatVersionSetKey(recordType), rng).Result()
|
||||
}
|
||||
|
||||
func redisSetRecord(ctx context.Context, p redis.Pipeliner, recordType string, record *databroker.Record) error {
|
||||
v, err := strconv.ParseInt(record.GetVersion(), 16, 64)
|
||||
if err != nil {
|
||||
v = 0
|
||||
}
|
||||
|
||||
raw, err := proto.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// store the record in the hash
|
||||
p.HSet(ctx, formatRecordsKey(recordType), record.GetId(), string(raw))
|
||||
// set its score for sorting by version
|
||||
p.ZAdd(ctx, formatVersionSetKey(recordType), &redis.Z{
|
||||
Score: float64(v),
|
||||
Member: record.GetId(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatDeletedSetKey(recordType string) string {
|
||||
return fmt.Sprintf("%s_deleted_set", recordType)
|
||||
}
|
||||
|
||||
func formatLastVersionChannelKey(recordType string) string {
|
||||
return fmt.Sprintf("%s_last_version_ch", recordType)
|
||||
}
|
||||
|
||||
func formatLastVersionKey(recordType string) string {
|
||||
return fmt.Sprintf("%s_last_version", recordType)
|
||||
}
|
||||
|
||||
func formatRecordsKey(recordType string) string {
|
||||
return recordType
|
||||
}
|
||||
|
||||
func formatVersion(version int64) string {
|
||||
return fmt.Sprintf("%012d", version)
|
||||
}
|
||||
|
||||
func formatVersionSetKey(recordType string) string {
|
||||
return fmt.Sprintf("%s_version_set", recordType)
|
||||
}
|
||||
|
||||
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
|
||||
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
|
||||
Operation: operation,
|
||||
Error: err,
|
||||
Backend: Name,
|
||||
}, time.Since(startTime))
|
||||
func getHashKey(recordType, id string) (key, field string) {
|
||||
return recordHashKey, fmt.Sprintf("%s/%s", recordType, id)
|
||||
}
|
||||
|
|
|
@ -2,192 +2,172 @@ package redis
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
var db *DB
|
||||
|
||||
func cleanup(ctx context.Context, db *DB, t *testing.T) {
|
||||
require.NoError(t, db.client.FlushAll(ctx).Err())
|
||||
}
|
||||
|
||||
func tlsConfig(rawURL string, t *testing.T) *tls.Config {
|
||||
if !strings.HasPrefix(rawURL, "rediss") {
|
||||
return nil
|
||||
}
|
||||
cert, err := cryptutil.CertificateFromFile("./testdata/tls/redis.crt", "./testdata/tls/redis.key")
|
||||
require.NoError(t, err)
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCert, err := ioutil.ReadFile("./testdata/tls/ca.crt")
|
||||
require.NoError(t, err)
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
func runWithRedisDockerImage(t *testing.T, runOpts *dockertest.RunOptions, withTLS bool, testFunc func(t *testing.T)) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
resource, err := pool.RunWithOptions(runOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start resource: %s", err)
|
||||
}
|
||||
resource.Expire(30)
|
||||
|
||||
defer func() {
|
||||
if err := pool.Purge(resource); err != nil {
|
||||
t.Fatalf("Could not purge resource: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scheme := "redis"
|
||||
if withTLS {
|
||||
scheme = "rediss"
|
||||
}
|
||||
address := fmt.Sprintf(scheme+"://localhost:%s/0", resource.GetPort("6379/tcp"))
|
||||
if err := pool.Retry(func() error {
|
||||
var err error
|
||||
db, err = New(address, WithRecordType("record_type"), WithTLSConfig(tlsConfig(address, t)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.client.Ping(context.Background()).Err()
|
||||
return err
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
testFunc(t)
|
||||
}
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
func TestBackend(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
assert.NoError(t, err)
|
||||
|
||||
tlsCmd := []string{
|
||||
"--port", "0",
|
||||
"--tls-port", "6379",
|
||||
"--tls-cert-file", "/tls/redis.crt",
|
||||
"--tls-key-file", "/tls/redis.key",
|
||||
"--tls-ca-cert-file", "/tls/ca.crt",
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
withTLS bool
|
||||
runOpts *dockertest.RunOptions
|
||||
}{
|
||||
{"redis", false, &dockertest.RunOptions{Repository: "redis", Tag: "latest"}},
|
||||
{"redis TLS", true, &dockertest.RunOptions{Repository: "redis", Tag: "latest", Cmd: tlsCmd, Mounts: []string{cwd + "/testdata/tls:/tls"}}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
runWithRedisDockerImage(t, tc.runOpts, tc.withTLS, testDB)
|
||||
})
|
||||
for _, useTLS := range []bool{true, false} {
|
||||
require.NoError(t, testutil.WithTestRedis(useTLS, func(rawURL string) error {
|
||||
ctx := context.Background()
|
||||
var opts []Option
|
||||
if useTLS {
|
||||
opts = append(opts, WithTLSConfig(testutil.RedisTLSConfig()))
|
||||
}
|
||||
backend, err := New(rawURL, opts...)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = backend.Close() }()
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: "abcd",
|
||||
Data: data,
|
||||
}))
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.Equal(t, data, record.Data)
|
||||
assert.Nil(t, record.DeletedAt)
|
||||
assert.Equal(t, "abcd", record.Id)
|
||||
assert.NotNil(t, record.ModifiedAt)
|
||||
assert.Equal(t, "TYPE", record.Type)
|
||||
assert.Equal(t, uint64(1), record.Version)
|
||||
}
|
||||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: "abcd",
|
||||
DeletedAt: timestamppb.Now(),
|
||||
}))
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get all records", func(t *testing.T) {
|
||||
for i := 0; i < 1000; i++ {
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: fmt.Sprint(i),
|
||||
}))
|
||||
}
|
||||
records, version, err := backend.GetAll(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 1000)
|
||||
assert.Equal(t, uint64(1002), version)
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func testDB(t *testing.T) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
users := []*directory.User{
|
||||
{Id: "u1", GroupIds: []string{"test", "admin"}},
|
||||
{Id: "u2"},
|
||||
{Id: "u3", GroupIds: []string{"test"}},
|
||||
func TestChangeSignal(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
ids := []string{"a", "b", "c"}
|
||||
id := ids[0]
|
||||
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
record, err := db.Get(ctx, id)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data, _ := anypb.New(users[0])
|
||||
assert.NoError(t, db.Put(ctx, id, data))
|
||||
record, err := db.Get(ctx, id)
|
||||
ctx := context.Background()
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*10)
|
||||
defer clearTimeout()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
backend1, err := New(rawURL)
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.CreatedAt)
|
||||
assert.NotEmpty(t, record.Data)
|
||||
assert.Nil(t, record.DeletedAt)
|
||||
assert.Equal(t, "a", record.Id)
|
||||
assert.NotNil(t, record.ModifiedAt)
|
||||
assert.Equal(t, "000000000001", record.Version)
|
||||
}
|
||||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
original, err := db.Get(ctx, id)
|
||||
defer func() { _ = backend1.Close() }()
|
||||
|
||||
backend2, err := New(rawURL)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, db.Delete(ctx, id))
|
||||
record, err := db.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record)
|
||||
assert.NotNil(t, record.DeletedAt)
|
||||
assert.NotEqual(t, original.GetVersion(), record.GetVersion())
|
||||
})
|
||||
t.Run("clear deleted", func(t *testing.T) {
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
record, err := db.Get(ctx, id)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("list", func(t *testing.T) {
|
||||
cleanup(ctx, db, t)
|
||||
defer func() { _ = backend2.Close() }()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
id := fmt.Sprintf("%02d", i)
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, id, data))
|
||||
}
|
||||
ch := backend1.onChange.Bind()
|
||||
defer backend1.onChange.Unbind(ch)
|
||||
|
||||
records, err := db.List(ctx, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 10)
|
||||
records, err = db.List(ctx, "000000000005")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 5)
|
||||
records, err = db.List(ctx, "000000000010")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 0)
|
||||
})
|
||||
t.Run("watch", func(t *testing.T) {
|
||||
ch := db.Watch(ctx)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
go db.Put(ctx, "WATCH", new(anypb.Any))
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Millisecond * 100)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
_ = backend2.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: "ID",
|
||||
})
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second * 10):
|
||||
t.Error("expected watch signal on put")
|
||||
case <-ctx.Done():
|
||||
t.Fatal("expected signal to be fired when another backend triggers a change")
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestExpiry(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
backend, err := New(rawURL, WithExpiry(0))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
assert.NoError(t, backend.Put(ctx, &databroker.Record{
|
||||
Type: "TYPE",
|
||||
Id: fmt.Sprint(i),
|
||||
}))
|
||||
}
|
||||
stream, err := backend.Sync(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
var records []*databroker.Record
|
||||
for stream.Next(false) {
|
||||
records = append(records, stream.Record())
|
||||
}
|
||||
_ = stream.Close()
|
||||
require.Len(t, records, 1000)
|
||||
|
||||
backend.removeChangesBefore(time.Now().Add(time.Second))
|
||||
|
||||
stream, err = backend.Sync(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
records = nil
|
||||
for stream.Next(false) {
|
||||
records = append(records, stream.Record())
|
||||
}
|
||||
_ = stream.Close()
|
||||
require.Len(t, records, 0)
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
|
105
pkg/storage/redis/stream.go
Normal file
105
pkg/storage/redis/stream.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type recordStream struct {
|
||||
ctx context.Context
|
||||
backend *Backend
|
||||
|
||||
changed chan struct{}
|
||||
version uint64
|
||||
record *databroker.Record
|
||||
err error
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newRecordStream(ctx context.Context, backend *Backend, version uint64) *recordStream {
|
||||
return &recordStream{
|
||||
ctx: ctx,
|
||||
backend: backend,
|
||||
|
||||
changed: backend.onChange.Bind(),
|
||||
version: version,
|
||||
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *recordStream) Close() error {
|
||||
stream.closeOnce.Do(func() {
|
||||
stream.backend.onChange.Unbind(stream.changed)
|
||||
close(stream.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stream *recordStream) Next(block bool) bool {
|
||||
if stream.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(watchPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{
|
||||
Min: fmt.Sprintf("(%d", stream.version),
|
||||
Max: "+inf",
|
||||
Offset: 0,
|
||||
Count: 1,
|
||||
})
|
||||
results, err := cmd.Result()
|
||||
if err != nil {
|
||||
stream.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
if len(results) > 0 {
|
||||
result := results[0]
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(result), &record)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("redis: invalid record detected")
|
||||
} else {
|
||||
stream.record = &record
|
||||
}
|
||||
stream.version++
|
||||
return true
|
||||
}
|
||||
|
||||
if block {
|
||||
select {
|
||||
case <-stream.ctx.Done():
|
||||
stream.err = stream.ctx.Err()
|
||||
return false
|
||||
case <-stream.closed:
|
||||
return false
|
||||
case <-ticker.C: // check again
|
||||
case <-stream.changed: // check again
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *recordStream) Record() *databroker.Record {
|
||||
return stream.record
|
||||
}
|
||||
|
||||
func (stream *recordStream) Err() error {
|
||||
return stream.err
|
||||
}
|
|
@ -3,8 +3,8 @@ package storage
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
@ -13,30 +13,39 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrNotFound = errors.New("record not found")
|
||||
ErrStreamClosed = errors.New("record stream closed")
|
||||
)
|
||||
|
||||
// A RecordStream is a stream of records.
|
||||
type RecordStream interface {
|
||||
// Close closes the record stream and releases any underlying resources.
|
||||
Close() error
|
||||
// Next is called to retrieve the next record. If one is available it will
|
||||
// be returned immediately. If none is available and block is true, the method
|
||||
// will block until one is available or an error occurs. The error should be
|
||||
// checked with a call to `.Err()`.
|
||||
Next(block bool) bool
|
||||
// Record returns the current record.
|
||||
Record() *databroker.Record
|
||||
// Err returns any error that occurred while streaming.
|
||||
Err() error
|
||||
}
|
||||
|
||||
// Backend is the interface required for a storage backend.
|
||||
type Backend interface {
|
||||
// Close closes the backend.
|
||||
Close() error
|
||||
|
||||
// Put is used to insert or update a record.
|
||||
Put(ctx context.Context, id string, data *anypb.Any) error
|
||||
|
||||
// Get is used to retrieve a record.
|
||||
Get(ctx context.Context, id string) (*databroker.Record, error)
|
||||
|
||||
// List is used to retrieve all the records since a version.
|
||||
List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
|
||||
|
||||
// Delete is used to mark a record as deleted.
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// ClearDeleted is used clear marked delete records.
|
||||
ClearDeleted(ctx context.Context, cutoff time.Time)
|
||||
|
||||
// Watch returns a channel to the caller. The channel is used to notify
|
||||
// about changes that happen in storage. When ctx is finished, Watch will close
|
||||
// the channel.
|
||||
Watch(ctx context.Context) <-chan struct{}
|
||||
Get(ctx context.Context, recordType, id string) (*databroker.Record, error)
|
||||
// GetAll gets all the records.
|
||||
GetAll(ctx context.Context) (records []*databroker.Record, version uint64, err error)
|
||||
// Put is used to insert or update a record.
|
||||
Put(ctx context.Context, record *databroker.Record) error
|
||||
// Sync syncs record changes after the specified version.
|
||||
Sync(ctx context.Context, version uint64) (RecordStream, error)
|
||||
}
|
||||
|
||||
// MatchAny searches any data with a query.
|
||||
|
|
|
@ -3,7 +3,6 @@ package storage
|
|||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
@ -13,50 +12,29 @@ import (
|
|||
)
|
||||
|
||||
type mockBackend struct {
|
||||
put func(ctx context.Context, id string, data *anypb.Any) error
|
||||
get func(ctx context.Context, id string) (*databroker.Record, error)
|
||||
getAll func(ctx context.Context) ([]*databroker.Record, error)
|
||||
list func(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
|
||||
delete func(ctx context.Context, id string) error
|
||||
clearDeleted func(ctx context.Context, cutoff time.Time)
|
||||
query func(ctx context.Context, query string, offset, limit int) ([]*databroker.Record, int, error)
|
||||
watch func(ctx context.Context) <-chan struct{}
|
||||
put func(ctx context.Context, record *databroker.Record) error
|
||||
get func(ctx context.Context, recordType, id string) (*databroker.Record, error)
|
||||
getAll func(ctx context.Context) ([]*databroker.Record, uint64, error)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||
return m.put(ctx, id, data)
|
||||
func (m *mockBackend) Put(ctx context.Context, record *databroker.Record) error {
|
||||
return m.put(ctx, record)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Get(ctx context.Context, id string) (*databroker.Record, error) {
|
||||
return m.get(ctx, id)
|
||||
func (m *mockBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
return m.get(ctx, recordType, id)
|
||||
}
|
||||
|
||||
func (m *mockBackend) GetAll(ctx context.Context) ([]*databroker.Record, error) {
|
||||
func (m *mockBackend) GetAll(ctx context.Context) ([]*databroker.Record, uint64, error) {
|
||||
return m.getAll(ctx)
|
||||
}
|
||||
|
||||
func (m *mockBackend) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
return m.list(ctx, sinceVersion)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Delete(ctx context.Context, id string) error {
|
||||
return m.delete(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mockBackend) ClearDeleted(ctx context.Context, cutoff time.Time) {
|
||||
m.clearDeleted(ctx, cutoff)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Query(ctx context.Context, query string, offset, limit int) ([]*databroker.Record, int, error) {
|
||||
return m.query(ctx, query, offset, limit)
|
||||
}
|
||||
|
||||
func (m *mockBackend) Watch(ctx context.Context) <-chan struct{} {
|
||||
return m.watch(ctx)
|
||||
func (m *mockBackend) Sync(ctx context.Context, version uint64) (RecordStream, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestMatchAny(t *testing.T) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue