// package enabler contains a component that can be enabled and disabled dynamically
package enabler

import (
	"context"
	"errors"
	"sync"

	"github.com/pomerium/pomerium/internal/log"
)

var errCauseEnabler = errors.New("enabler")

// A Handler is a component with a RunEnabled function.
type Handler interface {
	RunEnabled(ctx context.Context) error
}

// HandlerFunc is a function run by the enabler.
type HandlerFunc func(ctx context.Context) error

func (f HandlerFunc) RunEnabled(ctx context.Context) error {
	return f(ctx)
}

// An Enabler enables or disables a component dynamically.
// When the Enabler is enabled, the Handler's RunEnabled will be called.
// If the Enabler is subsequently disabled the context passed to RunEnabled will be canceled.
// If the Enabler is subseqently enabled again, RunEnabled will be called again.
// Handlers should obey the context lifetime and be tolerant of RunEnabled
// being called multiple times. (not concurrently)
type Enabler interface {
	Run(ctx context.Context) error
	Enable()
	Disable()
}

type enabler struct {
	name    string
	handler Handler

	mu      sync.Mutex
	cancel  context.CancelCauseFunc
	enabled bool
}

// New creates a new Enabler.
func New(name string, handler Handler, enabled bool) Enabler {
	d := &enabler{
		name:    name,
		handler: handler,
		enabled: enabled,
		cancel:  func(_ error) {},
	}
	return d
}

// Run calls RunEnabled if enabled, otherwise it waits until enabled.
func (d *enabler) Run(ctx context.Context) error {
	for {
		err := d.runOrWaitForEnabled(ctx)
		// if we received any error but our own, exit with that error
		if !errors.Is(err, errCauseEnabler) {
			return err
		}
	}
}

func (d *enabler) runOrWaitForEnabled(ctx context.Context) error {
	d.mu.Lock()
	enabled := d.enabled
	ctx, d.cancel = context.WithCancelCause(ctx)
	d.mu.Unlock()

	// we're enabled so call RunEnabled. If Disabled is called it will cancel ctx.
	if enabled {
		log.Ctx(ctx).Info().Msgf("enabled %s", d.name)
		err := d.handler.RunEnabled(ctx)
		// if RunEnabled stopped because we canceled the context
		if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), errCauseEnabler) {
			log.Ctx(ctx).Info().Msgf("disabled %s", d.name)
			return errCauseEnabler
		}
		return err
	}

	// wait until Enabled is called
	<-ctx.Done()
	return context.Cause(ctx)
}

func (d *enabler) Enable() {
	d.mu.Lock()
	if !d.enabled {
		d.enabled = true
		d.cancel(errCauseEnabler)
	}
	d.mu.Unlock()
}

func (d *enabler) Disable() {
	d.mu.Lock()
	if d.enabled {
		d.enabled = false
		d.cancel(errCauseEnabler)
	}
	d.mu.Unlock()
}