feature/databroker: user data and session refactor project (#926)

* databroker: add databroker, identity manager, update cache (#864)

* databroker: add databroker, identity manager, update cache

* fix cache tests

* directory service (#885)

* directory: add google and okta

* add onelogin

* add directory provider

* initialize before sync, upate google provider, remove dead code

* add azure provider

* fix azure provider

* fix gitlab

* add gitlab test, fix azure test

* hook up okta

* remove dead code

* fix tests

* fix flaky test

* authorize: use databroker data for rego policy (#904)

* wip

* add directory provider

* initialize before sync, upate google provider, remove dead code

* fix flaky test

* update authorize to use databroker data

* implement signed jwt

* wait for session and user to appear

* fix test

* directory service (#885)

* directory: add google and okta

* add onelogin

* add directory provider

* initialize before sync, upate google provider, remove dead code

* add azure provider

* fix azure provider

* fix gitlab

* add gitlab test, fix azure test

* hook up okta

* remove dead code

* fix tests

* fix flaky test

* remove log line

* only redirect when no session id exists

* prepare rego query as part of create

* return on ctx done

* retry on disconnect for sync

* move jwt signing

* use !=

* use parent ctx for wait

* remove session state, remove logs

* rename function

* add log message

* pre-allocate slice

* use errgroup

* return nil on eof for sync

* move check

* disable timeout on gRPC requests in envoy

* fix gitlab test

* use v4 backoff

* authenticate: databroker changes (#914)

* wip

* add directory provider

* initialize before sync, upate google provider, remove dead code

* fix flaky test

* update authorize to use databroker data

* implement signed jwt

* wait for session and user to appear

* fix test

* directory service (#885)

* directory: add google and okta

* add onelogin

* add directory provider

* initialize before sync, upate google provider, remove dead code

* add azure provider

* fix azure provider

* fix gitlab

* add gitlab test, fix azure test

* hook up okta

* remove dead code

* fix tests

* fix flaky test

* remove log line

* only redirect when no session id exists

* prepare rego query as part of create

* return on ctx done

* retry on disconnect for sync

* move jwt signing

* use !=

* use parent ctx for wait

* remove session state, remove logs

* rename function

* add log message

* pre-allocate slice

* use errgroup

* return nil on eof for sync

* move check

* disable timeout on gRPC requests in envoy

* fix dashboard

* delete session on logout

* permanently delete sessions once they are marked as deleted

* remove permanent delete

* fix tests

* remove groups and refresh test

* databroker: remove dead code, rename cache url, move dashboard (#925)

* wip

* add directory provider

* initialize before sync, upate google provider, remove dead code

* fix flaky test

* update authorize to use databroker data

* implement signed jwt

* wait for session and user to appear

* fix test

* directory service (#885)

* directory: add google and okta

* add onelogin

* add directory provider

* initialize before sync, upate google provider, remove dead code

* add azure provider

* fix azure provider

* fix gitlab

* add gitlab test, fix azure test

* hook up okta

* remove dead code

* fix tests

* fix flaky test

* remove log line

* only redirect when no session id exists

* prepare rego query as part of create

* return on ctx done

* retry on disconnect for sync

* move jwt signing

* use !=

* use parent ctx for wait

* remove session state, remove logs

* rename function

* add log message

* pre-allocate slice

* use errgroup

* return nil on eof for sync

* move check

* disable timeout on gRPC requests in envoy

* fix dashboard

* delete session on logout

* permanently delete sessions once they are marked as deleted

* remove permanent delete

* fix tests

* remove cache service

* remove kv

* remove refresh docs

* remove obsolete cache docs

* add databroker url option

* cache: use memberlist to detect multiple instances

* add databroker service url

* remove cache service

* remove kv

* remove refresh docs

* remove obsolete cache docs

* add databroker url option

* cache: use memberlist to detect multiple instances

* add databroker service url

* wip

* remove groups and refresh test

* fix redirect, signout

* remove databroker client from proxy

* remove unused method

* remove user dashboard test

* handle missing session ids

* session: reject sessions with no id

* sessions: invalidate old sessions via databroker server version (#930)

* session: add a version field tied to the databroker server version that can be used to invalidate sessions

* fix tests

* add log

* authenticate: create user record immediately, call "get" directly in authorize (#931)
This commit is contained in:
Caleb Doxsey 2020-06-19 07:52:44 -06:00 committed by GitHub
parent 39cdb31170
commit dbd7f55b20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
115 changed files with 8479 additions and 3584 deletions

View file

@ -0,0 +1,45 @@
package memory
import "time"
var (
// DefaultDeletePermanentlyAfter is the default amount of time to wait before deleting
// a record permanently.
DefaultDeletePermanentlyAfter = time.Hour
// DefaultBTreeDegree is the default number of items to store in each node of the BTree.
DefaultBTreeDegree = 8
)
type serverConfig struct {
deletePermanentlyAfter time.Duration
btreeDegree int
}
func newServerConfig(options ...ServerOption) *serverConfig {
cfg := new(serverConfig)
WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg)
WithBTreeDegree(DefaultBTreeDegree)(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// A ServerOption customizes the server.
type ServerOption func(*serverConfig)
// WithBTreeDegree sets the number of items to store in each node of the BTree.
func WithBTreeDegree(degree int) ServerOption {
return func(cfg *serverConfig) {
cfg.btreeDegree = degree
}
}
// WithDeletePermanentlyAfter sets the deletePermanentlyAfter duration.
// If a record is deleted via Delete, it will be permanently deleted after
// the given duration.
func WithDeletePermanentlyAfter(dur time.Duration) ServerOption {
return func(cfg *serverConfig) {
cfg.deletePermanentlyAfter = dur
}
}

View file

@ -0,0 +1,141 @@
package memory
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/google/btree"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/internal/grpc/databroker"
)
type byIDRecord struct {
*databroker.Record
}
func (k byIDRecord) Less(than btree.Item) bool {
return k.Id < than.(byIDRecord).Id
}
type byVersionRecord struct {
*databroker.Record
}
func (k byVersionRecord) Less(than btree.Item) bool {
return k.Version < than.(byVersionRecord).Version
}
// DB is an in-memory database of records using b-trees.
type DB struct {
recordType string
lastVersion uint64
mu sync.Mutex
byID *btree.BTree
byVersion *btree.BTree
deletedIDs []string
}
// NewDB creates a new in-memory database for the given record type.
func NewDB(recordType string, btreeDegree int) *DB {
return &DB{
recordType: recordType,
byID: btree.New(btreeDegree),
byVersion: btree.New(btreeDegree),
}
}
// ClearDeleted clears all the currently deleted records older than the given cutoff.
func (db *DB) ClearDeleted(cutoff time.Time) {
db.mu.Lock()
defer db.mu.Unlock()
var remaining []string
for _, id := range db.deletedIDs {
record, _ := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
ts, _ := ptypes.Timestamp(record.DeletedAt)
if ts.Before(cutoff) {
db.byID.Delete(record)
db.byVersion.Delete(byVersionRecord(record))
} else {
remaining = append(remaining, id)
}
}
db.deletedIDs = remaining
}
// Delete marks a record as deleted.
func (db *DB) Delete(id string) {
db.replaceOrInsert(id, func(record *databroker.Record) {
record.DeletedAt = ptypes.TimestampNow()
db.deletedIDs = append(db.deletedIDs, id)
})
}
// Get gets a record from the db.
func (db *DB) Get(id string) *databroker.Record {
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok {
return nil
}
return record.Record
}
// GetAll gets all the records in the db.
func (db *DB) GetAll() []*databroker.Record {
var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record)
return true
})
return records
}
// List lists all the changes since the given version.
func (db *DB) List(sinceVersion string) []*databroker.Record {
var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord)
if record.Version > sinceVersion {
records = append(records, record.Record)
}
return true
})
return records
}
// Set replaces or inserts a record in the db.
func (db *DB) Set(id string, data *anypb.Any) {
db.replaceOrInsert(id, func(record *databroker.Record) {
record.Data = data
})
}
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
db.mu.Lock()
defer db.mu.Unlock()
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if ok {
db.byVersion.Delete(byVersionRecord(record))
record.Record = proto.Clone(record.Record).(*databroker.Record)
} else {
record.Record = new(databroker.Record)
}
f(record.Record)
if record.CreatedAt == nil {
record.CreatedAt = ptypes.TimestampNow()
}
record.ModifiedAt = ptypes.TimestampNow()
record.Type = db.recordType
record.Id = id
record.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
db.byID.ReplaceOrInsert(record)
db.byVersion.ReplaceOrInsert(byVersionRecord(record))
}

View file

@ -0,0 +1,60 @@
package memory
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
)
func TestDB(t *testing.T) {
db := NewDB("example", 2)
t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get("abcd"))
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
db.Set("abcd", data)
record := db.Get("abcd")
if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "abcd", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "example", record.Type)
assert.Equal(t, "000000000001", record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
db.Delete("abcd")
record := db.Get("abcd")
if assert.NotNil(t, record) {
assert.NotNil(t, record.DeletedAt)
}
})
t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(time.Now().Add(time.Second))
assert.Nil(t, db.Get("abcd"))
})
t.Run("keep remaining", func(t *testing.T) {
data := new(anypb.Any)
db.Set("abcd", data)
db.Delete("abcd")
db.ClearDeleted(time.Now().Add(-10 * time.Second))
assert.NotNil(t, db.Get("abcd"))
db.ClearDeleted(time.Now().Add(time.Second))
})
t.Run("list", func(t *testing.T) {
for i := 0; i < 10; i++ {
data := new(anypb.Any)
db.Set(fmt.Sprintf("%02d", i), data)
}
assert.Len(t, db.List(""), 10)
assert.Len(t, db.List("00000000000A"), 4)
assert.Len(t, db.List("00000000000F"), 0)
})
}

View file

@ -0,0 +1,236 @@
// Package memory contains an in-memory data broker implementation.
package memory
import (
"context"
"reflect"
"sort"
"sync"
"time"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/uuid"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/internal/grpc/databroker"
"github.com/pomerium/pomerium/internal/log"
)
// Server implements the databroker service using an in memory database.
type Server struct {
version string
cfg *serverConfig
log zerolog.Logger
mu sync.RWMutex
byType map[string]*DB
onchange *Signal
}
// New creates a new server.
func New(options ...ServerOption) *Server {
cfg := newServerConfig(options...)
srv := &Server{
version: uuid.New().String(),
cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]*DB),
onchange: NewSignal(),
}
go func() {
ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2)
defer ticker.Stop()
for range ticker.C {
var recordTypes []string
srv.mu.RLock()
for recordType := range srv.byType {
recordTypes = append(recordTypes, recordType)
}
srv.mu.RUnlock()
for _, recordType := range recordTypes {
srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter))
}
}
}()
return srv
}
// Delete deletes a record from the in-memory list.
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("delete")
defer srv.onchange.Broadcast()
srv.getDB(req.GetType()).Delete(req.GetId())
return new(empty.Empty), nil
}
// Get gets a record from the in-memory list.
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("get")
record := srv.getDB(req.GetType()).Get(req.GetId())
if record == nil {
return nil, status.Error(codes.NotFound, "record not found")
}
return &databroker.GetResponse{Record: record}, nil
}
// GetAll gets all the records from the in-memory list.
func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*databroker.GetAllResponse, error) {
srv.log.Info().
Str("type", req.GetType()).
Msg("get all")
records := srv.getDB(req.GetType()).GetAll()
var recordVersion string
for _, record := range records {
if record.GetVersion() > recordVersion {
recordVersion = record.GetVersion()
}
}
return &databroker.GetAllResponse{
ServerVersion: srv.version,
RecordVersion: recordVersion,
Records: records,
}, nil
}
// Set updates a record in the in-memory list, or adds a new one.
func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databroker.SetResponse, error) {
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("set")
defer srv.onchange.Broadcast()
db := srv.getDB(req.GetType())
db.Set(req.GetId(), req.GetData())
record := db.Get(req.GetId())
return &databroker.SetResponse{
Record: record,
ServerVersion: srv.version,
}, nil
}
// Sync streams updates for the given record type.
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
srv.log.Info().
Str("type", req.GetType()).
Str("server_version", req.GetServerVersion()).
Str("record_version", req.GetRecordVersion()).
Msg("sync")
recordVersion := req.GetRecordVersion()
// reset record version if the server versions don't match
if req.GetServerVersion() != srv.version {
recordVersion = ""
}
db := srv.getDB(req.GetType())
ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch)
for {
updated := db.List(recordVersion)
if len(updated) > 0 {
sort.Slice(updated, func(i, j int) bool {
return updated[i].Version < updated[j].Version
})
recordVersion = updated[len(updated)-1].Version
err := stream.Send(&databroker.SyncResponse{
ServerVersion: srv.version,
Records: updated,
})
if err != nil {
return err
}
}
select {
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
}
}
}
// GetTypes returns all the known record types.
func (srv *Server) GetTypes(_ context.Context, _ *emptypb.Empty) (*databroker.GetTypesResponse, error) {
var recordTypes []string
srv.mu.RLock()
for recordType := range srv.byType {
recordTypes = append(recordTypes, recordType)
}
srv.mu.RUnlock()
sort.Strings(recordTypes)
return &databroker.GetTypesResponse{
Types: recordTypes,
}, nil
}
// SyncTypes synchronizes all the known record types.
func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerService_SyncTypesServer) error {
srv.log.Info().
Msg("sync types")
ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch)
var prev []string
for {
res, err := srv.GetTypes(stream.Context(), req)
if err != nil {
return err
}
if prev == nil || !reflect.DeepEqual(prev, res.Types) {
err := stream.Send(res)
if err != nil {
return err
}
prev = res.Types
}
select {
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
}
}
}
func (srv *Server) getDB(recordType string) *DB {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
srv.mu.RLock()
db := srv.byType[recordType]
srv.mu.RUnlock()
if db == nil {
srv.mu.Lock()
db = srv.byType[recordType]
if db == nil {
db = NewDB(recordType, srv.cfg.btreeDegree)
srv.byType[recordType] = db
}
srv.mu.Unlock()
}
return db
}

View file

@ -0,0 +1,45 @@
package memory
import "sync"
// A Signal is used to let multiple listeners know when something happened.
type Signal struct {
mu sync.Mutex
chs map[chan struct{}]struct{}
}
// NewSignal creates a new Signal.
func NewSignal() *Signal {
return &Signal{
chs: make(map[chan struct{}]struct{}),
}
}
// Broadcast signals all the listeners. Broadcast never blocks.
func (s *Signal) Broadcast() {
s.mu.Lock()
for ch := range s.chs {
select {
case ch <- struct{}{}:
default:
}
}
s.mu.Unlock()
}
// Bind creates a new listening channel bound to the signal. The channel used has a size of 1
// and any given broadcast will signal at least one event, but may signal more than one.
func (s *Signal) Bind() chan struct{} {
ch := make(chan struct{}, 1)
s.mu.Lock()
s.chs[ch] = struct{}{}
s.mu.Unlock()
return ch
}
// Unbind stops the listening channel bound to the signal.
func (s *Signal) Unbind(ch chan struct{}) {
s.mu.Lock()
delete(s.chs, ch)
s.mu.Unlock()
}