mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
201 lines
5.2 KiB
Go
201 lines
5.2 KiB
Go
package registry
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/pomerium/pomerium/internal/signal"
|
|
pb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
|
|
"github.com/golang/protobuf/ptypes"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
)
|
|
|
|
type inMemoryServer struct {
|
|
ttl time.Duration
|
|
// onchange is used to broadcast changes to listeners
|
|
onchange *signal.Signal
|
|
|
|
// mu holds lock for regs
|
|
mu sync.RWMutex
|
|
// regs is {service,endpoint} -> expiration time mapping
|
|
regs map[inMemoryKey]*timestamppb.Timestamp
|
|
}
|
|
|
|
type inMemoryKey struct {
|
|
kind pb.ServiceKind
|
|
endpoint string
|
|
}
|
|
|
|
// NewInMemoryServer constructs a new registry tracking service that operates in RAM
|
|
// as such, it is not usable for multi-node deployment where REDIS or other alternative should be used
|
|
func NewInMemoryServer(ctx context.Context, ttl time.Duration) pb.RegistryServer {
|
|
srv := &inMemoryServer{
|
|
ttl: ttl,
|
|
regs: make(map[inMemoryKey]*timestamppb.Timestamp),
|
|
onchange: signal.New(),
|
|
}
|
|
go srv.periodicCheck(ctx)
|
|
return srv
|
|
}
|
|
|
|
func (s *inMemoryServer) periodicCheck(ctx context.Context) {
|
|
after := s.ttl * purgeAfterTTLFactor
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(after):
|
|
if s.lockAndRmExpired() {
|
|
s.onchange.Broadcast()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Report is periodically sent by each service to confirm it is still serving with the registry
|
|
// data is persisted with a certain TTL
|
|
func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
|
|
if err := req.Validate(); err != nil {
|
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
|
}
|
|
|
|
updated, err := s.lockAndReport(req.Services)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if updated {
|
|
s.onchange.Broadcast()
|
|
}
|
|
|
|
return &pb.RegisterResponse{
|
|
CallBackAfter: durationpb.New(s.ttl / callAfterTTLFactor),
|
|
}, nil
|
|
}
|
|
|
|
func (s *inMemoryServer) lockAndRmExpired() bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.rmExpiredLocked()
|
|
}
|
|
|
|
func (s *inMemoryServer) rmExpiredLocked() bool {
|
|
now := time.Now()
|
|
removed := false
|
|
|
|
for k, expires := range s.regs {
|
|
if expires.AsTime().Before(now) {
|
|
delete(s.regs, k)
|
|
removed = true
|
|
}
|
|
}
|
|
|
|
return removed
|
|
}
|
|
|
|
// lockAndReport acquires lock, performs an update and returns current state of services
|
|
func (s *inMemoryServer) lockAndReport(services []*pb.Service) (bool, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.reportLocked(services)
|
|
}
|
|
|
|
// reportLocked updates registration and also returns an indication whether service list was updated
|
|
func (s *inMemoryServer) reportLocked(services []*pb.Service) (bool, error) {
|
|
expires, err := ptypes.TimestampProto(time.Now().Add(s.ttl))
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
inserted := false
|
|
for _, svc := range services {
|
|
k := inMemoryKey{kind: svc.Kind, endpoint: svc.Endpoint}
|
|
if _, present := s.regs[k]; !present {
|
|
inserted = true
|
|
}
|
|
s.regs[k] = expires
|
|
}
|
|
|
|
removed := s.rmExpiredLocked()
|
|
return inserted || removed, nil
|
|
}
|
|
|
|
// List returns current snapshot of the services known to the registry
|
|
func (s *inMemoryServer) List(ctx context.Context, req *pb.ListRequest) (*pb.ServiceList, error) {
|
|
if err := req.Validate(); err != nil {
|
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
|
}
|
|
|
|
return &pb.ServiceList{Services: s.getServices(kindsMap(req.Kinds))}, nil
|
|
}
|
|
|
|
func kindsMap(kinds []pb.ServiceKind) map[pb.ServiceKind]bool {
|
|
out := make(map[pb.ServiceKind]bool, len(kinds))
|
|
for _, k := range kinds {
|
|
out[k] = true
|
|
}
|
|
return out
|
|
}
|
|
|
|
// Watch returns a stream of updates as full snapshots
|
|
func (s *inMemoryServer) Watch(req *pb.ListRequest, srv pb.Registry_WatchServer) error {
|
|
if err := req.Validate(); err != nil {
|
|
return status.Error(codes.InvalidArgument, err.Error())
|
|
}
|
|
|
|
kinds := kindsMap(req.Kinds)
|
|
ctx := srv.Context()
|
|
|
|
updates := s.onchange.Bind()
|
|
defer s.onchange.Unbind(updates)
|
|
|
|
if err := srv.Send(&pb.ServiceList{Services: s.getServices(kinds)}); err != nil {
|
|
return status.Errorf(codes.Internal, "sending initial snapshot: %v", err)
|
|
}
|
|
|
|
for {
|
|
services, err := s.getServiceUpdates(ctx, kinds, updates)
|
|
if err != nil {
|
|
return status.Errorf(codes.Internal, "obtaining service registrations: %v", err)
|
|
}
|
|
if err := srv.Send(&pb.ServiceList{Services: services}); err != nil {
|
|
return status.Errorf(codes.Internal, "sending registration snapshot: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *inMemoryServer) getServiceUpdates(ctx context.Context, kinds map[pb.ServiceKind]bool, updates chan struct{}) ([]*pb.Service, error) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-updates:
|
|
return s.getServices(kinds), nil
|
|
}
|
|
}
|
|
|
|
func (s *inMemoryServer) getServices(kinds map[pb.ServiceKind]bool) []*pb.Service {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
return s.getServicesLocked(kinds)
|
|
}
|
|
|
|
func (s *inMemoryServer) getServicesLocked(kinds map[pb.ServiceKind]bool) []*pb.Service {
|
|
out := make([]*pb.Service, 0, len(s.regs))
|
|
for k := range s.regs {
|
|
if len(kinds) == 0 {
|
|
// all catch empty filter
|
|
} else if _, exists := kinds[k.kind]; !exists {
|
|
continue
|
|
}
|
|
out = append(out, &pb.Service{Kind: k.kind, Endpoint: k.endpoint})
|
|
}
|
|
return out
|
|
}
|