pomerium/authorize/access_tracker.go
Caleb Doxsey 36f73fa6c7
authorize: track session and service account access date (#3220)
* session: add accessed at date

* authorize: track session and service account access times

* Revert "databroker: add support for field masks on Put (#3210)"

This reverts commit 2dc778035d.

* add test

* fix data race in test

* add deadline for update

* track dropped accesses
2022-03-31 09:19:04 -06:00

170 lines
4.6 KiB
Go

package authorize
import (
"context"
"sync/atomic"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
const (
accessTrackerMaxSize = 1_000
accessTrackerDebouncePeriod = 10 * time.Second
accessTrackerUpdateTimeout = 3 * time.Second
)
// A AccessTrackerProvider provides the databroker service client for tracking session access.
type AccessTrackerProvider interface {
GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
}
// A AccessTracker tracks accesses to sessions
type AccessTracker struct {
provider AccessTrackerProvider
sessionAccesses chan string
serviceAccountAccesses chan string
maxSize int
debouncePeriod time.Duration
droppedAccesses int64
}
// NewAccessTracker creates a new SessionAccessTracker.
func NewAccessTracker(
provider AccessTrackerProvider,
maxSize int,
debouncePeriod time.Duration,
) *AccessTracker {
return &AccessTracker{
provider: provider,
sessionAccesses: make(chan string, maxSize),
serviceAccountAccesses: make(chan string, maxSize),
maxSize: maxSize,
debouncePeriod: debouncePeriod,
}
}
// Run runs the access tracker.
func (tracker *AccessTracker) Run(ctx context.Context) {
ticker := time.NewTicker(tracker.debouncePeriod)
defer ticker.Stop()
sessionAccesses := sets.NewSizeLimitedStringSet(tracker.maxSize)
serviceAccountAccesses := sets.NewSizeLimitedStringSet(tracker.maxSize)
runTrackSessionAccess := func(sessionID string) {
sessionAccesses.Add(sessionID)
}
runTrackServiceAccountAccess := func(serviceAccountID string) {
serviceAccountAccesses.Add(serviceAccountID)
}
runSubmit := func() {
if dropped := atomic.SwapInt64(&tracker.droppedAccesses, 0); dropped > 0 {
log.Error(ctx).
Int64("dropped", dropped).
Msg("authorize: failed to track all session accesses")
}
client := tracker.provider.GetDataBrokerServiceClient()
var err error
sessionAccesses.ForEach(func(sessionID string) bool {
err = tracker.updateSession(ctx, client, sessionID)
return err == nil
})
if err != nil {
log.Error(ctx).Err(err).Msg("authorize: error updating session last access timestamp")
return
}
serviceAccountAccesses.ForEach(func(serviceAccountID string) bool {
err = tracker.updateServiceAccount(ctx, client, serviceAccountID)
return err == nil
})
if err != nil {
log.Error(ctx).Err(err).Msg("authorize: error updating service account last access timestamp")
return
}
sessionAccesses = sets.NewSizeLimitedStringSet(tracker.maxSize)
serviceAccountAccesses = sets.NewSizeLimitedStringSet(tracker.maxSize)
}
for {
select {
case <-ctx.Done():
return
case id := <-tracker.sessionAccesses:
runTrackSessionAccess(id)
case id := <-tracker.serviceAccountAccesses:
runTrackServiceAccountAccess(id)
case <-ticker.C:
runSubmit()
}
}
}
// TrackServiceAccountAccess tracks a service account access.
func (tracker *AccessTracker) TrackServiceAccountAccess(serviceAccountID string) {
select {
case tracker.serviceAccountAccesses <- serviceAccountID:
default:
atomic.AddInt64(&tracker.droppedAccesses, 1)
}
}
// TrackSessionAccess tracks a session access.
func (tracker *AccessTracker) TrackSessionAccess(sessionID string) {
select {
case tracker.sessionAccesses <- sessionID:
default:
atomic.AddInt64(&tracker.droppedAccesses, 1)
}
}
func (tracker *AccessTracker) updateServiceAccount(
ctx context.Context,
client databroker.DataBrokerServiceClient,
serviceAccountID string,
) error {
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
defer clearTimeout()
sa, err := user.GetServiceAccount(ctx, client, serviceAccountID)
if status.Code(err) == codes.NotFound {
return nil
} else if err != nil {
return err
}
sa.AccessedAt = timestamppb.Now()
_, err = user.PutServiceAccount(ctx, client, sa)
return err
}
func (tracker *AccessTracker) updateSession(
ctx context.Context,
client databroker.DataBrokerServiceClient,
sessionID string,
) error {
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
defer clearTimeout()
s, err := session.Get(ctx, client, sessionID)
if status.Code(err) == codes.NotFound {
return nil
} else if err != nil {
return err
}
s.AccessedAt = timestamppb.Now()
_, err = session.Put(ctx, client, s)
return err
}