package manager

import (
	"context"
	"fmt"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"google.golang.org/protobuf/proto"

	"github.com/pomerium/pomerium/internal/directory"
	"github.com/pomerium/pomerium/pkg/grpc/databroker"
	"github.com/pomerium/pomerium/pkg/grpc/session"
	"github.com/pomerium/pomerium/pkg/grpc/user"
	"github.com/pomerium/pomerium/pkg/protoutil"
)

type mockProvider struct {
	user       func(ctx context.Context, userID, accessToken string) (*directory.User, error)
	userGroups func(ctx context.Context) ([]*directory.Group, []*directory.User, error)
}

func (mock mockProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
	return mock.user(ctx, userID, accessToken)
}

func (mock mockProvider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
	return mock.userGroups(ctx)
}

func TestManager_onUpdateRecords(t *testing.T) {
	ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
	defer clearTimeout()

	now := time.Now()

	mgr := New(
		WithDirectoryProvider(mockProvider{}),
		WithGroupRefreshInterval(time.Hour),
		WithNow(func() time.Time {
			return now
		}),
	)
	mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing

	mgr.onUpdateRecords(ctx, updateRecordsMessage{
		records: []*databroker.Record{
			mkRecord(&directory.Group{Id: "group1", Name: "group 1", Email: "group1@example.com"}),
			mkRecord(&directory.User{Id: "user1", DisplayName: "user 1", Email: "user1@example.com", GroupIds: []string{"group1s"}}),
			mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
			mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
		},
	})

	assert.NotNil(t, mgr.directoryGroups["group1"])
	assert.NotNil(t, mgr.directoryUsers["user1"])
	if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {

	}
	if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
		tm, id := mgr.userScheduler.Next()
		assert.Equal(t, now.Add(time.Hour), tm)
		assert.Equal(t, "user1", id)
	}

}

func TestManager_refreshDirectoryUserGroups(t *testing.T) {
	ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
	defer clearTimeout()

	t.Run("backoff", func(t *testing.T) {
		cnt := 0
		mgr := New(
			WithDirectoryProvider(mockProvider{
				userGroups: func(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
					cnt++
					switch cnt {
					case 1:
						return nil, nil, fmt.Errorf("error 1")
					case 2:
						return nil, nil, fmt.Errorf("error 2")
					}
					return nil, nil, nil
				},
			}),
			WithGroupRefreshInterval(time.Hour),
		)
		mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing

		dur1 := mgr.refreshDirectoryUserGroups(ctx)
		dur2 := mgr.refreshDirectoryUserGroups(ctx)
		dur3 := mgr.refreshDirectoryUserGroups(ctx)

		assert.Greater(t, dur2, dur1)
		assert.Greater(t, dur3, dur2)
		assert.Equal(t, time.Hour, dur3)
	})
}

func mkRecord(msg recordable) *databroker.Record {
	any := protoutil.NewAny(msg)
	return &databroker.Record{
		Type: any.GetTypeUrl(),
		Id:   msg.GetId(),
		Data: any,
	}
}

type recordable interface {
	proto.Message
	GetId() string
}