mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
128 lines
2.9 KiB
Go
128 lines
2.9 KiB
Go
package mux
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
|
|
"github.com/pomerium/pomerium/internal/zero/apierror"
|
|
"github.com/pomerium/pomerium/pkg/zero/connect"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// Watch watches for changes to the config until either context is canceled,
|
|
// or an error occurs while muxing
|
|
func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-svc.ready:
|
|
}
|
|
|
|
cfg := newConfig(opts...)
|
|
|
|
connected := svc.connected.Load()
|
|
if connected {
|
|
cfg.onConnected(ctx)
|
|
} else {
|
|
cfg.onDisconnected(ctx)
|
|
}
|
|
|
|
return svc.mux.Receive(ctx, func(ctx context.Context, msg message) error {
|
|
return dispatch(ctx, cfg, msg)
|
|
})
|
|
}
|
|
|
|
func dispatch(ctx context.Context, cfg *config, msg message) error {
|
|
switch {
|
|
case msg.stateChange != nil:
|
|
switch *msg.stateChange {
|
|
case connected:
|
|
cfg.onConnected(ctx)
|
|
case disconnected:
|
|
cfg.onDisconnected(ctx)
|
|
default:
|
|
return fmt.Errorf("unknown state change")
|
|
}
|
|
case msg.Message != nil:
|
|
switch msg.Message.Message.(type) {
|
|
case *connect.Message_ConfigUpdated:
|
|
cfg.onBundleUpdated(ctx, "config")
|
|
case *connect.Message_BootstrapConfigUpdated:
|
|
cfg.onBootstrapConfigUpdated(ctx)
|
|
case *connect.Message_TelemetryRequest:
|
|
cfg.onTelemetryRequested(ctx, msg.Message.GetTelemetryRequest())
|
|
default:
|
|
log.Ctx(ctx).Debug().Msg("unknown message type, ignored")
|
|
}
|
|
default:
|
|
return fmt.Errorf("unknown message payload")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type message struct {
|
|
*stateChange
|
|
*connect.Message
|
|
}
|
|
|
|
func (msg message) String() string {
|
|
var b strings.Builder
|
|
if msg.stateChange != nil {
|
|
b.WriteString("stateChange: ")
|
|
b.WriteString(string(*msg.stateChange))
|
|
}
|
|
if msg.Message != nil {
|
|
b.WriteString("message: ")
|
|
b.WriteString(protojson.Format(msg.Message))
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
type stateChange string
|
|
|
|
const (
|
|
connected stateChange = "connected"
|
|
disconnected stateChange = "disconnected"
|
|
)
|
|
|
|
// Publish publishes a message to the fanout
|
|
// we treat errors returned from the fanout as terminal,
|
|
// as they are generally non recoverable
|
|
func (svc *Mux) publish(ctx context.Context, msg message) error {
|
|
err := svc.mux.Publish(ctx, msg)
|
|
if err != nil {
|
|
return apierror.NewTerminalError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (svc *Mux) onConnected(ctx context.Context) error {
|
|
s := connected
|
|
svc.connected.Store(true)
|
|
err := svc.publish(ctx, message{stateChange: &s})
|
|
if err != nil {
|
|
return fmt.Errorf("onConnected: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (svc *Mux) onDisconnected(ctx context.Context) error {
|
|
s := disconnected
|
|
svc.connected.Store(false)
|
|
err := svc.publish(ctx, message{stateChange: &s})
|
|
if err != nil {
|
|
return fmt.Errorf("onDisconnected: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (svc *Mux) onMessage(ctx context.Context, msg *connect.Message) error {
|
|
err := svc.publish(ctx, message{Message: msg})
|
|
if err != nil {
|
|
return fmt.Errorf("onMessage: %w", err)
|
|
}
|
|
return nil
|
|
}
|