mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 03:12:50 +02:00
proxy: use querier cache for user info (#5532)
This commit is contained in:
parent
08623ef346
commit
bc263e3ee5
12 changed files with 259 additions and 156 deletions
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
|
@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
|||
Deterministic: true,
|
||||
}).Marshal(res)
|
||||
}
|
||||
|
||||
// GetDataBrokerRecord uses a querier to get a databroker record.
|
||||
func GetDataBrokerRecord(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*databroker.Record, error) {
|
||||
q := GetQuerier(ctx)
|
||||
|
||||
req := &databroker.QueryRequest{
|
||||
Type: recordType,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(recordID)
|
||||
|
||||
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||
q.InvalidateCache(ctx, req)
|
||||
} else {
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// retry with an up to date cache
|
||||
res, err = q.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// GetDataBrokerMessage gets a databroker record and converts it into the message type.
|
||||
func GetDataBrokerMessage[T any, TMessage interface {
|
||||
*T
|
||||
proto.Message
|
||||
}](
|
||||
ctx context.Context,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (TMessage, error) {
|
||||
var msg T
|
||||
|
||||
record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = record.GetData().UnmarshalTo(TMessage(&msg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return TMessage(&msg), nil
|
||||
}
|
||||
|
||||
// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson.
|
||||
func GetDataBrokerObjectViaJSON[T any](
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*T, error) {
|
||||
record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := protojson.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var obj T
|
||||
err = json.Unmarshal(bs, &obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records.
|
||||
func InvalidateCacheForDataBrokerRecords(
|
||||
ctx context.Context,
|
||||
records ...*databroker.Record,
|
||||
) {
|
||||
for _, record := range records {
|
||||
q := &databroker.QueryRequest{
|
||||
Type: record.GetType(),
|
||||
Limit: 1,
|
||||
}
|
||||
q.SetFilterByIDOrIndex(record.GetId())
|
||||
GetQuerier(ctx).InvalidateCache(ctx, q)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue