pomerium/authorize/access_tracker_test.go
Caleb Doxsey 1a5b8b606f
core/lint: upgrade golangci-lint, replace interface{} with any (#5099)
* core/lint: upgrade golangci-lint, replace interface{} with any

* regen proto
2024-05-02 14:33:52 -06:00

142 lines
4 KiB
Go

package authorize
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
)
type testAccessTrackerProvider struct {
dataBrokerServiceClient databroker.DataBrokerServiceClient
}
func (provider *testAccessTrackerProvider) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return provider.dataBrokerServiceClient
}
func TestAccessTracker(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var mu sync.Mutex
sessions := map[string]*session.Session{
"session-0": {
Id: "session-0",
},
"session-1": {
Id: "session-1",
},
"session-2": {
Id: "session-2",
},
}
serviceAccounts := map[string]*user.ServiceAccount{
"service-account-0": {
Id: "service-account-0",
},
"service-account-1": {
Id: "service-account-1",
},
"service-account-2": {
Id: "service-account-2",
},
}
tracker := NewAccessTracker(&testAccessTrackerProvider{
dataBrokerServiceClient: &mockDataBrokerServiceClient{
get: func(_ context.Context, in *databroker.GetRequest, _ ...grpc.CallOption) (*databroker.GetResponse, error) {
mu.Lock()
defer mu.Unlock()
switch in.GetType() {
case "type.googleapis.com/session.Session":
s, ok := sessions[in.GetId()]
if !ok {
return nil, status.Errorf(codes.NotFound, "unknown session")
}
return &databroker.GetResponse{
Record: &databroker.Record{
Type: in.GetType(),
Id: in.GetId(),
Data: protoutil.NewAny(s),
},
}, nil
case "type.googleapis.com/user.ServiceAccount":
sa, ok := serviceAccounts[in.GetId()]
if !ok {
return nil, status.Errorf(codes.NotFound, "unknown service account")
}
return &databroker.GetResponse{
Record: &databroker.Record{
Type: in.GetType(),
Id: in.GetId(),
Data: protoutil.NewAny(sa),
},
}, nil
default:
return nil, status.Errorf(codes.InvalidArgument, "unknown type: %s", in.GetType())
}
},
put: func(_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption) (*databroker.PutResponse, error) {
mu.Lock()
defer mu.Unlock()
res := new(databroker.PutResponse)
for _, record := range in.GetRecords() {
switch record.GetType() {
case "type.googleapis.com/session.Session":
data, _ := record.GetData().UnmarshalNew()
sessions[record.GetId()] = data.(*session.Session)
res.Records = append(res.Records, &databroker.Record{
Type: record.GetType(),
Id: record.GetId(),
Data: protoutil.NewAny(data),
})
case "type.googleapis.com/user.ServiceAccount":
data, _ := record.GetData().UnmarshalNew()
serviceAccounts[record.GetId()] = data.(*user.ServiceAccount)
res.Records = append(res.Records, &databroker.Record{
Type: record.GetType(),
Id: record.GetId(),
Data: protoutil.NewAny(data),
})
default:
return nil, status.Errorf(codes.InvalidArgument, "unknown type: %s", record.GetType())
}
}
return res, nil
},
},
}, 200, time.Second)
go tracker.Run(ctx)
for i := 0; i < 100; i++ {
tracker.TrackSessionAccess(fmt.Sprintf("session-%d", i%3))
}
for i := 0; i < 100; i++ {
tracker.TrackServiceAccountAccess(fmt.Sprintf("service-account-%d", i%3))
}
assert.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
return sessions["session-0"].GetAccessedAt().IsValid() &&
sessions["session-1"].GetAccessedAt().IsValid() &&
sessions["session-2"].GetAccessedAt().IsValid() &&
serviceAccounts["service-account-0"].GetAccessedAt().IsValid() &&
serviceAccounts["service-account-1"].GetAccessedAt().IsValid() &&
serviceAccounts["service-account-2"].GetAccessedAt().IsValid()
}, time.Second*10, time.Millisecond*100)
}