package retry_test

import (
	"context"
	"errors"
	"fmt"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"github.com/pomerium/pomerium/internal/retry"
)

func TestRetry(t *testing.T) {
	t.Parallel()

	ctx := context.Background()
	limit := retry.WithMaxInterval(time.Second * 5)

	t.Run("no error", func(t *testing.T) {
		t.Parallel()

		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			return nil
		}, limit)
		require.NoError(t, err)
	})

	t.Run("eventually succeeds", func(t *testing.T) {
		t.Parallel()
		i := 0
		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			if i++; i > 2 {
				return nil
			}
			return fmt.Errorf("transient %d", i)
		}, limit)
		require.NoError(t, err)
	})

	t.Run("eventually fails", func(t *testing.T) {
		t.Parallel()
		i := 0
		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			if i++; i > 2 {
				return retry.NewTerminalError(errors.New("the end"))
			}
			return fmt.Errorf("transient %d", i)
		})
		require.Error(t, err)
	})

	t.Run("context canceled", func(t *testing.T) {
		t.Parallel()
		ctx, cancel := context.WithCancel(ctx)
		cancel()
		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			return fmt.Errorf("retry")
		})
		require.Error(t, err)
	})

	t.Run("context canceled after retry", func(t *testing.T) {
		t.Parallel()
		ctx, cancel := context.WithCancel(ctx)
		t.Cleanup(cancel)
		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			cancel()
			return fmt.Errorf("retry")
		})
		require.Error(t, err)
	})

	t.Run("success after watch hook", func(t *testing.T) {
		t.Parallel()
		ch := make(chan struct{}, 1)
		ch <- struct{}{}
		ok := false
		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			if ok {
				return nil
			}
			return fmt.Errorf("retry")
		}, retry.WithWatch("watch", ch, func(_ context.Context) error {
			ok = true
			return nil
		}), limit)
		require.NoError(t, err)
	})

	t.Run("success after watch hook retried", func(t *testing.T) {
		t.Parallel()
		ch := make(chan struct{}, 1)
		ch <- struct{}{}
		ok := false
		i := 0
		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			if ok {
				return nil
			}
			return fmt.Errorf("retry test")
		}, retry.WithWatch("watch", ch, func(_ context.Context) error {
			if i++; i > 1 {
				ok = true
				return nil
			}
			return fmt.Errorf("retry watch")
		}), limit)
		require.NoError(t, err)
	})

	t.Run("watch hook fails", func(t *testing.T) {
		t.Parallel()
		ch := make(chan struct{}, 1)
		ch <- struct{}{}
		err := retry.Retry(ctx, "test", func(_ context.Context) error {
			return fmt.Errorf("retry")
		}, retry.WithWatch("watch", ch, func(_ context.Context) error {
			return retry.NewTerminalError(fmt.Errorf("watch"))
		}), limit)
		require.Error(t, err)
	})
}