mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 08:19:23 +02:00
proxy: use querier cache for user info (#5532)
This commit is contained in:
parent
08623ef346
commit
bc263e3ee5
12 changed files with 259 additions and 156 deletions
|
@ -107,3 +107,6 @@ func (cache *globalCache) set(expiry time.Time, key, value []byte) {
|
|||
cache.fastcache.Set(key, item)
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
// GlobalCache is a global cache with a TTL of one minute.
|
||||
var GlobalCache = NewGlobalCache(time.Minute)
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
|
@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
|||
Deterministic: true,
|
||||
}).Marshal(res)
|
||||
}
|
||||
|
||||
// GetDataBrokerRecord uses a querier to get a databroker record.
|
||||
func GetDataBrokerRecord(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*databroker.Record, error) {
|
||||
q := GetQuerier(ctx)
|
||||
|
||||
req := &databroker.QueryRequest{
|
||||
Type: recordType,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(recordID)
|
||||
|
||||
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||
q.InvalidateCache(ctx, req)
|
||||
} else {
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// retry with an up to date cache
|
||||
res, err = q.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// GetDataBrokerMessage gets a databroker record and converts it into the message type.
|
||||
func GetDataBrokerMessage[T any, TMessage interface {
|
||||
*T
|
||||
proto.Message
|
||||
}](
|
||||
ctx context.Context,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (TMessage, error) {
|
||||
var msg T
|
||||
|
||||
record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = record.GetData().UnmarshalTo(TMessage(&msg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return TMessage(&msg), nil
|
||||
}
|
||||
|
||||
// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson.
|
||||
func GetDataBrokerObjectViaJSON[T any](
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*T, error) {
|
||||
record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := protojson.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var obj T
|
||||
err = json.Unmarshal(bs, &obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records.
|
||||
func InvalidateCacheForDataBrokerRecords(
|
||||
ctx context.Context,
|
||||
records ...*databroker.Record,
|
||||
) {
|
||||
for _, record := range records {
|
||||
q := &databroker.QueryRequest{
|
||||
Type: record.GetType(),
|
||||
Limit: 1,
|
||||
}
|
||||
q.SetFilterByIDOrIndex(record.GetId())
|
||||
GetQuerier(ctx).InvalidateCache(ctx, q)
|
||||
}
|
||||
}
|
||||
|
|
101
pkg/storage/querier_test.go
Normal file
101
pkg/storage/querier_test.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/testing/protocmp"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestGetDataBrokerRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
recordVersion, queryVersion uint64
|
||||
underlyingQueryCount, cachedQueryCount int
|
||||
}{
|
||||
{"cached", 1, 1, 1, 2},
|
||||
{"invalidated", 1, 2, 3, 4},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
||||
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
s, err := storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
s, err = storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDataBrokerMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
s1 := &session.Session{Id: "s1"}
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
s2, err := storage.GetDataBrokerMessage[session.Session](qctx, "s1", 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(s1, s2, protocmp.Transform()))
|
||||
|
||||
_, err = storage.GetDataBrokerMessage[session.Session](qctx, "s2", 0)
|
||||
assert.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetDataBrokerObjectViaJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
du1 := &directory.User{
|
||||
ID: "u1",
|
||||
Email: "u1@example.com",
|
||||
DisplayName: "User 1!",
|
||||
}
|
||||
sq := storage.NewStaticQuerier(newDirectoryUserRecord(du1))
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
du2, err := storage.GetDataBrokerObjectViaJSON[directory.User](qctx, directory.UserRecordType, "u1", 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(du1, du2, protocmp.Transform()))
|
||||
}
|
||||
|
||||
func newDirectoryUserRecord(directoryUser *directory.User) *databroker.Record {
|
||||
m := map[string]any{}
|
||||
bs, _ := json.Marshal(directoryUser)
|
||||
_ = json.Unmarshal(bs, &m)
|
||||
s, _ := structpb.NewStruct(m)
|
||||
return storage.NewStaticRecord(directory.UserRecordType, s)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue