diff --git a/authorize/databroker.go b/authorize/databroker.go index 06ee9219f..4da2f4b5c 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -90,12 +90,11 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount( func (a *Authorize) getDataBrokerUser( ctx context.Context, userID string, - dataBrokerRecordVersion uint64, ) (*user.User, error) { ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser") defer span.End() - record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, dataBrokerRecordVersion) + record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0) if err != nil { return nil, err } diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go new file mode 100644 index 000000000..d90803d9e --- /dev/null +++ b/authorize/databroker_test.go @@ -0,0 +1,56 @@ +package authorize + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/storage" +) + +func Test_getDataBrokerRecord(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + for _, tc := range []struct { + name string + recordVersion, queryVersion uint64 + underlyingQueryCount, cachedQueryCount int + }{ + {"cached", 1, 1, 1, 2}, + {"invalidated", 1, 2, 3, 4}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} + + sq := storage.NewStaticQuerier(s1) + tsq := storage.NewTracingQuerier(sq) + cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache()) + tcq := storage.NewTracingQuerier(cq) + qctx := storage.WithQuerier(ctx, tcq) + + s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + assert.NoError(t, err) + assert.NotNil(t, s) + + s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + assert.NoError(t, err) + assert.NotNil(t, s) + + assert.Len(t, tsq.Traces(), tc.underlyingQueryCount, + "should have %d traces to the underlying querier", tc.underlyingQueryCount) + assert.Len(t, tcq.Traces(), tc.cachedQueryCount, + "should have %d traces to the cached querier", tc.cachedQueryCount) + }) + } +} diff --git a/authorize/grpc.go b/authorize/grpc.go index 4ea3c66ff..d8c448bfe 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -67,7 +67,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe } } if sessionState != nil && s != nil { - u, _ = a.getDataBrokerUser(ctx, s.GetUserId(), sessionState.DatabrokerRecordVersion) // ignore any missing user error + u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error } req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)