mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-05 20:32:57 +02:00
fix redirect loop, remove user/session services, remove duplicate deleted_at fields (#1162)
* fix redirect loop, remove user/session services, remove duplicate deleted_at fields * change loop * reuse err variable * wrap errors, use cookie timeout * wrap error, duplicate if
This commit is contained in:
parent
714363fb07
commit
97f85481f8
16 changed files with 288 additions and 918 deletions
|
@ -140,6 +140,9 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
if record.DeletedAt != nil {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
return &databroker.GetResponse{Record: record}, nil
|
||||
}
|
||||
|
||||
|
@ -155,16 +158,27 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records, err := db.GetAll(ctx)
|
||||
|
||||
all, err := db.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(all) == 0 {
|
||||
return &databroker.GetAllResponse{ServerVersion: srv.version}, nil
|
||||
}
|
||||
|
||||
var recordVersion string
|
||||
for _, record := range records {
|
||||
records := make([]*databroker.Record, 0, len(all))
|
||||
for _, record := range all {
|
||||
if record.GetVersion() > recordVersion {
|
||||
recordVersion = record.GetVersion()
|
||||
}
|
||||
if record.DeletedAt == nil {
|
||||
records = append(records, record)
|
||||
}
|
||||
}
|
||||
|
||||
return &databroker.GetAllResponse{
|
||||
ServerVersion: srv.version,
|
||||
RecordVersion: recordVersion,
|
||||
|
|
|
@ -8,10 +8,14 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
|
@ -80,3 +84,58 @@ func TestServer_initVersion(t *testing.T) {
|
|||
assert.Equal(t, srvVersion, srv.version)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Get(t *testing.T) {
|
||||
cfg := newServerConfig()
|
||||
t.Run("ignore deleted", func(t *testing.T) {
|
||||
srv := newServer(cfg)
|
||||
|
||||
s := new(session.Session)
|
||||
s.Id = "1"
|
||||
any, err := anypb.New(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
srv.Set(context.Background(), &databroker.SetRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
})
|
||||
srv.Delete(context.Background(), &databroker.DeleteRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
})
|
||||
_, err = srv.Get(context.Background(), &databroker.GetRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, codes.NotFound, status.Code(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_GetAll(t *testing.T) {
|
||||
cfg := newServerConfig()
|
||||
t.Run("ignore deleted", func(t *testing.T) {
|
||||
srv := newServer(cfg)
|
||||
|
||||
s := new(session.Session)
|
||||
s.Id = "1"
|
||||
any, err := anypb.New(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
srv.Set(context.Background(), &databroker.SetRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
})
|
||||
srv.Delete(context.Background(), &databroker.DeleteRequest{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
})
|
||||
res, err := srv.GetAll(context.Background(), &databroker.GetAllRequest{
|
||||
Type: any.TypeUrl,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res.GetRecords(), 0)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -29,13 +29,22 @@ type Authenticator interface {
|
|||
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
|
||||
}
|
||||
|
||||
type (
|
||||
sessionMessage struct {
|
||||
record *databroker.Record
|
||||
session *session.Session
|
||||
}
|
||||
userMessage struct {
|
||||
record *databroker.Record
|
||||
user *user.User
|
||||
}
|
||||
)
|
||||
|
||||
// A Manager refreshes identity information using session and user data.
|
||||
type Manager struct {
|
||||
cfg *config
|
||||
authenticator Authenticator
|
||||
directory directory.Provider
|
||||
sessionClient session.SessionServiceClient
|
||||
userClient user.UserServiceClient
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
log zerolog.Logger
|
||||
|
||||
|
@ -60,8 +69,6 @@ type Manager struct {
|
|||
func New(
|
||||
authenticator Authenticator,
|
||||
directoryProvider directory.Provider,
|
||||
sessionClient session.SessionServiceClient,
|
||||
userClient user.UserServiceClient,
|
||||
dataBrokerClient databroker.DataBrokerServiceClient,
|
||||
options ...Option,
|
||||
) *Manager {
|
||||
|
@ -69,8 +76,6 @@ func New(
|
|||
cfg: newConfig(options...),
|
||||
authenticator: authenticator,
|
||||
directory: directoryProvider,
|
||||
sessionClient: sessionClient,
|
||||
userClient: userClient,
|
||||
dataBrokerClient: dataBrokerClient,
|
||||
log: log.With().Str("service", "identity_manager").Logger(),
|
||||
|
||||
|
@ -100,12 +105,12 @@ func (mgr *Manager) Run(ctx context.Context) error {
|
|||
|
||||
t, ctx := tomb.WithContext(ctx)
|
||||
|
||||
updatedSession := make(chan *session.Session, 1)
|
||||
updatedSession := make(chan sessionMessage, 1)
|
||||
t.Go(func() error {
|
||||
return mgr.syncSessions(ctx, updatedSession)
|
||||
})
|
||||
|
||||
updatedUser := make(chan *user.User, 1)
|
||||
updatedUser := make(chan userMessage, 1)
|
||||
t.Go(func() error {
|
||||
return mgr.syncUsers(ctx, updatedUser)
|
||||
})
|
||||
|
@ -129,8 +134,8 @@ func (mgr *Manager) Run(ctx context.Context) error {
|
|||
|
||||
func (mgr *Manager) refreshLoop(
|
||||
ctx context.Context,
|
||||
updatedSession <-chan *session.Session,
|
||||
updatedUser <-chan *user.User,
|
||||
updatedSession <-chan sessionMessage,
|
||||
updatedUser <-chan userMessage,
|
||||
updatedDirectoryUser <-chan *directory.User,
|
||||
updatedDirectoryGroup <-chan *directory.Group,
|
||||
) error {
|
||||
|
@ -361,7 +366,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
}
|
||||
s.OauthToken = ToOAuthToken(newToken)
|
||||
|
||||
_, err = mgr.sessionClient.Add(ctx, &session.AddRequest{Session: s.Session})
|
||||
res, err := session.Set(ctx, mgr.dataBrokerClient, s.Session)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -370,7 +375,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
mgr.onUpdateSession(ctx, s.Session)
|
||||
mgr.onUpdateSession(ctx, sessionMessage{record: res.GetRecord(), session: s.Session})
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||
|
@ -412,7 +417,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
_, err = mgr.userClient.Add(ctx, &user.AddRequest{User: u.User})
|
||||
record, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -421,11 +426,11 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
mgr.onUpdateUser(ctx, u.User)
|
||||
mgr.onUpdateUser(ctx, userMessage{record: record, user: u.User})
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- *session.Session) error {
|
||||
func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- sessionMessage) error {
|
||||
mgr.log.Info().Msg("syncing sessions")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(session.Session))
|
||||
|
@ -455,13 +460,13 @@ func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- *session.Session
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- &pbSession:
|
||||
case ch <- sessionMessage{record: record, session: &pbSession}:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- *user.User) error {
|
||||
func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- userMessage) error {
|
||||
mgr.log.Info().Msg("syncing users")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(user.User))
|
||||
|
@ -491,7 +496,7 @@ func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- *user.User) error {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- &pbUser:
|
||||
case ch <- userMessage{record: record, user: &pbUser}:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -635,44 +640,44 @@ func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *director
|
|||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateSession(ctx context.Context, pbSession *session.Session) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(pbSession.GetUserId(), pbSession.GetId()))
|
||||
func (mgr *Manager) onUpdateSession(ctx context.Context, msg sessionMessage) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
|
||||
|
||||
if pbSession.GetDeletedAt() != nil {
|
||||
if msg.record.GetDeletedAt() != nil {
|
||||
// remove from local store
|
||||
mgr.sessions.Delete(pbSession.GetUserId(), pbSession.GetId())
|
||||
mgr.sessions.Delete(msg.session.GetUserId(), msg.session.GetId())
|
||||
return
|
||||
}
|
||||
|
||||
// update session
|
||||
s, _ := mgr.sessions.Get(pbSession.GetUserId(), pbSession.GetId())
|
||||
s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId())
|
||||
s.lastRefresh = time.Now()
|
||||
s.gracePeriod = mgr.cfg.sessionRefreshGracePeriod
|
||||
s.coolOffDuration = mgr.cfg.sessionRefreshCoolOffDuration
|
||||
s.Session = pbSession
|
||||
s.Session = msg.session
|
||||
mgr.sessions.ReplaceOrInsert(s)
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(pbSession.GetUserId(), pbSession.GetId()))
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
|
||||
|
||||
// create the user if it doesn't exist yet
|
||||
if _, ok := mgr.users.Get(pbSession.GetUserId()); !ok {
|
||||
mgr.createUser(ctx, pbSession)
|
||||
if _, ok := mgr.users.Get(msg.session.GetUserId()); !ok {
|
||||
mgr.createUser(ctx, msg.session)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateUser(_ context.Context, pbUser *user.User) {
|
||||
if pbUser.DeletedAt != nil {
|
||||
mgr.users.Delete(pbUser.GetId())
|
||||
mgr.userScheduler.Remove(pbUser.GetId())
|
||||
func (mgr *Manager) onUpdateUser(_ context.Context, msg userMessage) {
|
||||
if msg.record.DeletedAt != nil {
|
||||
mgr.users.Delete(msg.user.GetId())
|
||||
mgr.userScheduler.Remove(msg.user.GetId())
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := mgr.users.Get(pbUser.GetId())
|
||||
u, ok := mgr.users.Get(msg.user.GetId())
|
||||
if ok {
|
||||
// only reset the refresh time if this is an existing user
|
||||
u.lastRefresh = time.Now()
|
||||
}
|
||||
u.refreshInterval = mgr.cfg.groupRefreshInterval
|
||||
u.User = pbUser
|
||||
u.User = msg.user
|
||||
mgr.users.ReplaceOrInsert(u)
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
}
|
||||
|
@ -692,7 +697,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
|
|||
},
|
||||
}
|
||||
|
||||
_, err := mgr.userClient.Add(ctx, &user.AddRequest{User: u.User})
|
||||
_, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", pbSession.GetUserId()).
|
||||
|
@ -702,8 +707,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
|
|||
}
|
||||
|
||||
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
|
||||
pbSession.DeletedAt = ptypes.TimestampNow()
|
||||
_, err := mgr.sessionClient.Add(ctx, &session.AddRequest{Session: pbSession})
|
||||
err := session.Delete(ctx, mgr.dataBrokerClient, pbSession.GetId())
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("session_id", pbSession.GetId()).
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue