package leaser_test

import (
	"context"
	"errors"
	"fmt"
	"testing"

	"github.com/stretchr/testify/assert"

	"github.com/pomerium/pomerium/internal/zero/leaser"
)

func TestRestart(t *testing.T) {
	t.Parallel()

	for i := 0; i < 20; i++ {
		t.Run(fmt.Sprintf("quit on error %d", i), func(t *testing.T) {
			t.Parallel()

			errExpected := errors.New("execFn error")
			count := 0
			err := leaser.RunWithRestart(context.Background(),
				func(context.Context) error {
					count++
					if count == 1 {
						return errExpected
					}
					return errors.New("execFn should not be called more than once")
				},
				func(ctx context.Context) error {
					<-ctx.Done()
					return ctx.Err()
				},
			)
			assert.ErrorIs(t, err, errExpected)
		})

		t.Run(fmt.Sprintf("quit on no error %d", i), func(t *testing.T) {
			t.Parallel()

			count := 0
			err := leaser.RunWithRestart(context.Background(),
				func(context.Context) error {
					count++
					if count == 1 {
						return nil
					}
					return errors.New("execFn should not be called more than once")
				},
				func(ctx context.Context) error {
					<-ctx.Done()
					return ctx.Err()
				},
			)
			assert.NoError(t, err)
		})

		t.Run(fmt.Sprintf("parent context canceled %d", i), func(t *testing.T) {
			t.Parallel()

			ctx, cancel := context.WithCancel(context.Background())
			t.Cleanup(cancel)

			ready := make(chan struct{})
			err := leaser.RunWithRestart(ctx,
				func(context.Context) error {
					<-ready
					cancel()
					return ctx.Err()
				},
				func(context.Context) error {
					close(ready)
					<-ctx.Done()
					return ctx.Err()
				},
			)
			assert.ErrorIs(t, err, context.Canceled)
		})

		t.Run(fmt.Sprintf("triggers restart %d", i), func(t *testing.T) {
			t.Parallel()

			ctx, cancel := context.WithCancel(context.Background())
			t.Cleanup(cancel)

			errExpected := errors.New("execFn error")
			count := 0
			ready := make(chan struct{})
			err := leaser.RunWithRestart(ctx,
				func(ctx context.Context) error {
					count++
					if count == 1 { // wait for us to be restarted
						close(ready)
						<-ctx.Done()
						return ctx.Err()
					} else if count == 2 { // just quit
						return errExpected
					}
					return errors.New("execFn should not be called more than twice")
				},
				func(ctx context.Context) error {
					<-ready
					return errors.New("restart required")
				},
			)
			assert.ErrorIs(t, err, errExpected)
		})
	}
}