diff --git a/authorize/authorize.go b/authorize/authorize.go index 188ab6a5c..0ac75435a 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "html/template" + "sync" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/config" @@ -24,6 +25,11 @@ type Authorize struct { templates *template.Template dataBrokerInitialSync chan struct{} + + // The stateLock prevents updating the evaluator store simultaneously with an evaluation. + // This should provide a consistent view of the data at a given server/record version and + // avoid partial updates. + stateLock sync.RWMutex } // New validates and creates a new Authorize service from a set of config options. diff --git a/authorize/evaluator/store.go b/authorize/evaluator/store.go index 9ef7eec0c..cbf7dd3f2 100644 --- a/authorize/evaluator/store.go +++ b/authorize/evaluator/store.go @@ -23,19 +23,75 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" ) +type dataBrokerData struct { + mu sync.RWMutex + m map[string]map[string]proto.Message +} + +func newDataBrokerData() *dataBrokerData { + return &dataBrokerData{ + m: map[string]map[string]proto.Message{}, + } +} + +func (dbd *dataBrokerData) clear() { + dbd.mu.Lock() + defer dbd.mu.Unlock() + + dbd.m = map[string]map[string]proto.Message{} +} + +func (dbd *dataBrokerData) delete(typeURL, id string) { + dbd.mu.Lock() + defer dbd.mu.Unlock() + + m, ok := dbd.m[typeURL] + if !ok { + return + } + + delete(m, id) + + if len(m) == 0 { + delete(dbd.m, typeURL) + } +} + +func (dbd *dataBrokerData) get(typeURL, id string) proto.Message { + dbd.mu.RLock() + defer dbd.mu.RUnlock() + + m, ok := dbd.m[typeURL] + if !ok { + return nil + } + return m[id] +} + +func (dbd *dataBrokerData) set(typeURL, id string, msg proto.Message) { + dbd.mu.Lock() + defer dbd.mu.Unlock() + + m, ok := dbd.m[typeURL] + if !ok { + m = map[string]proto.Message{} + dbd.m[typeURL] = m + } + m[id] = msg +} + // A Store stores data for the OPA rego policy evaluation. type Store struct { storage.Store - mu sync.RWMutex - dataBrokerData map[string]map[string]proto.Message + dataBrokerData *dataBrokerData } // NewStore creates a new Store. func NewStore() *Store { return &Store{ Store: inmem.New(), - dataBrokerData: make(map[string]map[string]proto.Message), + dataBrokerData: newDataBrokerData(), } } @@ -63,49 +119,15 @@ func NewStoreFromProtos(serverVersion uint64, msgs ...proto.Message) *Store { return s } -// NewTransaction calls the underlying store NewTransaction and takes the transaction lock. -func (s *Store) NewTransaction(ctx context.Context, params ...storage.TransactionParams) (storage.Transaction, error) { - txn, err := s.Store.NewTransaction(ctx, params...) - if err != nil { - return nil, err - } - s.mu.RLock() - return txn, err -} - -// Commit calls the underlying store Commit and releases the transaction lock. -func (s *Store) Commit(ctx context.Context, txn storage.Transaction) error { - err := s.Store.Commit(ctx, txn) - s.mu.RUnlock() - return err -} - -// Abort calls the underlying store Abort and releases the transaction lock. -func (s *Store) Abort(ctx context.Context, txn storage.Transaction) { - s.Store.Abort(ctx, txn) - s.mu.RUnlock() -} - // ClearRecords removes all the records from the store. func (s *Store) ClearRecords() { - s.mu.Lock() - defer s.mu.Unlock() - - s.dataBrokerData = make(map[string]map[string]proto.Message) + s.dataBrokerData.clear() } // GetRecordData gets a record's data from the store. `nil` is returned // if no record exists for the given type and id. func (s *Store) GetRecordData(typeURL, id string) proto.Message { - s.mu.RLock() - defer s.mu.RUnlock() - - m, ok := s.dataBrokerData[typeURL] - if !ok { - return nil - } - - return m[id] + return s.dataBrokerData.get(typeURL, id) } // UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction. @@ -131,22 +153,14 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) { // UpdateRecord updates a record in the store. func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) { - s.mu.Lock() - defer s.mu.Unlock() - + if record.GetDeletedAt() != nil { + s.dataBrokerData.delete(record.GetType(), record.GetId()) + } else { + msg, _ := record.GetData().UnmarshalNew() + s.dataBrokerData.set(record.GetType(), record.GetId(), msg) + } s.write("/databroker_server_version", fmt.Sprint(serverVersion)) s.write("/databroker_record_version", fmt.Sprint(record.GetVersion())) - - m, ok := s.dataBrokerData[record.GetType()] - if !ok { - m = make(map[string]proto.Message) - s.dataBrokerData[record.GetType()] = m - } - if record.GetDeletedAt() != nil { - delete(m, record.GetId()) - } else { - m[record.GetId()], _ = record.GetData().UnmarshalNew() - } } // UpdateSigningKey updates the signing key stored in the database. Signing operations diff --git a/authorize/grpc.go b/authorize/grpc.go index 06ead6e46..f113e6efe 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -56,7 +56,10 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe return nil, err } + // take the state lock here so we don't update while evaluating + a.stateLock.RLock() reply, err := state.evaluator.Evaluate(ctx, req) + a.stateLock.RUnlock() if err != nil { log.Error(ctx).Err(err).Msg("error during OPA evaluation") return nil, err diff --git a/authorize/sync.go b/authorize/sync.go index 3d4af286a..1326c57bc 100644 --- a/authorize/sync.go +++ b/authorize/sync.go @@ -47,13 +47,17 @@ func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrok } func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) { + syncer.authorize.stateLock.Lock() syncer.authorize.store.ClearRecords() + syncer.authorize.stateLock.Unlock() } func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { + syncer.authorize.stateLock.Lock() for _, record := range records { syncer.authorize.store.UpdateRecord(serverVersion, record) } + syncer.authorize.stateLock.Unlock() // the first time we update records we signal the initial sync syncer.signalOnce.Do(func() { diff --git a/pkg/grpc/databroker/syncer.go b/pkg/grpc/databroker/syncer.go index 8594f4bcd..c9d72a6bf 100644 --- a/pkg/grpc/databroker/syncer.go +++ b/pkg/grpc/databroker/syncer.go @@ -156,8 +156,8 @@ func (syncer *Syncer) sync(ctx context.Context) error { } log.Debug(syncer.logCtx(ctx)). - Uint("version", uint(res.Record.GetVersion())). - Str("id", res.Record.Id). + Uint("version", uint(res.GetRecord().GetVersion())). + Str("id", res.GetRecord().GetId()). Msg("syncer got record") if syncer.recordVersion != res.GetRecord().GetVersion()-1 {