mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 09:56:31 +02:00
* storage: add fallback querier * storage: add sync querier * storage: add typed querier * use synced querier
85 lines
2.3 KiB
Go
85 lines
2.3 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strconv"
|
|
|
|
"github.com/google/uuid"
|
|
grpc "google.golang.org/grpc"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/proto"
|
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
type staticQuerier struct {
|
|
records map[string]RecordCollection
|
|
}
|
|
|
|
// NewStaticQuerier creates a Querier that returns statically defined protobuf records.
|
|
func NewStaticQuerier(msgs ...proto.Message) Querier {
|
|
getter := &staticQuerier{records: make(map[string]RecordCollection)}
|
|
for _, msg := range msgs {
|
|
record, ok := msg.(*databroker.Record)
|
|
if !ok {
|
|
record = NewStaticRecord(protoutil.NewAny(msg).TypeUrl, msg)
|
|
}
|
|
c, ok := getter.records[record.Type]
|
|
if !ok {
|
|
c = NewRecordCollection()
|
|
getter.records[record.Type] = c
|
|
}
|
|
c.Put(record)
|
|
}
|
|
return getter
|
|
}
|
|
|
|
// NewStaticRecord creates a new databroker Record from a protobuf message.
|
|
func NewStaticRecord(typeURL string, msg proto.Message) *databroker.Record {
|
|
data := protoutil.NewAny(msg)
|
|
record := new(databroker.Record)
|
|
record.ModifiedAt = timestamppb.Now()
|
|
record.Version = cryptutil.NewRandomUInt64()
|
|
record.Id = uuid.New().String()
|
|
record.Data = data
|
|
record.Type = typeURL
|
|
if hasID, ok := msg.(interface{ GetId() string }); ok {
|
|
record.Id = hasID.GetId()
|
|
}
|
|
if hasVersion, ok := msg.(interface{ GetVersion() string }); ok {
|
|
if v, err := strconv.ParseUint(hasVersion.GetVersion(), 10, 64); err == nil {
|
|
record.Version = v
|
|
}
|
|
}
|
|
|
|
var jsonData struct {
|
|
ID string `json:"id"`
|
|
Version string `json:"version"`
|
|
}
|
|
bs, _ := protojson.Marshal(msg)
|
|
_ = json.Unmarshal(bs, &jsonData)
|
|
|
|
if jsonData.ID != "" {
|
|
record.Id = jsonData.ID
|
|
}
|
|
if jsonData.Version != "" {
|
|
if v, err := strconv.ParseUint(jsonData.Version, 10, 64); err == nil {
|
|
record.Version = v
|
|
}
|
|
}
|
|
|
|
return record
|
|
}
|
|
|
|
func (q *staticQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {}
|
|
|
|
// Query queries for records.
|
|
func (q *staticQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
|
return QueryRecordCollections(q.records, req)
|
|
}
|
|
|
|
func (*staticQuerier) Stop() {}
|