mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
* storage: add fallback querier * storage: add sync querier * storage: add typed querier * use synced querier
164 lines
4.1 KiB
Go
164 lines
4.1 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
|
|
grpc "google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
status "google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
)
|
|
|
|
// ErrUnavailable indicates that a querier is not available.
|
|
var ErrUnavailable = errors.New("unavailable")
|
|
|
|
// A Querier is a read-only subset of the client methods
|
|
type Querier interface {
|
|
InvalidateCache(ctx context.Context, in *databroker.QueryRequest)
|
|
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
|
|
Stop()
|
|
}
|
|
|
|
// nilQuerier always returns NotFound.
|
|
type nilQuerier struct{}
|
|
|
|
func (nilQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {}
|
|
|
|
func (nilQuerier) Query(_ context.Context, _ *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
|
return nil, errors.Join(ErrUnavailable, status.Error(codes.NotFound, "not found"))
|
|
}
|
|
|
|
func (nilQuerier) Stop() {}
|
|
|
|
type querierKey struct{}
|
|
|
|
// GetQuerier gets the databroker Querier from the context.
|
|
func GetQuerier(ctx context.Context) Querier {
|
|
q, ok := ctx.Value(querierKey{}).(Querier)
|
|
if !ok {
|
|
q = nilQuerier{}
|
|
}
|
|
return q
|
|
}
|
|
|
|
// WithQuerier sets the databroker Querier on a context.
|
|
func WithQuerier(ctx context.Context, querier Querier) context.Context {
|
|
return context.WithValue(ctx, querierKey{}, querier)
|
|
}
|
|
|
|
// MarshalQueryRequest marshales the query request.
|
|
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
|
|
return (&proto.MarshalOptions{
|
|
Deterministic: true,
|
|
}).Marshal(req)
|
|
}
|
|
|
|
// MarshalQueryResponse marshals the query response.
|
|
func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
|
return (&proto.MarshalOptions{
|
|
Deterministic: true,
|
|
}).Marshal(res)
|
|
}
|
|
|
|
// GetDataBrokerRecord uses a querier to get a databroker record.
|
|
func GetDataBrokerRecord(
|
|
ctx context.Context,
|
|
recordType string,
|
|
recordID string,
|
|
lowestRecordVersion uint64,
|
|
) (*databroker.Record, error) {
|
|
q := GetQuerier(ctx)
|
|
|
|
req := &databroker.QueryRequest{
|
|
Type: recordType,
|
|
Limit: 1,
|
|
}
|
|
if lowestRecordVersion > 0 {
|
|
req.MinimumRecordVersionHint = proto.Uint64(lowestRecordVersion)
|
|
}
|
|
req.SetFilterByIDOrIndex(recordID)
|
|
|
|
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(res.GetRecords()) == 0 {
|
|
return nil, ErrNotFound
|
|
}
|
|
return res.GetRecords()[0], nil
|
|
}
|
|
|
|
// GetDataBrokerMessage gets a databroker record and converts it into the message type.
|
|
func GetDataBrokerMessage[T any, TMessage interface {
|
|
*T
|
|
proto.Message
|
|
}](
|
|
ctx context.Context,
|
|
recordID string,
|
|
lowestRecordVersion uint64,
|
|
) (TMessage, error) {
|
|
var msg T
|
|
|
|
record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = record.GetData().UnmarshalTo(TMessage(&msg))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return TMessage(&msg), nil
|
|
}
|
|
|
|
// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson.
|
|
func GetDataBrokerObjectViaJSON[T any](
|
|
ctx context.Context,
|
|
recordType string,
|
|
recordID string,
|
|
lowestRecordVersion uint64,
|
|
) (*T, error) {
|
|
record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
msg, err := record.GetData().UnmarshalNew()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
bs, err := protojson.Marshal(msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var obj T
|
|
err = json.Unmarshal(bs, &obj)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &obj, nil
|
|
}
|
|
|
|
// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records.
|
|
func InvalidateCacheForDataBrokerRecords(
|
|
ctx context.Context,
|
|
records ...*databroker.Record,
|
|
) {
|
|
for _, record := range records {
|
|
q := &databroker.QueryRequest{
|
|
Type: record.GetType(),
|
|
Limit: 1,
|
|
}
|
|
q.SetFilterByIDOrIndex(record.GetId())
|
|
GetQuerier(ctx).InvalidateCache(ctx, q)
|
|
}
|
|
}
|