package controller_test

import (
	"context"
	"encoding/base64"
	"errors"
	"testing"

	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"

	"github.com/pomerium/pomerium/config"
	"github.com/pomerium/pomerium/internal/zero/controller"
	"github.com/pomerium/pomerium/pkg/cryptutil"
	"github.com/pomerium/pomerium/pkg/grpc/databroker"
)

type mockConfigSource struct {
	mock.Mock
	config.Source
}

func (s *mockConfigSource) GetConfig() *config.Config {
	args := s.Called()
	return args.Get(0).(*config.Config)
}

func (s *mockConfigSource) OnConfigChange(ctx context.Context, cl config.ChangeListener) {
	s.Called(ctx, cl)
}

func TestDatabrokerRestart(t *testing.T) {
	t.Parallel()

	newConfig := func() *config.Config {
		return &config.Config{
			Options: &config.Options{
				SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey()),
			},
			GRPCPort: ":12345",
		}
	}

	t.Run("no error", func(t *testing.T) {
		t.Parallel()

		src := new(mockConfigSource)
		src.On("OnConfigChange", mock.Anything, mock.Anything).Once()
		src.On("GetConfig").Once().Return(newConfig())

		ctx := context.Background()
		r := controller.NewDatabrokerRestartRunner(ctx, src)
		defer r.Close()

		err := r.Run(ctx, func(_ context.Context, _ databroker.DataBrokerServiceClient) error {
			return nil
		})
		require.NoError(t, err)
	})
	t.Run("error, retry", func(t *testing.T) {
		t.Parallel()

		src := new(mockConfigSource)
		src.On("OnConfigChange", mock.Anything, mock.Anything).Once()
		src.On("GetConfig").Once().Return(newConfig())

		ctx := context.Background()
		r := controller.NewDatabrokerRestartRunner(ctx, src)
		defer r.Close()

		count := 0
		err := r.Run(ctx, func(_ context.Context, _ databroker.DataBrokerServiceClient) error {
			count++
			if count == 1 {
				return errors.New("simulated error")
			}
			return nil
		})
		require.NoError(t, err)
		require.Equal(t, 2, count)
	})
	t.Run("config changed, execution restarted", func(t *testing.T) {
		t.Parallel()

		src := new(mockConfigSource)
		var cl config.ChangeListener
		src.On("OnConfigChange", mock.Anything, mock.Anything).Once().Run(func(args mock.Arguments) {
			cl = args.Get(1).(config.ChangeListener)
		})
		src.On("GetConfig").Once().Return(newConfig())

		ctx := context.Background()
		r := controller.NewDatabrokerRestartRunner(ctx, src)
		defer r.Close()

		count := 0
		var clients [2]databroker.DataBrokerServiceClient
		err := r.Run(ctx, func(ctx context.Context, client databroker.DataBrokerServiceClient) error {
			clients[count] = client
			count++
			if count == 1 {
				cl(context.Background(), newConfig())
				<-ctx.Done()
				require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged)
				return ctx.Err()
			}
			return nil
		})
		require.NoError(t, err)
		require.Equal(t, 2, count)
		require.NotEqual(t, clients[0], clients[1])
	})
}