pomerium/pkg/storage/postgres/registry.go
Caleb Doxsey 0cfb1025db
core/proto: update protoc dependencies (#5218)
* core/proto: update protoc dependencies

* cleanup

* disable unimplemented forward compatibility check

* fix mock

* add generate make command

* add .0
2024-08-15 11:12:05 -06:00

110 lines
2.2 KiB
Go

package postgres
import (
"context"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/registry"
)
type registryServer struct {
registry.UnimplementedRegistryServer
*Backend
}
// RegistryServer returns a registry.RegistryServer for the backend.
func (backend *Backend) RegistryServer() registry.RegistryServer {
return registryServer{Backend: backend}
}
// List lists services.
func (backend registryServer) List(
ctx context.Context,
req *registry.ListRequest,
) (*registry.ServiceList, error) {
_, pool, err := backend.init(ctx)
if err != nil {
return nil, err
}
all, err := listServices(ctx, pool)
if err != nil {
return nil, err
}
res := new(registry.ServiceList)
s := sets.NewHash[registry.ServiceKind]()
s.Add(req.GetKinds()...)
for _, svc := range all {
if s.Size() == 0 || s.Has(svc.GetKind()) {
res.Services = append(res.Services, svc)
}
}
return res, nil
}
// Report registers services.
func (backend registryServer) Report(
ctx context.Context,
req *registry.RegisterRequest,
) (*registry.RegisterResponse, error) {
_, pool, err := backend.init(ctx)
if err != nil {
return nil, err
}
for _, svc := range req.GetServices() {
err = putService(ctx, pool, svc, time.Now().Add(backend.cfg.registryTTL))
if err != nil {
return nil, err
}
}
err = signalServiceChange(ctx, pool)
if err != nil {
return nil, err
}
return &registry.RegisterResponse{
CallBackAfter: durationpb.New(backend.cfg.registryTTL / 2),
}, nil
}
// Watch watches services.
func (backend registryServer) Watch(
req *registry.ListRequest,
srv registry.Registry_WatchServer,
) error {
ch := backend.onServiceChange.Bind()
defer backend.onServiceChange.Unbind(ch)
ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop()
var prev *registry.ServiceList
for i := 0; ; i++ {
res, err := backend.List(srv.Context(), req)
if err != nil {
return err
}
if i == 0 || !proto.Equal(res, prev) {
err = srv.Send(res)
if err != nil {
return err
}
prev = res
}
select {
case <-srv.Context().Done():
return srv.Context().Err()
case <-ch:
case <-ticker.C:
}
}
}