mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 16:30:17 +02:00
zero: only leave public packages in pkg/zero (#4854)
This commit is contained in:
parent
a6ae9d3f2d
commit
b66634d1e6
24 changed files with 22 additions and 22 deletions
57
internal/zero/connect-mux/config.go
Normal file
57
internal/zero/connect-mux/config.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package mux
|
||||
|
||||
import "context"
|
||||
|
||||
type config struct {
|
||||
onConnected func(ctx context.Context)
|
||||
onDisconnected func(ctx context.Context)
|
||||
onBundleUpdated func(ctx context.Context, key string)
|
||||
onBootstrapConfigUpdated func(ctx context.Context)
|
||||
}
|
||||
|
||||
// WatchOption allows to specify callbacks for various events
|
||||
type WatchOption func(*config)
|
||||
|
||||
// WithOnConnected sets the callback for when the connection is established
|
||||
func WithOnConnected(onConnected func(context.Context)) WatchOption {
|
||||
return func(cfg *config) {
|
||||
cfg.onConnected = onConnected
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnDisconnected sets the callback for when the connection is lost
|
||||
func WithOnDisconnected(onDisconnected func(context.Context)) WatchOption {
|
||||
return func(cfg *config) {
|
||||
cfg.onDisconnected = onDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnBundleUpdated sets the callback for when the bundle is updated
|
||||
func WithOnBundleUpdated(onBundleUpdated func(ctx context.Context, key string)) WatchOption {
|
||||
return func(cfg *config) {
|
||||
cfg.onBundleUpdated = onBundleUpdated
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnBootstrapConfigUpdated sets the callback for when the bootstrap config is updated
|
||||
func WithOnBootstrapConfigUpdated(onBootstrapConfigUpdated func(context.Context)) WatchOption {
|
||||
return func(cfg *config) {
|
||||
cfg.onBootstrapConfigUpdated = onBootstrapConfigUpdated
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(opts ...WatchOption) *config {
|
||||
cfg := &config{}
|
||||
for _, opt := range []WatchOption{
|
||||
WithOnConnected(func(_ context.Context) {}),
|
||||
WithOnDisconnected(func(_ context.Context) {}),
|
||||
WithOnBundleUpdated(func(_ context.Context, key string) {}),
|
||||
WithOnBootstrapConfigUpdated(func(_ context.Context) {}),
|
||||
} {
|
||||
opt(cfg)
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
109
internal/zero/connect-mux/messages.go
Normal file
109
internal/zero/connect-mux/messages.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/apierror"
|
||||
"github.com/pomerium/pomerium/pkg/zero/connect"
|
||||
)
|
||||
|
||||
// Watch watches for changes to the config until either context is canceled,
|
||||
// or an error occurs while muxing
|
||||
func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-svc.ready:
|
||||
}
|
||||
|
||||
cfg := newConfig(opts...)
|
||||
|
||||
connected := svc.connected.Load()
|
||||
if connected {
|
||||
cfg.onConnected(ctx)
|
||||
} else {
|
||||
cfg.onDisconnected(ctx)
|
||||
}
|
||||
|
||||
return svc.mux.Receive(ctx, func(ctx context.Context, msg message) error {
|
||||
return dispatch(ctx, cfg, msg)
|
||||
})
|
||||
}
|
||||
|
||||
func dispatch(ctx context.Context, cfg *config, msg message) error {
|
||||
switch {
|
||||
case msg.stateChange != nil:
|
||||
switch *msg.stateChange {
|
||||
case connected:
|
||||
cfg.onConnected(ctx)
|
||||
case disconnected:
|
||||
cfg.onDisconnected(ctx)
|
||||
default:
|
||||
return fmt.Errorf("unknown state change")
|
||||
}
|
||||
case msg.Message != nil:
|
||||
switch msg.Message.Message.(type) {
|
||||
case *connect.Message_ConfigUpdated:
|
||||
cfg.onBundleUpdated(ctx, "config")
|
||||
case *connect.Message_BootstrapConfigUpdated:
|
||||
cfg.onBootstrapConfigUpdated(ctx)
|
||||
default:
|
||||
return fmt.Errorf("unknown message type")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown message payload")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type message struct {
|
||||
*stateChange
|
||||
*connect.Message
|
||||
}
|
||||
|
||||
type stateChange string
|
||||
|
||||
const (
|
||||
connected stateChange = "connected"
|
||||
disconnected stateChange = "disconnected"
|
||||
)
|
||||
|
||||
// Publish publishes a message to the fanout
|
||||
// we treat errors returned from the fanout as terminal,
|
||||
// as they are generally non recoverable
|
||||
func (svc *Mux) publish(ctx context.Context, msg message) error {
|
||||
err := svc.mux.Publish(ctx, msg)
|
||||
if err != nil {
|
||||
return apierror.NewTerminalError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Mux) onConnected(ctx context.Context) error {
|
||||
s := connected
|
||||
svc.connected.Store(true)
|
||||
err := svc.publish(ctx, message{stateChange: &s})
|
||||
if err != nil {
|
||||
return fmt.Errorf("onConnected: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Mux) onDisconnected(ctx context.Context) error {
|
||||
s := disconnected
|
||||
svc.connected.Store(false)
|
||||
err := svc.publish(ctx, message{stateChange: &s})
|
||||
if err != nil {
|
||||
return fmt.Errorf("onDisconnected: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Mux) onMessage(ctx context.Context, msg *connect.Message) error {
|
||||
err := svc.publish(ctx, message{Message: msg})
|
||||
if err != nil {
|
||||
return fmt.Errorf("onMessage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
108
internal/zero/connect-mux/service.go
Normal file
108
internal/zero/connect-mux/service.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
// Package mux provides the way to listen for updates from the cloud
|
||||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/apierror"
|
||||
"github.com/pomerium/pomerium/pkg/fanout"
|
||||
"github.com/pomerium/pomerium/pkg/zero/connect"
|
||||
)
|
||||
|
||||
// Mux is the service that listens for updates from the cloud
|
||||
type Mux struct {
|
||||
client connect.ConnectClient
|
||||
mux *fanout.FanOut[message]
|
||||
|
||||
ready chan struct{}
|
||||
|
||||
connected atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new mux service that listens for updates from the cloud
|
||||
func New(client connect.ConnectClient) *Mux {
|
||||
svc := &Mux{
|
||||
client: client,
|
||||
ready: make(chan struct{}),
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
// Run starts the service
|
||||
func (svc *Mux) Run(ctx context.Context, opts ...fanout.Option) error {
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer func() { cancel(ctx.Err()) }()
|
||||
|
||||
svc.mux = fanout.Start[message](ctx, opts...)
|
||||
close(svc.ready)
|
||||
|
||||
err := svc.run(ctx)
|
||||
if err != nil {
|
||||
cancel(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Mux) run(ctx context.Context) error {
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
|
||||
ticker := time.NewTicker(time.Microsecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
err := svc.subscribeAndDispatch(ctx, bo.Reset)
|
||||
if err != nil {
|
||||
ticker.Reset(bo.NextBackOff())
|
||||
}
|
||||
|
||||
if apierror.IsTerminalError(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Mux) subscribeAndDispatch(ctx context.Context, onConnected func()) (err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := svc.client.Subscribe(ctx, &connect.SubscribeRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("subscribe: %w", err)
|
||||
}
|
||||
onConnected()
|
||||
|
||||
if err = svc.onConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = multierror.Append(err, svc.onDisconnected(ctx)).ErrorOrNil()
|
||||
}()
|
||||
|
||||
log.Ctx(ctx).Info().Msg("subscribed to connect service")
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
log.Ctx(ctx).Info().Interface("msg", msg).Err(err).Msg("receive")
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive: %w", err)
|
||||
}
|
||||
err = svc.onMessage(ctx, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue