pomerium/authorize/internal/store/index.go
2022-04-19 16:18:34 -06:00

196 lines
3.4 KiB
Go

package store
import (
"sync"
"github.com/kentik/patricia"
"github.com/kentik/patricia/string_tree"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
)
const (
indexField = "$index"
cidrField = "cidr"
)
type index struct {
mu sync.RWMutex
byType map[string]*recordIndex
}
func newIndex() *index {
idx := new(index)
idx.clear()
return idx
}
func (idx *index) clear() {
idx.mu.Lock()
defer idx.mu.Unlock()
idx.byType = map[string]*recordIndex{}
}
func (idx *index) delete(typeURL, id string) {
idx.mu.Lock()
defer idx.mu.Unlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return
}
ridx.delete(id)
if len(ridx.byID) == 0 {
delete(idx.byType, typeURL)
}
}
func (idx *index) find(typeURL, id string) proto.Message {
idx.mu.RLock()
defer idx.mu.RUnlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return nil
}
return ridx.find(id)
}
func (idx *index) get(typeURL, id string) proto.Message {
idx.mu.RLock()
defer idx.mu.RUnlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return nil
}
return ridx.get(id)
}
func (idx *index) set(typeURL, id string, msg proto.Message) {
idx.mu.Lock()
defer idx.mu.Unlock()
ridx, ok := idx.byType[typeURL]
if !ok {
ridx = newRecordIndex()
idx.byType[typeURL] = ridx
}
ridx.set(id, msg)
}
// a recordIndex indexes records for of a specific type
type recordIndex struct {
byID map[string]proto.Message
byCIDRV4 *string_tree.TreeV4
byCIDRV6 *string_tree.TreeV6
}
// newRecordIndex creates a new record index.
func newRecordIndex() *recordIndex {
return &recordIndex{
byID: map[string]proto.Message{},
byCIDRV4: string_tree.NewTreeV4(),
byCIDRV6: string_tree.NewTreeV6(),
}
}
func (idx *recordIndex) delete(id string) {
r, ok := idx.byID[id]
if !ok {
return
}
delete(idx.byID, id)
addr4, addr6 := getIndexCIDR(r)
if addr4 != nil {
idx.byCIDRV4.Delete(*addr4, func(payload, val string) bool {
return payload == val
}, id)
}
if addr6 != nil {
idx.byCIDRV6.Delete(*addr6, func(payload, val string) bool {
return payload == val
}, id)
}
}
func (idx *recordIndex) find(idOrString string) proto.Message {
r, ok := idx.byID[idOrString]
if ok {
return r
}
addrv4, addrv6, _ := patricia.ParseIPFromString(idOrString)
if addrv4 != nil {
found, id := idx.byCIDRV4.FindDeepestTag(*addrv4)
if found {
return idx.byID[id]
}
}
if addrv6 != nil {
found, id := idx.byCIDRV6.FindDeepestTag(*addrv6)
if found {
return idx.byID[id]
}
}
return nil
}
func (idx *recordIndex) get(id string) proto.Message {
return idx.byID[id]
}
func (idx *recordIndex) set(id string, msg proto.Message) {
_, ok := idx.byID[id]
if ok {
idx.delete(id)
}
idx.byID[id] = msg
addr4, addr6 := getIndexCIDR(msg)
if addr4 != nil {
idx.byCIDRV4.Set(*addr4, id)
}
if addr6 != nil {
idx.byCIDRV6.Set(*addr6, id)
}
}
func getIndexCIDR(msg proto.Message) (*patricia.IPv4Address, *patricia.IPv6Address) {
var s *structpb.Struct
if sv, ok := msg.(*structpb.Value); ok {
s = sv.GetStructValue()
} else {
s, _ = msg.(*structpb.Struct)
}
if s == nil {
return nil, nil
}
f, ok := s.Fields[indexField]
if !ok {
return nil, nil
}
obj := f.GetStructValue()
if obj == nil {
return nil, nil
}
cf, ok := obj.Fields[cidrField]
if !ok {
return nil, nil
}
c := cf.GetStringValue()
if c == "" {
return nil, nil
}
addr4, addr6, _ := patricia.ParseIPFromString(c)
return addr4, addr6
}