mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
registry: implement redis backend (#2179)
This commit is contained in:
parent
28155314e9
commit
a54d43b937
21 changed files with 772 additions and 64 deletions
254
internal/registry/redis/redis.go
Normal file
254
internal/registry/redis/redis.go
Normal file
|
@ -0,0 +1,254 @@
|
|||
// Package redis implements a registry in redis.
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/redisutil"
|
||||
"github.com/pomerium/pomerium/internal/registry"
|
||||
"github.com/pomerium/pomerium/internal/registry/redis/lua"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
registryKey = redisutil.KeyPrefix + "registry"
|
||||
registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch"
|
||||
|
||||
pollInterval = time.Second * 30
|
||||
)
|
||||
|
||||
type impl struct {
|
||||
cfg *config
|
||||
|
||||
client redis.UniversalClient
|
||||
onChange *signal.Signal
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new registry implementation backend by redis.
|
||||
func New(rawURL string, options ...Option) (registry.Interface, error) {
|
||||
cfg := getConfig(options...)
|
||||
|
||||
client, err := redisutil.NewClientFromURL(rawURL, cfg.tls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i := &impl{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
onChange: signal.New(),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go i.listenForChanges(context.Background())
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
||||
_, err := i.runReport(ctx, req.GetServices())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ®istrypb.RegisterResponse{
|
||||
CallBackAfter: durationpb.New(i.cfg.ttl / 2),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
||||
all, err := i.runReport(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
include := map[registrypb.ServiceKind]struct{}{}
|
||||
for _, kind := range req.GetKinds() {
|
||||
include[kind] = struct{}{}
|
||||
}
|
||||
|
||||
filtered := make([]*registrypb.Service, 0, len(all))
|
||||
for _, svc := range all {
|
||||
if _, ok := include[svc.GetKind()]; !ok {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, svc)
|
||||
}
|
||||
|
||||
sort.Slice(filtered, func(i, j int) bool {
|
||||
{
|
||||
iv, jv := filtered[i].GetKind(), filtered[j].GetKind()
|
||||
switch {
|
||||
case iv < jv:
|
||||
return true
|
||||
case jv < iv:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint()
|
||||
switch {
|
||||
case iv < jv:
|
||||
return true
|
||||
case jv < iv:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
return ®istrypb.ServiceList{
|
||||
Services: filtered,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
||||
// listen for changes
|
||||
ch := i.onChange.Bind()
|
||||
defer i.onChange.Unbind(ch)
|
||||
|
||||
// force a check periodically
|
||||
poll := time.NewTicker(pollInterval)
|
||||
defer poll.Stop()
|
||||
|
||||
var prev *registrypb.ServiceList
|
||||
for {
|
||||
// retrieve the most recent list of services
|
||||
lst, err := i.List(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only send a new list if something changed
|
||||
if !proto.Equal(prev, lst) {
|
||||
err = stream.Send(lst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
prev = lst
|
||||
|
||||
// wait for an update
|
||||
select {
|
||||
case <-i.closed:
|
||||
return nil
|
||||
case <-stream.Context().Done():
|
||||
return stream.Context().Err()
|
||||
case <-ch:
|
||||
case <-poll.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *impl) Close() error {
|
||||
var err error
|
||||
i.closeOnce.Do(func() {
|
||||
err = i.client.Close()
|
||||
close(i.closed)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *impl) listenForChanges(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
<-i.closed
|
||||
cancel()
|
||||
}()
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
|
||||
outer:
|
||||
for {
|
||||
pubsub := i.client.Subscribe(ctx, registryUpdateKey)
|
||||
for {
|
||||
msg, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
continue outer
|
||||
}
|
||||
bo.Reset()
|
||||
|
||||
switch msg.(type) {
|
||||
case *redis.Message:
|
||||
i.onChange.Broadcast(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) {
|
||||
args := []interface{}{
|
||||
i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time
|
||||
i.cfg.ttl.Milliseconds(), // ttl
|
||||
}
|
||||
for _, svc := range updates {
|
||||
args = append(args, i.getRegistryHashKey(svc))
|
||||
}
|
||||
res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey}, args...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = i.client.Publish(ctx, registryUpdateKey, time.Now().Format(time.RFC3339Nano)).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if values, ok := res.([]interface{}); ok {
|
||||
var all []*registrypb.Service
|
||||
for _, value := range values {
|
||||
svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value))
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("redis: invalid service")
|
||||
continue
|
||||
}
|
||||
all = append(all, svc)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) {
|
||||
idx := strings.Index(key, "|")
|
||||
if idx == -1 {
|
||||
return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key)
|
||||
}
|
||||
|
||||
svcKindStr := key[:idx]
|
||||
svcEndpointStr := key[idx+1:]
|
||||
|
||||
svcKind, ok := registrypb.ServiceKind_value[svcKindStr]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr)
|
||||
}
|
||||
|
||||
svc := ®istrypb.Service{
|
||||
Kind: registrypb.ServiceKind(svcKind),
|
||||
Endpoint: svcEndpointStr,
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (i *impl) getRegistryHashKey(svc *registrypb.Service) string {
|
||||
return svc.GetKind().String() + "|" + svc.GetEndpoint()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue