pomerium/authorize/access_tracker.go
2024-10-03 12:59:11 -06:00

167 lines
4.6 KiB
Go

package authorize
import (
"context"
"fmt"
"sync/atomic"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
const (
accessTrackerMaxSize = 1_000
accessTrackerDebouncePeriod = 10 * time.Second
accessTrackerUpdateTimeout = 3 * time.Second
)
// A AccessTrackerProvider provides the databroker service client for tracking session access.
type AccessTrackerProvider interface {
GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
}
// A AccessTracker tracks accesses to sessions
type AccessTracker struct {
provider AccessTrackerProvider
sessionAccesses chan string
serviceAccountAccesses chan string
maxSize int
debouncePeriod time.Duration
droppedAccesses int64
}
// NewAccessTracker creates a new SessionAccessTracker.
func NewAccessTracker(
provider AccessTrackerProvider,
maxSize int,
debouncePeriod time.Duration,
) *AccessTracker {
return &AccessTracker{
provider: provider,
sessionAccesses: make(chan string, maxSize),
serviceAccountAccesses: make(chan string, maxSize),
maxSize: maxSize,
debouncePeriod: debouncePeriod,
}
}
// Run runs the access tracker.
func (tracker *AccessTracker) Run(ctx context.Context) {
ticker := time.NewTicker(tracker.debouncePeriod)
defer ticker.Stop()
sessionAccesses := sets.NewSizeLimited[string](tracker.maxSize)
serviceAccountAccesses := sets.NewSizeLimited[string](tracker.maxSize)
runTrackSessionAccess := func(sessionID string) {
sessionAccesses.Insert(sessionID)
}
runTrackServiceAccountAccess := func(serviceAccountID string) {
serviceAccountAccesses.Insert(serviceAccountID)
}
runSubmit := func() {
if dropped := atomic.SwapInt64(&tracker.droppedAccesses, 0); dropped > 0 {
log.Ctx(ctx).Error().
Int64("dropped", dropped).
Msg("authorize: failed to track all session accesses")
}
client := tracker.provider.GetDataBrokerServiceClient()
for sessionID := range sessionAccesses.Items() {
err := tracker.updateSession(ctx, client, sessionID)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating session last access timestamp")
return
}
}
for serviceAccountID := range serviceAccountAccesses.Items() {
err := tracker.updateServiceAccount(ctx, client, serviceAccountID)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating service account last access timestamp")
return
}
}
sessionAccesses = sets.NewSizeLimited[string](tracker.maxSize)
serviceAccountAccesses = sets.NewSizeLimited[string](tracker.maxSize)
}
for {
select {
case <-ctx.Done():
return
case id := <-tracker.sessionAccesses:
runTrackSessionAccess(id)
case id := <-tracker.serviceAccountAccesses:
runTrackServiceAccountAccess(id)
case <-ticker.C:
runSubmit()
}
}
}
// TrackServiceAccountAccess tracks a service account access.
func (tracker *AccessTracker) TrackServiceAccountAccess(serviceAccountID string) {
select {
case tracker.serviceAccountAccesses <- serviceAccountID:
default:
atomic.AddInt64(&tracker.droppedAccesses, 1)
}
}
// TrackSessionAccess tracks a session access.
func (tracker *AccessTracker) TrackSessionAccess(sessionID string) {
select {
case tracker.sessionAccesses <- sessionID:
default:
atomic.AddInt64(&tracker.droppedAccesses, 1)
}
}
func (tracker *AccessTracker) updateServiceAccount(
ctx context.Context,
client databroker.DataBrokerServiceClient,
serviceAccountID string,
) error {
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
defer clearTimeout()
sa, err := user.GetServiceAccount(ctx, client, serviceAccountID)
if status.Code(err) == codes.NotFound {
return nil
} else if err != nil {
return err
}
sa.AccessedAt = timestamppb.Now()
_, err = user.PutServiceAccount(ctx, client, sa)
return err
}
func (tracker *AccessTracker) updateSession(
ctx context.Context,
client databroker.DataBrokerServiceClient,
sessionID string,
) error {
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
defer clearTimeout()
s := &session.Session{Id: sessionID, AccessedAt: timestamppb.Now()}
m, err := fieldmaskpb.New(s, "accessed_at")
if err != nil {
return fmt.Errorf("internal error: %w", err)
}
_, err = session.Patch(ctx, client, s, m)
return err
}