authorize: use query instead of sync for databroker data (#3377)

This commit is contained in:
Caleb Doxsey 2022-06-01 15:40:07 -06:00 committed by GitHub
parent fd82cc7870
commit f61e7efe73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 661 additions and 1008 deletions

View file

@ -1,196 +0,0 @@
package store
import (
"sync"
"github.com/kentik/patricia"
"github.com/kentik/patricia/string_tree"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
)
const (
indexField = "$index"
cidrField = "cidr"
)
type index struct {
mu sync.RWMutex
byType map[string]*recordIndex
}
func newIndex() *index {
idx := new(index)
idx.clear()
return idx
}
func (idx *index) clear() {
idx.mu.Lock()
defer idx.mu.Unlock()
idx.byType = map[string]*recordIndex{}
}
func (idx *index) delete(typeURL, id string) {
idx.mu.Lock()
defer idx.mu.Unlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return
}
ridx.delete(id)
if len(ridx.byID) == 0 {
delete(idx.byType, typeURL)
}
}
func (idx *index) find(typeURL, id string) proto.Message {
idx.mu.RLock()
defer idx.mu.RUnlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return nil
}
return ridx.find(id)
}
func (idx *index) get(typeURL, id string) proto.Message {
idx.mu.RLock()
defer idx.mu.RUnlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return nil
}
return ridx.get(id)
}
func (idx *index) set(typeURL, id string, msg proto.Message) {
idx.mu.Lock()
defer idx.mu.Unlock()
ridx, ok := idx.byType[typeURL]
if !ok {
ridx = newRecordIndex()
idx.byType[typeURL] = ridx
}
ridx.set(id, msg)
}
// a recordIndex indexes records for of a specific type
type recordIndex struct {
byID map[string]proto.Message
byCIDRV4 *string_tree.TreeV4
byCIDRV6 *string_tree.TreeV6
}
// newRecordIndex creates a new record index.
func newRecordIndex() *recordIndex {
return &recordIndex{
byID: map[string]proto.Message{},
byCIDRV4: string_tree.NewTreeV4(),
byCIDRV6: string_tree.NewTreeV6(),
}
}
func (idx *recordIndex) delete(id string) {
r, ok := idx.byID[id]
if !ok {
return
}
delete(idx.byID, id)
addr4, addr6 := getIndexCIDR(r)
if addr4 != nil {
idx.byCIDRV4.Delete(*addr4, func(payload, val string) bool {
return payload == val
}, id)
}
if addr6 != nil {
idx.byCIDRV6.Delete(*addr6, func(payload, val string) bool {
return payload == val
}, id)
}
}
func (idx *recordIndex) find(idOrString string) proto.Message {
r, ok := idx.byID[idOrString]
if ok {
return r
}
addrv4, addrv6, _ := patricia.ParseIPFromString(idOrString)
if addrv4 != nil {
found, id := idx.byCIDRV4.FindDeepestTag(*addrv4)
if found {
return idx.byID[id]
}
}
if addrv6 != nil {
found, id := idx.byCIDRV6.FindDeepestTag(*addrv6)
if found {
return idx.byID[id]
}
}
return nil
}
func (idx *recordIndex) get(id string) proto.Message {
return idx.byID[id]
}
func (idx *recordIndex) set(id string, msg proto.Message) {
_, ok := idx.byID[id]
if ok {
idx.delete(id)
}
idx.byID[id] = msg
addr4, addr6 := getIndexCIDR(msg)
if addr4 != nil {
idx.byCIDRV4.Set(*addr4, id)
}
if addr6 != nil {
idx.byCIDRV6.Set(*addr6, id)
}
}
func getIndexCIDR(msg proto.Message) (*patricia.IPv4Address, *patricia.IPv6Address) {
var s *structpb.Struct
if sv, ok := msg.(*structpb.Value); ok {
s = sv.GetStructValue()
} else {
s, _ = msg.(*structpb.Struct)
}
if s == nil {
return nil, nil
}
f, ok := s.Fields[indexField]
if !ok {
return nil, nil
}
obj := f.GetStructValue()
if obj == nil {
return nil, nil
}
cf, ok := obj.Fields[cidrField]
if !ok {
return nil, nil
}
c := cf.GetStringValue()
if c == "" {
return nil, nil
}
addr4, addr6, _ := patricia.ParseIPFromString(c)
return addr4, addr6
}

View file

@ -1,74 +0,0 @@
package store
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
)
func TestByID(t *testing.T) {
idx := newIndex()
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("r1"),
}}
idx.set("example.com/record", "r1", r1)
assert.Equal(t, r1, idx.get("example.com/record", "r1"))
idx.delete("example.com/record", "r1")
assert.Nil(t, idx.get("example.com/record", "r1"))
}
func TestByCIDR(t *testing.T) {
t.Run("ipv4", func(t *testing.T) {
idx := newIndex()
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/16"),
}}),
"id": structpb.NewStringValue("r1"),
}}
idx.set("example.com/record", "r1", r1)
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/24"),
}}),
"id": structpb.NewStringValue("r2"),
}}
idx.set("example.com/record", "r2", r2)
assert.Equal(t, r2, idx.find("example.com/record", "192.168.0.7"))
idx.delete("example.com/record", "r2")
assert.Equal(t, r1, idx.find("example.com/record", "192.168.0.7"))
idx.delete("example.com/record", "r1")
assert.Nil(t, idx.find("example.com/record", "192.168.0.7"))
})
t.Run("ipv6", func(t *testing.T) {
idx := newIndex()
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("2001:db8::/32"),
}}),
"id": structpb.NewStringValue("r1"),
}}
idx.set("example.com/record", "r1", r1)
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("2001:db8::/48"),
}}),
"id": structpb.NewStringValue("r2"),
}}
idx.set("example.com/record", "r2", r2)
assert.Equal(t, r2, idx.find("example.com/record", "2001:db8::"))
idx.delete("example.com/record", "r2")
assert.Equal(t, r1, idx.find("example.com/record", "2001:db8::"))
idx.delete("example.com/record", "r1")
assert.Nil(t, idx.find("example.com/record", "2001:db8::"))
})
}

View file

