package authorize import ( "context" "sync/atomic" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sets" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" ) const ( accessTrackerMaxSize = 1_000 accessTrackerDebouncePeriod = 10 * time.Second accessTrackerUpdateTimeout = 3 * time.Second ) // A AccessTrackerProvider provides the databroker service client for tracking session access. type AccessTrackerProvider interface { GetDataBrokerServiceClient() databroker.DataBrokerServiceClient } // A AccessTracker tracks accesses to sessions type AccessTracker struct { provider AccessTrackerProvider sessionAccesses chan string serviceAccountAccesses chan string maxSize int debouncePeriod time.Duration droppedAccesses int64 } // NewAccessTracker creates a new SessionAccessTracker. func NewAccessTracker( provider AccessTrackerProvider, maxSize int, debouncePeriod time.Duration, ) *AccessTracker { return &AccessTracker{ provider: provider, sessionAccesses: make(chan string, maxSize), serviceAccountAccesses: make(chan string, maxSize), maxSize: maxSize, debouncePeriod: debouncePeriod, } } // Run runs the access tracker. func (tracker *AccessTracker) Run(ctx context.Context) { ticker := time.NewTicker(tracker.debouncePeriod) defer ticker.Stop() sessionAccesses := sets.NewSizeLimitedStringSet(tracker.maxSize) serviceAccountAccesses := sets.NewSizeLimitedStringSet(tracker.maxSize) runTrackSessionAccess := func(sessionID string) { sessionAccesses.Add(sessionID) } runTrackServiceAccountAccess := func(serviceAccountID string) { serviceAccountAccesses.Add(serviceAccountID) } runSubmit := func() { if dropped := atomic.SwapInt64(&tracker.droppedAccesses, 0); dropped > 0 { log.Error(ctx). Int64("dropped", dropped). Msg("authorize: failed to track all session accesses") } client := tracker.provider.GetDataBrokerServiceClient() var err error sessionAccesses.ForEach(func(sessionID string) bool { err = tracker.updateSession(ctx, client, sessionID) return err == nil }) if err != nil { log.Error(ctx).Err(err).Msg("authorize: error updating session last access timestamp") return } serviceAccountAccesses.ForEach(func(serviceAccountID string) bool { err = tracker.updateServiceAccount(ctx, client, serviceAccountID) return err == nil }) if err != nil { log.Error(ctx).Err(err).Msg("authorize: error updating service account last access timestamp") return } sessionAccesses = sets.NewSizeLimitedStringSet(tracker.maxSize) serviceAccountAccesses = sets.NewSizeLimitedStringSet(tracker.maxSize) } for { select { case <-ctx.Done(): return case id := <-tracker.sessionAccesses: runTrackSessionAccess(id) case id := <-tracker.serviceAccountAccesses: runTrackServiceAccountAccess(id) case <-ticker.C: runSubmit() } } } // TrackServiceAccountAccess tracks a service account access. func (tracker *AccessTracker) TrackServiceAccountAccess(serviceAccountID string) { select { case tracker.serviceAccountAccesses <- serviceAccountID: default: atomic.AddInt64(&tracker.droppedAccesses, 1) } } // TrackSessionAccess tracks a session access. func (tracker *AccessTracker) TrackSessionAccess(sessionID string) { select { case tracker.sessionAccesses <- sessionID: default: atomic.AddInt64(&tracker.droppedAccesses, 1) } } func (tracker *AccessTracker) updateServiceAccount( ctx context.Context, client databroker.DataBrokerServiceClient, serviceAccountID string, ) error { ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout) defer clearTimeout() sa, err := user.GetServiceAccount(ctx, client, serviceAccountID) if status.Code(err) == codes.NotFound { return nil } else if err != nil { return err } sa.AccessedAt = timestamppb.Now() _, err = user.PutServiceAccount(ctx, client, sa) return err } func (tracker *AccessTracker) updateSession( ctx context.Context, client databroker.DataBrokerServiceClient, sessionID string, ) error { ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout) defer clearTimeout() s, err := session.Get(ctx, client, sessionID) if status.Code(err) == codes.NotFound { return nil } else if err != nil { return err } s.AccessedAt = timestamppb.Now() _, err = session.Put(ctx, client, s) return err }