// Package redis implements a registry in redis. package redis import ( "context" "fmt" "sort" "strings" "sync" "time" "github.com/cenkalti/backoff/v4" "github.com/go-redis/redis/v8" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/redisutil" "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/registry/redis/lua" "github.com/pomerium/pomerium/internal/signal" registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" ) const ( registryKey = redisutil.KeyPrefix + "registry" registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch" pollInterval = time.Second * 30 ) type impl struct { cfg *config client redis.UniversalClient onChange *signal.Signal closeOnce sync.Once closed chan struct{} } // New creates a new registry implementation backend by redis. func New(rawURL string, options ...Option) (registry.Interface, error) { cfg := getConfig(options...) client, err := redisutil.NewClientFromURL(rawURL, cfg.tls) if err != nil { return nil, err } i := &impl{ cfg: cfg, client: client, onChange: signal.New(), closed: make(chan struct{}), } go i.listenForChanges(context.Background()) return i, nil } func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { _, err := i.runReport(ctx, req.GetServices()) if err != nil { return nil, err } return ®istrypb.RegisterResponse{ CallBackAfter: durationpb.New(i.cfg.ttl / 2), }, nil } func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { all, err := i.runReport(ctx, nil) if err != nil { return nil, err } include := map[registrypb.ServiceKind]struct{}{} for _, kind := range req.GetKinds() { include[kind] = struct{}{} } filtered := make([]*registrypb.Service, 0, len(all)) for _, svc := range all { if _, ok := include[svc.GetKind()]; !ok { continue } filtered = append(filtered, svc) } sort.Slice(filtered, func(i, j int) bool { { iv, jv := filtered[i].GetKind(), filtered[j].GetKind() switch { case iv < jv: return true case jv < iv: return false } } { iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint() switch { case iv < jv: return true case jv < iv: return false } } return false }) return ®istrypb.ServiceList{ Services: filtered, }, nil } func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { // listen for changes ch := i.onChange.Bind() defer i.onChange.Unbind(ch) // force a check periodically poll := time.NewTicker(pollInterval) defer poll.Stop() var prev *registrypb.ServiceList for { // retrieve the most recent list of services lst, err := i.List(stream.Context(), req) if err != nil { return err } // only send a new list if something changed if !proto.Equal(prev, lst) { err = stream.Send(lst) if err != nil { return err } } prev = lst // wait for an update select { case <-i.closed: return nil case <-stream.Context().Done(): return stream.Context().Err() case <-ch: case <-poll.C: } } } func (i *impl) Close() error { var err error i.closeOnce.Do(func() { err = i.client.Close() close(i.closed) }) return err } func (i *impl) listenForChanges(ctx context.Context) { ctx, cancel := context.WithCancel(ctx) go func() { <-i.closed cancel() }() bo := backoff.NewExponentialBackOff() bo.MaxElapsedTime = 0 outer: for { pubsub := i.client.Subscribe(ctx, registryUpdateKey) for { msg, err := pubsub.Receive(ctx) if err != nil { _ = pubsub.Close() select { case <-ctx.Done(): return case <-time.After(bo.NextBackOff()): } continue outer } bo.Reset() switch msg.(type) { case *redis.Message: i.onChange.Broadcast(ctx) } } } } func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) { args := []interface{}{ i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time i.cfg.ttl.Milliseconds(), // ttl } for _, svc := range updates { args = append(args, i.getRegistryHashKey(svc)) } res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey, registryUpdateKey}, args...).Result() if err != nil { return nil, err } if values, ok := res.([]interface{}); ok { var all []*registrypb.Service for _, value := range values { svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value)) if err != nil { log.Warn(ctx).Err(err).Msg("redis: invalid service") continue } all = append(all, svc) } return all, nil } return nil, nil } func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) { idx := strings.Index(key, "|") if idx == -1 { return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key) } svcKindStr := key[:idx] svcEndpointStr := key[idx+1:] svcKind, ok := registrypb.ServiceKind_value[svcKindStr] if !ok { return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr) } svc := ®istrypb.Service{ Kind: registrypb.ServiceKind(svcKind), Endpoint: svcEndpointStr, } return svc, nil } func (i *impl) getRegistryHashKey(svc *registrypb.Service) string { return svc.GetKind().String() + "|" + svc.GetEndpoint() }