@ -5,78 +5,33 @@ import (
"context"
"encoding/json"
"fmt"
"sync/atomic"
"github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/storage"
opastorage "github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/storage/inmem"
"github.com/open-policy-agent/opa/types"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
// A Store stores data for the OPA rego policy evaluation.
type Store struct {
storage.Store
index *index
dataBrokerServerVersion, dataBrokerRecordVersion uint64
opastorage.Store
}
// New creates a new Store.
func New() *Store {
return &Store{
Store: inmem.New(),
index: newIndex(),
}
}
// NewFromProtos creates a new Store from an existing set of protobuf messages.
func NewFromProtos(serverVersion uint64, msgs ...proto.Message) *Store {
s := New()
for _, msg := range msgs {
any := protoutil.NewAny(msg)
record := new(databroker.Record)
record.ModifiedAt = timestamppb.Now()
record.Version = cryptutil.NewRandomUInt64()
record.Id = uuid.New().String()
record.Data = any
record.Type = any.TypeUrl
if hasID, ok := msg.(interface{ GetId() string }); ok {
record.Id = hasID.GetId()
}
s.UpdateRecord(serverVersion, record)
}
return s
}
// ClearRecords removes all the records from the store.
func (s *Store) ClearRecords() {
s.index.clear()
}
// GetDataBrokerVersions gets the databroker versions.
func (s *Store) GetDataBrokerVersions() (serverVersion, recordVersion uint64) {
return atomic.LoadUint64(&s.dataBrokerServerVersion),
atomic.LoadUint64(&s.dataBrokerRecordVersion)
}
// GetRecordData gets a record's data from the store. `nil` is returned
// if no record exists for the given type and id.
func (s *Store) GetRecordData(typeURL, idOrValue string) proto.Message {
return s.index.find(typeURL, idOrValue)
}
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
func (s *Store) UpdateIssuer(issuer string) {
s.write("/issuer", issuer)
@ -98,20 +53,6 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
s.write("/route_policies", routePolicies)
}
// UpdateRecord updates a record in the store.
func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) {
if record.GetDeletedAt() != nil {
s.index.delete(record.GetType(), record.GetId())
} else {
msg, _ := record.GetData().UnmarshalNew()
s.index.set(record.GetType(), record.GetId(), msg)
}
s.write("/databroker_server_version", fmt.Sprint(serverVersion))
s.write("/databroker_record_version", fmt.Sprint(record.GetVersion()))
atomic.StoreUint64(&s.dataBrokerServerVersion, serverVersion)
atomic.StoreUint64(&s.dataBrokerRecordVersion, record.GetVersion())
}
// UpdateSigningKey updates the signing key stored in the database. Signing operations
// in rego use JWKs, so we take in that format.
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
@ -120,7 +61,7 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
func (s *Store) write(rawPath string, value interface{}) {
ctx := context.TODO()
err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error {
err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error {
return s.writeTxn(txn, rawPath, value)
})
if err != nil {
@ -129,23 +70,23 @@ func (s *Store) write(rawPath string, value interface{}) {
}
}
func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error {
p, ok := storage.ParsePath(rawPath)
func (s *Store) writeTxn(txn opastorage.Transaction, rawPath string, value interface{}) error {
p, ok := opastorage.ParsePath(rawPath)
if !ok {
return fmt.Errorf("invalid path")
}
if len(p) > 1 {
err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
err := opastorage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
if err != nil {
return err
}
}
var op storage.PatchOp = storage.ReplaceOp
var op opastorage.PatchOp = opastorage.ReplaceOp
_, err := s.Read(context.Background(), txn, p)
if storage.IsNotFound(err) {
op = storage.AddOp
if opastorage.IsNotFound(err) {
op = opastorage.AddOp
} else if err != nil {
return err
}
@ -167,23 +108,42 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
return nil, fmt.Errorf("invalid record type: %T", op1)
}
recordID, ok := op2.Value.(ast.String)
value, ok := op2.Value.(ast.String)
if !ok {
return nil, fmt.Errorf("invalid record id: %T", op2)
}
msg := s.GetRecordData(string(recordType), string(recordID))
if msg == nil {
req := &databroker.QueryRequest{
Type: string(recordType),
Limit: 1,
}
req.SetFilterByIDOrIndex(string(value))
res, err := storage.GetQuerier(bctx.Context).Query(bctx.Context, req)
if err != nil {
log.Error(bctx.Context).Err(err).Msg("authorize/store: error retrieving record")
return ast.NullTerm(), nil
}
if len(res.GetRecords()) == 0 {
return ast.NullTerm(), nil
}
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
if msg == nil {
if msg == nil {
return ast.NullTerm(), nil
}
}
obj := toMap(msg)
value, err := ast.InterfaceToValue(obj)
regoValue, err := ast.InterfaceToValue(obj)
if err != nil {
return nil, err
log.Error(bctx.Context).Err(err).Msg("authorize/store: error converting object to rego")
return ast.NullTerm(), nil
}
return ast.NewTerm(value), nil
return ast.NewTerm(regoValue), nil
})
}

View file

@ -1,83 +0,0 @@
package store
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestStore(t *testing.T) {
t.Run("records", func(t *testing.T) {
s := New()
u := &user.User{
Version: "v1",
Id: "u1",
Name: "name",
Email: "name@example.com",
}
any := protoutil.NewAny(u)
s.UpdateRecord(0, &databroker.Record{
Version: 1,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
})
v := s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Equal(t, map[string]interface{}{
"version": "v1",
"id": "u1",
"name": "name",
"email": "name@example.com",
}, toMap(v))
s.UpdateRecord(0, &databroker.Record{
Version: 2,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
DeletedAt: timestamppb.Now(),
})
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Nil(t, v)
s.UpdateRecord(0, &databroker.Record{
Version: 3,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
})
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.NotNil(t, v)
s.ClearRecords()
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Nil(t, v)
})
t.Run("cidr", func(t *testing.T) {
s := New()
any := protoutil.NewAny(&structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/16"),
}}),
"id": structpb.NewStringValue("r1"),
}})
s.UpdateRecord(0, &databroker.Record{
Version: 1,
Type: any.GetTypeUrl(),
Id: "r1",
Data: any,
})
v := s.GetRecordData(any.GetTypeUrl(), "192.168.0.7")
assert.NotNil(t, v)
})
}