zero: only leave public packages in pkg/zero (#4854)

This commit is contained in:
Denis Mishin 2023-12-12 14:24:37 -05:00 committed by GitHub
parent a6ae9d3f2d
commit b66634d1e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 22 additions and 22 deletions

View file

@ -0,0 +1,57 @@
package mux
import "context"
type config struct {
onConnected func(ctx context.Context)
onDisconnected func(ctx context.Context)
onBundleUpdated func(ctx context.Context, key string)
onBootstrapConfigUpdated func(ctx context.Context)
}
// WatchOption allows to specify callbacks for various events
type WatchOption func(*config)
// WithOnConnected sets the callback for when the connection is established
func WithOnConnected(onConnected func(context.Context)) WatchOption {
return func(cfg *config) {
cfg.onConnected = onConnected
}
}
// WithOnDisconnected sets the callback for when the connection is lost
func WithOnDisconnected(onDisconnected func(context.Context)) WatchOption {
return func(cfg *config) {
cfg.onDisconnected = onDisconnected
}
}
// WithOnBundleUpdated sets the callback for when the bundle is updated
func WithOnBundleUpdated(onBundleUpdated func(ctx context.Context, key string)) WatchOption {
return func(cfg *config) {
cfg.onBundleUpdated = onBundleUpdated
}
}
// WithOnBootstrapConfigUpdated sets the callback for when the bootstrap config is updated
func WithOnBootstrapConfigUpdated(onBootstrapConfigUpdated func(context.Context)) WatchOption {
return func(cfg *config) {
cfg.onBootstrapConfigUpdated = onBootstrapConfigUpdated
}
}
func newConfig(opts ...WatchOption) *config {
cfg := &config{}
for _, opt := range []WatchOption{
WithOnConnected(func(_ context.Context) {}),
WithOnDisconnected(func(_ context.Context) {}),
WithOnBundleUpdated(func(_ context.Context, key string) {}),
WithOnBootstrapConfigUpdated(func(_ context.Context) {}),
} {
opt(cfg)
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}

View file

@ -0,0 +1,109 @@
package mux
import (
"context"
"fmt"
"github.com/pomerium/pomerium/internal/zero/apierror"
"github.com/pomerium/pomerium/pkg/zero/connect"
)
// Watch watches for changes to the config until either context is canceled,
// or an error occurs while muxing
func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-svc.ready:
}
cfg := newConfig(opts...)
connected := svc.connected.Load()
if connected {
cfg.onConnected(ctx)
} else {
cfg.onDisconnected(ctx)
}
return svc.mux.Receive(ctx, func(ctx context.Context, msg message) error {
return dispatch(ctx, cfg, msg)
})
}
func dispatch(ctx context.Context, cfg *config, msg message) error {
switch {
case msg.stateChange != nil:
switch *msg.stateChange {
case connected:
cfg.onConnected(ctx)
case disconnected:
cfg.onDisconnected(ctx)
default:
return fmt.Errorf("unknown state change")
}
case msg.Message != nil:
switch msg.Message.Message.(type) {
case *connect.Message_ConfigUpdated:
cfg.onBundleUpdated(ctx, "config")
case *connect.Message_BootstrapConfigUpdated:
cfg.onBootstrapConfigUpdated(ctx)
default:
return fmt.Errorf("unknown message type")
}
default:
return fmt.Errorf("unknown message payload")
}
return nil
}
type message struct {
*stateChange
*connect.Message
}
type stateChange string
const (
connected stateChange = "connected"
disconnected stateChange = "disconnected"
)
// Publish publishes a message to the fanout
// we treat errors returned from the fanout as terminal,
// as they are generally non recoverable
func (svc *Mux) publish(ctx context.Context, msg message) error {
err := svc.mux.Publish(ctx, msg)
if err != nil {
return apierror.NewTerminalError(err)
}
return nil
}
func (svc *Mux) onConnected(ctx context.Context) error {
s := connected
svc.connected.Store(true)
err := svc.publish(ctx, message{stateChange: &s})
if err != nil {
return fmt.Errorf("onConnected: %w", err)
}
return nil
}
func (svc *Mux) onDisconnected(ctx context.Context) error {
s := disconnected
svc.connected.Store(false)
err := svc.publish(ctx, message{stateChange: &s})
if err != nil {
return fmt.Errorf("onDisconnected: %w", err)
}
return nil
}
func (svc *Mux) onMessage(ctx context.Context, msg *connect.Message) error {
err := svc.publish(ctx, message{Message: msg})
if err != nil {
return fmt.Errorf("onMessage: %w", err)
}
return nil
}

View file

@ -0,0 +1,108 @@
// Package mux provides the way to listen for updates from the cloud
package mux
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/rs/zerolog/log"
"github.com/pomerium/pomerium/internal/zero/apierror"
"github.com/pomerium/pomerium/pkg/fanout"
"github.com/pomerium/pomerium/pkg/zero/connect"
)
// Mux is the service that listens for updates from the cloud
type Mux struct {
client connect.ConnectClient
mux *fanout.FanOut[message]
ready chan struct{}
connected atomic.Bool
}
// New creates a new mux service that listens for updates from the cloud
func New(client connect.ConnectClient) *Mux {
svc := &Mux{
client: client,
ready: make(chan struct{}),
}
return svc
}
// Run starts the service
func (svc *Mux) Run(ctx context.Context, opts ...fanout.Option) error {
ctx, cancel := context.WithCancelCause(ctx)
defer func() { cancel(ctx.Err()) }()
svc.mux = fanout.Start[message](ctx, opts...)
close(svc.ready)
err := svc.run(ctx)
if err != nil {
cancel(err)
return err
}
return nil
}
func (svc *Mux) run(ctx context.Context) error {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
ticker := time.NewTicker(time.Microsecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
err := svc.subscribeAndDispatch(ctx, bo.Reset)
if err != nil {
ticker.Reset(bo.NextBackOff())
}
if apierror.IsTerminalError(err) {
return err
}
}
}
func (svc *Mux) subscribeAndDispatch(ctx context.Context, onConnected func()) (err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := svc.client.Subscribe(ctx, &connect.SubscribeRequest{})
if err != nil {
return fmt.Errorf("subscribe: %w", err)
}
onConnected()
if err = svc.onConnected(ctx); err != nil {
return err
}
defer func() {
err = multierror.Append(err, svc.onDisconnected(ctx)).ErrorOrNil()
}()
log.Ctx(ctx).Info().Msg("subscribed to connect service")
for {
msg, err := stream.Recv()
log.Ctx(ctx).Info().Interface("msg", msg).Err(err).Msg("receive")
if err != nil {
return fmt.Errorf("receive: %w", err)
}
err = svc.onMessage(ctx, msg)
if err != nil {
return err
}
}
}