mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-24 20:18:13 +02:00
New tracing system (#5388)
* update tracing config definitions * new tracing system * performance improvements * only configure tracing in envoy if it is enabled in pomerium * [tracing] refactor to use custom extension for trace id editing (#5420) refactor to use custom extension for trace id editing * set default tracing sample rate to 1.0 * fix proxy service http middleware * improve some existing auth related traces * test fixes * bump envoyproxy/go-control-plane * code cleanup * test fixes * Fix missing spans for well-known endpoints * import extension apis from pomerium/envoy-custom
This commit is contained in:
parent
832742648d
commit
396c35b6b4
121 changed files with 6096 additions and 1946 deletions
|
@ -17,7 +17,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -26,25 +25,34 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
databroker_service "github.com/pomerium/pomerium/databroker"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv/envutil"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/envoy"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/health"
|
||||
"github.com/pomerium/pomerium/pkg/identity/legacymanager"
|
||||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
"github.com/pomerium/pomerium/pkg/slices"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
)
|
||||
|
@ -56,6 +64,7 @@ type Environment interface {
|
|||
// top-level logger scoped to this environment. It will be canceled when
|
||||
// Stop() is called, or during test cleanup.
|
||||
Context() context.Context
|
||||
Tracer() oteltrace.Tracer
|
||||
|
||||
Assert() *assert.Assertions
|
||||
Require() *require.Assertions
|
||||
|
@ -133,10 +142,29 @@ type Environment interface {
|
|||
// the Pomerium server and Envoy.
|
||||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||
|
||||
// GetState returns the current state of the test environment.
|
||||
GetState() EnvironmentState
|
||||
|
||||
// OnStateChanged registers a callback to be invoked when the environment's
|
||||
// state changes to the given state. The callback is invoked in a separate
|
||||
// goroutine.
|
||||
OnStateChanged(state EnvironmentState, callback func())
|
||||
// state changes to the given state. Each callback is invoked in a separate
|
||||
// goroutine, but the test environment will wait for all callbacks to return
|
||||
// before continuing, after triggering the state change.
|
||||
//
|
||||
// Calling the returned stop function will prevent the callback from being
|
||||
// run. Returns true if it stopped the callback from being run, or false if
|
||||
// it already ran or is currently running.
|
||||
//
|
||||
// If the environment is already in the given state, the callback will be run
|
||||
// in a separate goroutine immediately and the returned stop function will
|
||||
// have no effect. A callback run in this way will prevent the state from
|
||||
// advancing until the callback returns.
|
||||
//
|
||||
// State changes are triggered in the following places:
|
||||
// - NotRunning->Starting: in Start(), as the first operation
|
||||
// - Starting->Running: in Start(), just before returning
|
||||
// - Running->Stopping: in Stop(), just before the env context is canceled
|
||||
// - Stopping->Stopped: in Stop(), after all tasks have completed
|
||||
OnStateChanged(state EnvironmentState, callback func()) (stop func() bool)
|
||||
}
|
||||
|
||||
type Certificate tls.Certificate
|
||||
|
@ -153,10 +181,9 @@ func (c *Certificate) SPKIHash() string {
|
|||
|
||||
type EnvironmentState uint32
|
||||
|
||||
const NotRunning EnvironmentState = 0
|
||||
|
||||
const (
|
||||
Starting EnvironmentState = 1 << iota
|
||||
NotRunning EnvironmentState = iota
|
||||
Starting
|
||||
Running
|
||||
Stopping
|
||||
Stopped
|
||||
|
@ -192,10 +219,13 @@ type environment struct {
|
|||
workspaceFolder string
|
||||
silent bool
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
cleanupOnce sync.Once
|
||||
logWriter *log.MultiWriter
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
cleanupOnce sync.Once
|
||||
logWriter *log.MultiWriter
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
tracer oteltrace.Tracer
|
||||
rootSpan oteltrace.Span
|
||||
|
||||
mods []WithCaller[Modifier]
|
||||
tasks []WithCaller[Task]
|
||||
|
@ -204,14 +234,17 @@ type environment struct {
|
|||
stateMu sync.Mutex
|
||||
state EnvironmentState
|
||||
stateChangeListeners map[EnvironmentState][]func()
|
||||
stateChangeBlockers sync.WaitGroup
|
||||
|
||||
src *configSource
|
||||
}
|
||||
|
||||
type EnvironmentOptions struct {
|
||||
debug bool
|
||||
pauseOnFailure bool
|
||||
forceSilent bool
|
||||
debug bool
|
||||
pauseOnFailure bool
|
||||
forceSilent bool
|
||||
traceDebugFlags trace.DebugFlags
|
||||
traceClient otlptrace.Client
|
||||
}
|
||||
|
||||
type EnvironmentOption func(*EnvironmentOptions)
|
||||
|
@ -249,28 +282,57 @@ func Silent(silent ...bool) EnvironmentOption {
|
|||
}
|
||||
}
|
||||
|
||||
const StandardTraceDebugFlags = trace.TrackSpanCallers |
|
||||
trace.WarnOnIncompleteSpans |
|
||||
trace.WarnOnIncompleteTraces |
|
||||
trace.WarnOnUnresolvedReferences |
|
||||
trace.LogTraceIDsOnWarn |
|
||||
trace.LogAllSpansOnWarn
|
||||
|
||||
func WithTraceDebugFlags(flags trace.DebugFlags) EnvironmentOption {
|
||||
return func(o *EnvironmentOptions) {
|
||||
o.traceDebugFlags = flags
|
||||
}
|
||||
}
|
||||
|
||||
func WithTraceClient(traceClient otlptrace.Client) EnvironmentOption {
|
||||
return func(o *EnvironmentOptions) {
|
||||
o.traceClient = traceClient
|
||||
}
|
||||
}
|
||||
|
||||
var setGrpcLoggerOnce sync.Once
|
||||
|
||||
const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences
|
||||
|
||||
var (
|
||||
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
|
||||
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
|
||||
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
|
||||
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
|
||||
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
|
||||
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
|
||||
flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)")
|
||||
)
|
||||
|
||||
func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("test environment only supported on linux")
|
||||
addTraceDebugFlags := strings.HasPrefix(*flagTraceDebugFlags, "+")
|
||||
defaultTraceDebugFlags, err := strconv.Atoi(strings.TrimPrefix(*flagTraceDebugFlags, "+"))
|
||||
if err != nil {
|
||||
panic("malformed value for --env.trace-debug-flags: " + err.Error())
|
||||
}
|
||||
options := EnvironmentOptions{
|
||||
debug: *flagDebug,
|
||||
pauseOnFailure: *flagPauseOnFailure,
|
||||
forceSilent: *flagSilent,
|
||||
debug: *flagDebug,
|
||||
pauseOnFailure: *flagPauseOnFailure,
|
||||
forceSilent: *flagSilent,
|
||||
traceDebugFlags: trace.DebugFlags(defaultTraceDebugFlags),
|
||||
}
|
||||
options.apply(opts...)
|
||||
if testing.Short() {
|
||||
t.Helper()
|
||||
t.Skip("test environment disabled in short mode")
|
||||
}
|
||||
if addTraceDebugFlags {
|
||||
options.traceDebugFlags |= trace.DebugFlags(defaultTraceDebugFlags)
|
||||
}
|
||||
trace.UseGlobalPanicTracer()
|
||||
databroker.DebugUseFasterBackoff.Store(true)
|
||||
workspaceFolder, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
@ -305,7 +367,16 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
|||
})
|
||||
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancelCause(logger.WithContext(context.Background()))
|
||||
ctx := trace.Options{
|
||||
DebugFlags: options.traceDebugFlags,
|
||||
RemoteClient: options.traceClient,
|
||||
}.NewContext(logger.WithContext(context.Background()))
|
||||
tracerProvider := trace.NewTracerProvider(ctx, "Test Environment")
|
||||
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
||||
ctx, span := tracer.Start(ctx, t.Name(), oteltrace.WithNewRoot())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
taskErrGroup, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
e := &environment{
|
||||
|
@ -313,7 +384,7 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
|||
t: t,
|
||||
assert: assert.New(t),
|
||||
require: require.New(t),
|
||||
tempDir: t.TempDir(),
|
||||
tempDir: tempDir(t),
|
||||
ports: Ports{
|
||||
ProxyHTTP: values.Deferred[int](),
|
||||
ProxyGRPC: values.Deferred[int](),
|
||||
|
@ -325,13 +396,18 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
|||
Debug: values.Deferred[int](),
|
||||
ALPN: values.Deferred[int](),
|
||||
},
|
||||
workspaceFolder: workspaceFolder,
|
||||
silent: silent,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logWriter: writer,
|
||||
taskErrGroup: taskErrGroup,
|
||||
workspaceFolder: workspaceFolder,
|
||||
silent: silent,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
tracerProvider: tracerProvider,
|
||||
tracer: tracer,
|
||||
logWriter: writer,
|
||||
taskErrGroup: taskErrGroup,
|
||||
stateChangeListeners: make(map[EnvironmentState][]func()),
|
||||
rootSpan: span,
|
||||
}
|
||||
|
||||
_, err = rand.Read(e.sharedSecret[:])
|
||||
require.NoError(t, err)
|
||||
_, err = rand.Read(e.cookieSecret[:])
|
||||
|
@ -362,11 +438,13 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
|||
|
||||
func (e *environment) debugf(format string, args ...any) {
|
||||
e.t.Helper()
|
||||
if e.rootSpan.IsRecording() {
|
||||
e.rootSpan.AddEvent(fmt.Sprintf(format, args...))
|
||||
}
|
||||
if !e.debug {
|
||||
return
|
||||
}
|
||||
|
||||
e.t.Logf("\x1b[34m[debug] "+format+"\x1b[0m", args...)
|
||||
e.t.Logf("\x1b[34mDEBUG ["+e.t.Name()+"] "+format+"\x1b[0m", args...)
|
||||
}
|
||||
|
||||
type WithCaller[T any] struct {
|
||||
|
@ -394,6 +472,10 @@ func (e *environment) Context() context.Context {
|
|||
return ContextWithEnv(e.ctx, e)
|
||||
}
|
||||
|
||||
func (e *environment) Tracer() oteltrace.Tracer {
|
||||
return e.tracer
|
||||
}
|
||||
|
||||
func (e *environment) Assert() *assert.Assertions {
|
||||
return e.assert
|
||||
}
|
||||
|
@ -455,9 +537,11 @@ var ErrCauseTestCleanup = errors.New("test cleanup")
|
|||
var ErrCauseManualStop = errors.New("Stop() called")
|
||||
|
||||
func (e *environment) Start() {
|
||||
_, span := e.tracer.Start(e.Context(), "Start")
|
||||
defer span.End()
|
||||
e.debugf("Start()")
|
||||
e.advanceState(Starting)
|
||||
e.t.Cleanup(e.cleanup)
|
||||
e.t.Cleanup(e.onTestCleanup)
|
||||
e.t.Setenv("TMPDIR", e.TempDir())
|
||||
e.debugf("temp dir: %s", e.TempDir())
|
||||
|
||||
|
@ -524,8 +608,13 @@ func (e *environment) Start() {
|
|||
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
||||
}
|
||||
|
||||
opts := []pomerium.RunOption{
|
||||
opts := []pomerium.Option{
|
||||
pomerium.WithOverrideFileManager(fileMgr),
|
||||
pomerium.WithEnvoyServerOptions(envoy.WithExitGracePeriod(30 * time.Second)),
|
||||
pomerium.WithDataBrokerServerOptions(
|
||||
databroker_service.WithManagerOptions(manager.WithLeaseTTL(1*time.Second)),
|
||||
databroker_service.WithLegacyManagerOptions(legacymanager.WithLeaseTTL(1*time.Second)),
|
||||
),
|
||||
}
|
||||
envoyBinaryPath := filepath.Join(e.workspaceFolder, fmt.Sprintf("pkg/envoy/files/envoy-%s-%s", runtime.GOOS, runtime.GOARCH))
|
||||
if envutil.EnvoyProfilerAvailable(envoyBinaryPath) {
|
||||
|
@ -556,23 +645,29 @@ func (e *environment) Start() {
|
|||
}
|
||||
if len(envVars) > 0 {
|
||||
e.debugf("adding envoy env vars: %v\n", envVars)
|
||||
opts = append(opts, pomerium.WithEnvoyServerOptions(
|
||||
envoy.WithExtraEnvVars(envVars...),
|
||||
envoy.WithExitGracePeriod(10*time.Second), // allow envoy time to flush pprof data to disk
|
||||
))
|
||||
opts = append(opts, pomerium.WithEnvoyServerOptions(envoy.WithExtraEnvVars(envVars...)))
|
||||
}
|
||||
} else {
|
||||
e.debugf("envoy profiling not available")
|
||||
}
|
||||
|
||||
return pomerium.Run(ctx, e.src, opts...)
|
||||
pom := pomerium.New(opts...)
|
||||
e.OnStateChanged(Stopping, func() {
|
||||
if err := pom.Shutdown(ctx); err != nil {
|
||||
log.Ctx(ctx).Err(err).Msg("error shutting down pomerium server")
|
||||
} else {
|
||||
e.debugf("pomerium server shut down without error")
|
||||
}
|
||||
})
|
||||
require.NoError(e.t, pom.Start(ctx, e.tracerProvider, e.src))
|
||||
return pom.Wait()
|
||||
}))
|
||||
|
||||
for i, task := range e.tasks {
|
||||
log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
|
||||
log.Ctx(e.Context()).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
|
||||
e.taskErrGroup.Go(func() error {
|
||||
defer log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
|
||||
return task.Value.Run(e.ctx)
|
||||
defer log.Ctx(e.Context()).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
|
||||
return task.Value.Run(e.Context())
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -695,14 +790,9 @@ func (e *environment) Stop() {
|
|||
b.StopTimer()
|
||||
defer b.StartTimer()
|
||||
}
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.debugf("stop: Stop() called manually")
|
||||
e.advanceState(Stopping)
|
||||
e.cancel(ErrCauseManualStop)
|
||||
err := e.taskErrGroup.Wait()
|
||||
e.advanceState(Stopped)
|
||||
e.debugf("stop: done waiting")
|
||||
assert.ErrorIs(e.t, err, ErrCauseManualStop)
|
||||
e.cleanup(fmt.Errorf("%w (caller: %s:%d)", ErrCauseManualStop, file, line))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -714,33 +804,51 @@ func (e *environment) Pause() {
|
|||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||
}
|
||||
|
||||
func (e *environment) cleanup() {
|
||||
func (e *environment) onTestCleanup() {
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.debugf("stop: test cleanup")
|
||||
if e.t.Failed() {
|
||||
if e.pauseOnFailure {
|
||||
e.t.Log("\x1b[31m*** pausing on test failure; continue with ctrl+c ***\x1b[0m")
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
<-c
|
||||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||
signal.Stop(c)
|
||||
}
|
||||
}
|
||||
e.advanceState(Stopping)
|
||||
e.cancel(ErrCauseTestCleanup)
|
||||
err := e.taskErrGroup.Wait()
|
||||
e.advanceState(Stopped)
|
||||
e.debugf("stop: done waiting")
|
||||
assert.ErrorIs(e.t, err, ErrCauseTestCleanup)
|
||||
e.cleanup(ErrCauseTestCleanup)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) cleanup(cancelCause error) {
|
||||
e.debugf("stop: %s", cancelCause.Error())
|
||||
if e.t.Failed() {
|
||||
if e.pauseOnFailure {
|
||||
e.t.Log("\x1b[31m*** pausing on test failure; continue with ctrl+c ***\x1b[0m")
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
<-c
|
||||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||
signal.Stop(c)
|
||||
}
|
||||
}
|
||||
e.advanceState(Stopping)
|
||||
e.cancel(cancelCause)
|
||||
errs := []error{}
|
||||
if err := e.taskErrGroup.Wait(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("error waiting for tasks: %w", err))
|
||||
}
|
||||
e.rootSpan.End()
|
||||
if err := trace.ShutdownContext(e.Context()); err != nil {
|
||||
errs = append(errs, fmt.Errorf("error shutting down trace context: %w", err))
|
||||
}
|
||||
e.advanceState(Stopped)
|
||||
// Wait for any additional callbacks created during stopped callbacks
|
||||
// (for consistency, we consider the stopped state to "end" here)
|
||||
e.stateChangeBlockers.Wait()
|
||||
e.debugf("stop: done")
|
||||
// err can be nil if e.g. the only task is the internal pomerium task, which
|
||||
// returns a nil error if it exits cleanly
|
||||
if err := errors.Join(errs...); err != nil {
|
||||
assert.ErrorIs(e.t, err, cancelCause)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *environment) Add(m Modifier) {
|
||||
e.t.Helper()
|
||||
caller := getCaller()
|
||||
e.debugf("Add: %T from %s", m, caller)
|
||||
switch e.getState() {
|
||||
switch e.GetState() {
|
||||
case NotRunning:
|
||||
for _, mod := range e.mods {
|
||||
if mod.Value == m {
|
||||
|
@ -757,11 +865,11 @@ func (e *environment) Add(m Modifier) {
|
|||
panic("test bug: cannot call Add() before Start() has returned")
|
||||
case Running:
|
||||
e.debugf("Add: state=Running; calling ModifyConfig")
|
||||
e.src.ModifyConfig(e.ctx, m)
|
||||
e.src.ModifyConfig(e.Context(), m)
|
||||
case Stopped, Stopping:
|
||||
panic("test bug: cannot call Add() after Stop()")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected environment state: %s", e.getState()))
|
||||
panic(fmt.Sprintf("unexpected environment state: %s", e.GetState()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -803,34 +911,63 @@ func (e *environment) advanceState(newState EnvironmentState) {
|
|||
if newState <= e.state {
|
||||
panic(fmt.Sprintf("internal test environment bug: changed state to <= current: newState=%s, current=%s", newState, e.state))
|
||||
}
|
||||
e.stateChangeBlockers.Wait()
|
||||
e.debugf("state %s -> %s", e.state.String(), newState.String())
|
||||
e.state = newState
|
||||
e.debugf("notifying %d listeners of state change", len(e.stateChangeListeners[newState]))
|
||||
for _, listener := range e.stateChangeListeners[newState] {
|
||||
go listener()
|
||||
if len(e.stateChangeListeners[newState]) > 0 {
|
||||
e.debugf("notifying %d listeners of state change", len(e.stateChangeListeners[newState]))
|
||||
var wg sync.WaitGroup
|
||||
for _, listener := range e.stateChangeListeners[newState] {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, span := e.tracer.Start(e.Context(), "State Change Callback")
|
||||
span.SetAttributes(attribute.String("state", newState.String()))
|
||||
defer span.End()
|
||||
defer wg.Done()
|
||||
listener()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
e.debugf("done notifying state change listeners")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *environment) getState() EnvironmentState {
|
||||
func (e *environment) GetState() EnvironmentState {
|
||||
e.stateMu.Lock()
|
||||
defer e.stateMu.Unlock()
|
||||
return e.state
|
||||
}
|
||||
|
||||
func (e *environment) OnStateChanged(state EnvironmentState, callback func()) {
|
||||
func (e *environment) OnStateChanged(state EnvironmentState, callback func()) (cancel func() bool) {
|
||||
e.stateMu.Lock()
|
||||
defer e.stateMu.Unlock()
|
||||
|
||||
if e.state&state != 0 {
|
||||
go callback()
|
||||
return
|
||||
}
|
||||
|
||||
// add change listeners for all states, if there are multiple bits set
|
||||
for state > 0 {
|
||||
stateBit := EnvironmentState(bits.TrailingZeros32(uint32(state)))
|
||||
state &= (state - 1)
|
||||
e.stateChangeListeners[stateBit] = append(e.stateChangeListeners[stateBit], callback)
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
switch {
|
||||
case state < e.state:
|
||||
panic(fmt.Sprintf("test bug: OnStateChanged called with state %s which is < current state (%s)", state, e.sharedSecret))
|
||||
case state == e.state:
|
||||
e.stateChangeBlockers.Add(1)
|
||||
e.debugf("invoking callback for current state (state: %s, caller: %s:%d)", state.String(), file, line)
|
||||
go func() {
|
||||
defer func() {
|
||||
e.stateChangeBlockers.Done()
|
||||
}()
|
||||
callback()
|
||||
}()
|
||||
return func() bool { return false }
|
||||
default:
|
||||
canceled := &atomic.Bool{}
|
||||
e.stateChangeListeners[state] = append(e.stateChangeListeners[state], func() {
|
||||
if canceled.CompareAndSwap(false, true) {
|
||||
e.debugf("invoking state change callback (caller: %s:%d)", file, line)
|
||||
callback()
|
||||
}
|
||||
})
|
||||
return func() bool {
|
||||
e.debugf("stopped state change callback (state: %s, caller: %s:%d)", state.String(), file, line)
|
||||
return canceled.CompareAndSwap(false, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue