authenticate: always update user record on login (#2719)

* authenticate: always update user record on login

* identity: fix user refresh

* add test for manager update

* fix time
This commit is contained in:
Caleb Doxsey 2021-11-01 14:18:18 -06:00 committed by GitHub
parent 90f2b00bb6
commit b0f8c055ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 24 deletions

View file

@ -539,23 +539,22 @@ func (a *Authenticate) saveSessionToDataBroker(
s.SetRawIDToken(claims.RawIDToken)
s.AddClaims(claims.Flatten())
// if no user exists yet, create a new one
currentUser, _ := user.Get(ctx, state.dataBrokerClient, s.GetUserId())
if currentUser == nil {
mu := manager.User{
User: &user.User{
Id: s.GetUserId(),
},
}
err := a.provider.Load().UpdateUserInfo(ctx, accessToken, &mu)
if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
_, err = user.Put(ctx, state.dataBrokerClient, mu.User)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)
var managerUser manager.User
managerUser.User, _ = user.Get(ctx, state.dataBrokerClient, s.GetUserId())
if managerUser.User == nil {
// if no user exists yet, create a new one
managerUser.User = &user.User{
Id: s.GetUserId(),
}
}
err := a.provider.Load().UpdateUserInfo(ctx, accessToken, &managerUser)
if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
_, err = user.Put(ctx, state.dataBrokerClient, managerUser.User)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)
}
res, err := session.Put(ctx, state.dataBrokerClient, s)
if err != nil {

View file

@ -23,6 +23,7 @@ type config struct {
groupRefreshTimeout time.Duration
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
now func() time.Time
}
func newConfig(options ...Option) *config {
@ -31,6 +32,7 @@ func newConfig(options ...Option) *config {
WithGroupRefreshTimeout(defaultGroupRefreshTimeout)(cfg)
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
WithNow(time.Now)(cfg)
for _, option := range options {
option(cfg)
}
@ -89,6 +91,13 @@ func WithSessionRefreshCoolOffDuration(dur time.Duration) Option {
}
}
// WithNow customizes the time.Now function used by the manager.
func WithNow(now func() time.Time) Option {
return func(cfg *config) {
cfg.now = now
}
}
type atomicConfig struct {
value atomic.Value
}

View file

@ -77,6 +77,7 @@ func New(
}
mgr.directoryBackoff = backoff.NewExponentialBackOff()
mgr.directoryBackoff.MaxElapsedTime = 0
mgr.reset()
mgr.UpdateConfig(options...)
return mgr
}
@ -128,10 +129,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.directoryGroups = make(map[string]*directory.Group)
mgr.directoryUsers = make(map[string]*directory.User)
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
mgr.reset()
}
select {
case <-ctx.Done():
@ -161,10 +159,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.directoryGroups = make(map[string]*directory.Group)
mgr.directoryUsers = make(map[string]*directory.User)
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
mgr.reset()
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
case <-timer.C:
@ -571,7 +566,7 @@ func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, u
}
u, _ := mgr.users.Get(user.GetId())
u.lastRefresh = time.Now()
u.lastRefresh = mgr.cfg.Load().now()
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
u.User = user
mgr.users.ReplaceOrInsert(u)
@ -595,6 +590,14 @@ func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Sessio
}
}
// reset resets all the manager datastructures to their initial state
func (mgr *Manager) reset() {
mgr.directoryGroups = make(map[string]*directory.Group)
mgr.directoryUsers = make(map[string]*directory.User)
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
}
func isTemporaryError(err error) bool {
if err == nil {
return false

View file

@ -7,8 +7,13 @@ import (
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
)
type mockProvider struct {
@ -24,6 +29,43 @@ func (mock mockProvider) UserGroups(ctx context.Context) ([]*directory.Group, []
return mock.userGroups(ctx)
}
func TestManager_onUpdateRecords(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
now := time.Now()
mgr := New(
WithDirectoryProvider(mockProvider{}),
WithGroupRefreshInterval(time.Hour),
WithNow(func() time.Time {
return now
}),
)
mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(&directory.Group{Id: "group1", Name: "group 1", Email: "group1@example.com"}),
mkRecord(&directory.User{Id: "user1", DisplayName: "user 1", Email: "user1@example.com", GroupIds: []string{"group1s"}}),
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
assert.NotNil(t, mgr.directoryGroups["group1"])
assert.NotNil(t, mgr.directoryUsers["user1"])
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
}
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
tm, id := mgr.userScheduler.Next()
assert.Equal(t, now.Add(time.Hour), tm)
assert.Equal(t, "user1", id)
}
}
func TestManager_refreshDirectoryUserGroups(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
@ -56,3 +98,17 @@ func TestManager_refreshDirectoryUserGroups(t *testing.T) {
assert.Equal(t, time.Hour, dur3)
})
}
func mkRecord(msg recordable) *databroker.Record {
any := protoutil.NewAny(msg)
return &databroker.Record{
Type: any.GetTypeUrl(),
Id: msg.GetId(),
Data: any,
}
}
type recordable interface {
proto.Message
GetId() string
}