mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 16:59:22 +02:00
authorize: use query instead of sync for databroker data (#3377)
This commit is contained in:
parent
fd82cc7870
commit
f61e7efe73
24 changed files with 661 additions and 1008 deletions
|
@ -1,196 +0,0 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/kentik/patricia"
|
||||
"github.com/kentik/patricia/string_tree"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
const (
|
||||
indexField = "$index"
|
||||
cidrField = "cidr"
|
||||
)
|
||||
|
||||
type index struct {
|
||||
mu sync.RWMutex
|
||||
byType map[string]*recordIndex
|
||||
}
|
||||
|
||||
func newIndex() *index {
|
||||
idx := new(index)
|
||||
idx.clear()
|
||||
return idx
|
||||
}
|
||||
|
||||
func (idx *index) clear() {
|
||||
idx.mu.Lock()
|
||||
defer idx.mu.Unlock()
|
||||
idx.byType = map[string]*recordIndex{}
|
||||
}
|
||||
|
||||
func (idx *index) delete(typeURL, id string) {
|
||||
idx.mu.Lock()
|
||||
defer idx.mu.Unlock()
|
||||
|
||||
ridx, ok := idx.byType[typeURL]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ridx.delete(id)
|
||||
|
||||
if len(ridx.byID) == 0 {
|
||||
delete(idx.byType, typeURL)
|
||||
}
|
||||
}
|
||||
|
||||
func (idx *index) find(typeURL, id string) proto.Message {
|
||||
idx.mu.RLock()
|
||||
defer idx.mu.RUnlock()
|
||||
|
||||
ridx, ok := idx.byType[typeURL]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return ridx.find(id)
|
||||
}
|
||||
|
||||
func (idx *index) get(typeURL, id string) proto.Message {
|
||||
idx.mu.RLock()
|
||||
defer idx.mu.RUnlock()
|
||||
|
||||
ridx, ok := idx.byType[typeURL]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return ridx.get(id)
|
||||
}
|
||||
|
||||
func (idx *index) set(typeURL, id string, msg proto.Message) {
|
||||
idx.mu.Lock()
|
||||
defer idx.mu.Unlock()
|
||||
|
||||
ridx, ok := idx.byType[typeURL]
|
||||
if !ok {
|
||||
ridx = newRecordIndex()
|
||||
idx.byType[typeURL] = ridx
|
||||
}
|
||||
ridx.set(id, msg)
|
||||
}
|
||||
|
||||
// a recordIndex indexes records for of a specific type
|
||||
type recordIndex struct {
|
||||
byID map[string]proto.Message
|
||||
byCIDRV4 *string_tree.TreeV4
|
||||
byCIDRV6 *string_tree.TreeV6
|
||||
}
|
||||
|
||||
// newRecordIndex creates a new record index.
|
||||
func newRecordIndex() *recordIndex {
|
||||
return &recordIndex{
|
||||
byID: map[string]proto.Message{},
|
||||
byCIDRV4: string_tree.NewTreeV4(),
|
||||
byCIDRV6: string_tree.NewTreeV6(),
|
||||
}
|
||||
}
|
||||
|
||||
func (idx *recordIndex) delete(id string) {
|
||||
r, ok := idx.byID[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(idx.byID, id)
|
||||
|
||||
addr4, addr6 := getIndexCIDR(r)
|
||||
if addr4 != nil {
|
||||
idx.byCIDRV4.Delete(*addr4, func(payload, val string) bool {
|
||||
return payload == val
|
||||
}, id)
|
||||
}
|
||||
if addr6 != nil {
|
||||
idx.byCIDRV6.Delete(*addr6, func(payload, val string) bool {
|
||||
return payload == val
|
||||
}, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (idx *recordIndex) find(idOrString string) proto.Message {
|
||||
r, ok := idx.byID[idOrString]
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
|
||||
addrv4, addrv6, _ := patricia.ParseIPFromString(idOrString)
|
||||
if addrv4 != nil {
|
||||
found, id := idx.byCIDRV4.FindDeepestTag(*addrv4)
|
||||
if found {
|
||||
return idx.byID[id]
|
||||
}
|
||||
}
|
||||
if addrv6 != nil {
|
||||
found, id := idx.byCIDRV6.FindDeepestTag(*addrv6)
|
||||
if found {
|
||||
return idx.byID[id]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (idx *recordIndex) get(id string) proto.Message {
|
||||
return idx.byID[id]
|
||||
}
|
||||
|
||||
func (idx *recordIndex) set(id string, msg proto.Message) {
|
||||
_, ok := idx.byID[id]
|
||||
if ok {
|
||||
idx.delete(id)
|
||||
}
|
||||
|
||||
idx.byID[id] = msg
|
||||
addr4, addr6 := getIndexCIDR(msg)
|
||||
if addr4 != nil {
|
||||
idx.byCIDRV4.Set(*addr4, id)
|
||||
}
|
||||
if addr6 != nil {
|
||||
idx.byCIDRV6.Set(*addr6, id)
|
||||
}
|
||||
}
|
||||
|
||||
func getIndexCIDR(msg proto.Message) (*patricia.IPv4Address, *patricia.IPv6Address) {
|
||||
var s *structpb.Struct
|
||||
if sv, ok := msg.(*structpb.Value); ok {
|
||||
s = sv.GetStructValue()
|
||||
} else {
|
||||
s, _ = msg.(*structpb.Struct)
|
||||
}
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
f, ok := s.Fields[indexField]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
obj := f.GetStructValue()
|
||||
if obj == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cf, ok := obj.Fields[cidrField]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
c := cf.GetStringValue()
|
||||
if c == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
addr4, addr6, _ := patricia.ParseIPFromString(c)
|
||||
return addr4, addr6
|
||||
}
|
|
@ -1,74 +0,0 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
func TestByID(t *testing.T) {
|
||||
idx := newIndex()
|
||||
|
||||
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"id": structpb.NewStringValue("r1"),
|
||||
}}
|
||||
|
||||
idx.set("example.com/record", "r1", r1)
|
||||
assert.Equal(t, r1, idx.get("example.com/record", "r1"))
|
||||
idx.delete("example.com/record", "r1")
|
||||
assert.Nil(t, idx.get("example.com/record", "r1"))
|
||||
}
|
||||
|
||||
func TestByCIDR(t *testing.T) {
|
||||
t.Run("ipv4", func(t *testing.T) {
|
||||
idx := newIndex()
|
||||
|
||||
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("192.168.0.0/16"),
|
||||
}}),
|
||||
"id": structpb.NewStringValue("r1"),
|
||||
}}
|
||||
idx.set("example.com/record", "r1", r1)
|
||||
|
||||
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("192.168.0.0/24"),
|
||||
}}),
|
||||
"id": structpb.NewStringValue("r2"),
|
||||
}}
|
||||
idx.set("example.com/record", "r2", r2)
|
||||
|
||||
assert.Equal(t, r2, idx.find("example.com/record", "192.168.0.7"))
|
||||
idx.delete("example.com/record", "r2")
|
||||
assert.Equal(t, r1, idx.find("example.com/record", "192.168.0.7"))
|
||||
idx.delete("example.com/record", "r1")
|
||||
assert.Nil(t, idx.find("example.com/record", "192.168.0.7"))
|
||||
})
|
||||
t.Run("ipv6", func(t *testing.T) {
|
||||
idx := newIndex()
|
||||
|
||||
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("2001:db8::/32"),
|
||||
}}),
|
||||
"id": structpb.NewStringValue("r1"),
|
||||
}}
|
||||
idx.set("example.com/record", "r1", r1)
|
||||
|
||||
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("2001:db8::/48"),
|
||||
}}),
|
||||
"id": structpb.NewStringValue("r2"),
|
||||
}}
|
||||
idx.set("example.com/record", "r2", r2)
|
||||
|
||||
assert.Equal(t, r2, idx.find("example.com/record", "2001:db8::"))
|
||||
idx.delete("example.com/record", "r2")
|
||||
assert.Equal(t, r1, idx.find("example.com/record", "2001:db8::"))
|
||||
idx.delete("example.com/record", "r1")
|
||||
assert.Nil(t, idx.find("example.com/record", "2001:db8::"))
|
||||
})
|
||||
}
|
|
@ -5,78 +5,33 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"github.com/open-policy-agent/opa/storage"
|
||||
opastorage "github.com/open-policy-agent/opa/storage"
|
||||
"github.com/open-policy-agent/opa/storage/inmem"
|
||||
"github.com/open-policy-agent/opa/types"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// A Store stores data for the OPA rego policy evaluation.
|
||||
type Store struct {
|
||||
storage.Store
|
||||
index *index
|
||||
|
||||
dataBrokerServerVersion, dataBrokerRecordVersion uint64
|
||||
opastorage.Store
|
||||
}
|
||||
|
||||
// New creates a new Store.
|
||||
func New() *Store {
|
||||
return &Store{
|
||||
Store: inmem.New(),
|
||||
index: newIndex(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewFromProtos creates a new Store from an existing set of protobuf messages.
|
||||
func NewFromProtos(serverVersion uint64, msgs ...proto.Message) *Store {
|
||||
s := New()
|
||||
for _, msg := range msgs {
|
||||
any := protoutil.NewAny(msg)
|
||||
record := new(databroker.Record)
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version = cryptutil.NewRandomUInt64()
|
||||
record.Id = uuid.New().String()
|
||||
record.Data = any
|
||||
record.Type = any.TypeUrl
|
||||
if hasID, ok := msg.(interface{ GetId() string }); ok {
|
||||
record.Id = hasID.GetId()
|
||||
}
|
||||
|
||||
s.UpdateRecord(serverVersion, record)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ClearRecords removes all the records from the store.
|
||||
func (s *Store) ClearRecords() {
|
||||
s.index.clear()
|
||||
}
|
||||
|
||||
// GetDataBrokerVersions gets the databroker versions.
|
||||
func (s *Store) GetDataBrokerVersions() (serverVersion, recordVersion uint64) {
|
||||
return atomic.LoadUint64(&s.dataBrokerServerVersion),
|
||||
atomic.LoadUint64(&s.dataBrokerRecordVersion)
|
||||
}
|
||||
|
||||
// GetRecordData gets a record's data from the store. `nil` is returned
|
||||
// if no record exists for the given type and id.
|
||||
func (s *Store) GetRecordData(typeURL, idOrValue string) proto.Message {
|
||||
return s.index.find(typeURL, idOrValue)
|
||||
}
|
||||
|
||||
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
|
||||
func (s *Store) UpdateIssuer(issuer string) {
|
||||
s.write("/issuer", issuer)
|
||||
|
@ -98,20 +53,6 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
|
|||
s.write("/route_policies", routePolicies)
|
||||
}
|
||||
|
||||
// UpdateRecord updates a record in the store.
|
||||
func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) {
|
||||
if record.GetDeletedAt() != nil {
|
||||
s.index.delete(record.GetType(), record.GetId())
|
||||
} else {
|
||||
msg, _ := record.GetData().UnmarshalNew()
|
||||
s.index.set(record.GetType(), record.GetId(), msg)
|
||||
}
|
||||
s.write("/databroker_server_version", fmt.Sprint(serverVersion))
|
||||
s.write("/databroker_record_version", fmt.Sprint(record.GetVersion()))
|
||||
atomic.StoreUint64(&s.dataBrokerServerVersion, serverVersion)
|
||||
atomic.StoreUint64(&s.dataBrokerRecordVersion, record.GetVersion())
|
||||
}
|
||||
|
||||
// UpdateSigningKey updates the signing key stored in the database. Signing operations
|
||||
// in rego use JWKs, so we take in that format.
|
||||
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||
|
@ -120,7 +61,7 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
|||
|
||||
func (s *Store) write(rawPath string, value interface{}) {
|
||||
ctx := context.TODO()
|
||||
err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error {
|
||||
err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error {
|
||||
return s.writeTxn(txn, rawPath, value)
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -129,23 +70,23 @@ func (s *Store) write(rawPath string, value interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error {
|
||||
p, ok := storage.ParsePath(rawPath)
|
||||
func (s *Store) writeTxn(txn opastorage.Transaction, rawPath string, value interface{}) error {
|
||||
p, ok := opastorage.ParsePath(rawPath)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid path")
|
||||
}
|
||||
|
||||
if len(p) > 1 {
|
||||
err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
|
||||
err := opastorage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var op storage.PatchOp = storage.ReplaceOp
|
||||
var op opastorage.PatchOp = opastorage.ReplaceOp
|
||||
_, err := s.Read(context.Background(), txn, p)
|
||||
if storage.IsNotFound(err) {
|
||||
op = storage.AddOp
|
||||
if opastorage.IsNotFound(err) {
|
||||
op = opastorage.AddOp
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -167,23 +108,42 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
|||
return nil, fmt.Errorf("invalid record type: %T", op1)
|
||||
}
|
||||
|
||||
recordID, ok := op2.Value.(ast.String)
|
||||
value, ok := op2.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid record id: %T", op2)
|
||||
}
|
||||
|
||||
msg := s.GetRecordData(string(recordType), string(recordID))
|
||||
if msg == nil {
|
||||
req := &databroker.QueryRequest{
|
||||
Type: string(recordType),
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(string(value))
|
||||
|
||||
res, err := storage.GetQuerier(bctx.Context).Query(bctx.Context, req)
|
||||
if err != nil {
|
||||
log.Error(bctx.Context).Err(err).Msg("authorize/store: error retrieving record")
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||
if msg == nil {
|
||||
if msg == nil {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
}
|
||||
obj := toMap(msg)
|
||||
|
||||
value, err := ast.InterfaceToValue(obj)
|
||||
regoValue, err := ast.InterfaceToValue(obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Error(bctx.Context).Err(err).Msg("authorize/store: error converting object to rego")
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
return ast.NewTerm(value), nil
|
||||
return ast.NewTerm(regoValue), nil
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,83 +0,0 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
t.Run("records", func(t *testing.T) {
|
||||
s := New()
|
||||
u := &user.User{
|
||||
Version: "v1",
|
||||
Id: "u1",
|
||||
Name: "name",
|
||||
Email: "name@example.com",
|
||||
}
|
||||
any := protoutil.NewAny(u)
|
||||
s.UpdateRecord(0, &databroker.Record{
|
||||
Version: 1,
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
})
|
||||
|
||||
v := s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
"version": "v1",
|
||||
"id": "u1",
|
||||
"name": "name",
|
||||
"email": "name@example.com",
|
||||
}, toMap(v))
|
||||
|
||||
s.UpdateRecord(0, &databroker.Record{
|
||||
Version: 2,
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
DeletedAt: timestamppb.Now(),
|
||||
})
|
||||
|
||||
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||
assert.Nil(t, v)
|
||||
|
||||
s.UpdateRecord(0, &databroker.Record{
|
||||
Version: 3,
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
})
|
||||
|
||||
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||
assert.NotNil(t, v)
|
||||
|
||||
s.ClearRecords()
|
||||
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||
assert.Nil(t, v)
|
||||
})
|
||||
t.Run("cidr", func(t *testing.T) {
|
||||
s := New()
|
||||
any := protoutil.NewAny(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("192.168.0.0/16"),
|
||||
}}),
|
||||
"id": structpb.NewStringValue("r1"),
|
||||
}})
|
||||
s.UpdateRecord(0, &databroker.Record{
|
||||
Version: 1,
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: "r1",
|
||||
Data: any,
|
||||
})
|
||||
|
||||
v := s.GetRecordData(any.GetTypeUrl(), "192.168.0.7")
|
||||
assert.NotNil(t, v)
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue