mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
142 lines
4 KiB
Go
142 lines
4 KiB
Go
package authorize
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
type testAccessTrackerProvider struct {
|
|
dataBrokerServiceClient databroker.DataBrokerServiceClient
|
|
}
|
|
|
|
func (provider *testAccessTrackerProvider) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
|
return provider.dataBrokerServiceClient
|
|
}
|
|
|
|
func TestAccessTracker(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
var mu sync.Mutex
|
|
sessions := map[string]*session.Session{
|
|
"session-0": {
|
|
Id: "session-0",
|
|
},
|
|
"session-1": {
|
|
Id: "session-1",
|
|
},
|
|
"session-2": {
|
|
Id: "session-2",
|
|
},
|
|
}
|
|
serviceAccounts := map[string]*user.ServiceAccount{
|
|
"service-account-0": {
|
|
Id: "service-account-0",
|
|
},
|
|
"service-account-1": {
|
|
Id: "service-account-1",
|
|
},
|
|
"service-account-2": {
|
|
Id: "service-account-2",
|
|
},
|
|
}
|
|
tracker := NewAccessTracker(&testAccessTrackerProvider{
|
|
dataBrokerServiceClient: &mockDataBrokerServiceClient{
|
|
get: func(_ context.Context, in *databroker.GetRequest, _ ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
switch in.GetType() {
|
|
case "type.googleapis.com/session.Session":
|
|
s, ok := sessions[in.GetId()]
|
|
if !ok {
|
|
return nil, status.Errorf(codes.NotFound, "unknown session")
|
|
}
|
|
return &databroker.GetResponse{
|
|
Record: &databroker.Record{
|
|
Type: in.GetType(),
|
|
Id: in.GetId(),
|
|
Data: protoutil.NewAny(s),
|
|
},
|
|
}, nil
|
|
case "type.googleapis.com/user.ServiceAccount":
|
|
sa, ok := serviceAccounts[in.GetId()]
|
|
if !ok {
|
|
return nil, status.Errorf(codes.NotFound, "unknown service account")
|
|
}
|
|
return &databroker.GetResponse{
|
|
Record: &databroker.Record{
|
|
Type: in.GetType(),
|
|
Id: in.GetId(),
|
|
Data: protoutil.NewAny(sa),
|
|
},
|
|
}, nil
|
|
default:
|
|
return nil, status.Errorf(codes.InvalidArgument, "unknown type: %s", in.GetType())
|
|
}
|
|
},
|
|
put: func(_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption) (*databroker.PutResponse, error) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
res := new(databroker.PutResponse)
|
|
for _, record := range in.GetRecords() {
|
|
switch record.GetType() {
|
|
case "type.googleapis.com/session.Session":
|
|
data, _ := record.GetData().UnmarshalNew()
|
|
sessions[record.GetId()] = data.(*session.Session)
|
|
res.Records = append(res.Records, &databroker.Record{
|
|
Type: record.GetType(),
|
|
Id: record.GetId(),
|
|
Data: protoutil.NewAny(data),
|
|
})
|
|
case "type.googleapis.com/user.ServiceAccount":
|
|
data, _ := record.GetData().UnmarshalNew()
|
|
serviceAccounts[record.GetId()] = data.(*user.ServiceAccount)
|
|
res.Records = append(res.Records, &databroker.Record{
|
|
Type: record.GetType(),
|
|
Id: record.GetId(),
|
|
Data: protoutil.NewAny(data),
|
|
})
|
|
default:
|
|
return nil, status.Errorf(codes.InvalidArgument, "unknown type: %s", record.GetType())
|
|
}
|
|
}
|
|
return res, nil
|
|
},
|
|
},
|
|
}, 200, time.Second)
|
|
go tracker.Run(ctx)
|
|
|
|
for i := 0; i < 100; i++ {
|
|
tracker.TrackSessionAccess(fmt.Sprintf("session-%d", i%3))
|
|
}
|
|
for i := 0; i < 100; i++ {
|
|
tracker.TrackServiceAccountAccess(fmt.Sprintf("service-account-%d", i%3))
|
|
}
|
|
|
|
assert.Eventually(t, func() bool {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
return sessions["session-0"].GetAccessedAt().IsValid() &&
|
|
sessions["session-1"].GetAccessedAt().IsValid() &&
|
|
sessions["session-2"].GetAccessedAt().IsValid() &&
|
|
serviceAccounts["service-account-0"].GetAccessedAt().IsValid() &&
|
|
serviceAccounts["service-account-1"].GetAccessedAt().IsValid() &&
|
|
serviceAccounts["service-account-2"].GetAccessedAt().IsValid()
|
|
}, time.Second*10, time.Millisecond*100)
|
|
}
|