package authorize

import (
	"context"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/proto"

	"github.com/pomerium/pomerium/config"
	"github.com/pomerium/pomerium/pkg/grpc/databroker"
	"github.com/pomerium/pomerium/pkg/grpc/session"
	"github.com/pomerium/pomerium/pkg/grpcutil"
	"github.com/pomerium/pomerium/pkg/protoutil"
)

func TestAuthorize_forceSyncToVersion(t *testing.T) {
	o := &config.Options{
		AuthenticateURLString: "https://authN.example.com",
		DataBrokerURLString:   "https://databroker.example.com",
		SharedKey:             "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
		Policies:              testPolicies(t),
	}
	a, err := New(&config.Config{Options: o})
	require.NoError(t, err)

	a.store.UpdateRecord(1, &databroker.Record{
		Version: 1,
	})
	t.Run("ready", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		cancel()
		assert.True(t, a.forceSyncToVersion(ctx, 1, 1))
	})
	t.Run("not ready", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		cancel()
		assert.False(t, a.forceSyncToVersion(ctx, 1, 2))
	})
	t.Run("becomes ready", func(t *testing.T) {
		ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
		defer clearTimeout()

		go func() {
			<-time.After(time.Millisecond * 100)
			a.store.UpdateRecord(1, &databroker.Record{
				Version: 2,
			})
		}()
		assert.True(t, a.forceSyncToVersion(ctx, 1, 2))
	})
}

func TestAuthorize_waitForRecordSync(t *testing.T) {
	ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
	defer clearTimeout()

	o := &config.Options{
		AuthenticateURLString: "https://authN.example.com",
		DataBrokerURLString:   "https://databroker.example.com",
		SharedKey:             "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
		Policies:              testPolicies(t),
	}
	t.Run("skip if exists", func(t *testing.T) {
		a, err := New(&config.Config{Options: o})
		require.NoError(t, err)

		a.store.UpdateRecord(0, newRecord(&session.Session{
			Id: "SESSION_ID",
		}))
		a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
			get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
				panic("should never be called")
			},
		}
		a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
	})
	t.Run("skip if not found", func(t *testing.T) {
		a, err := New(&config.Config{Options: o})
		require.NoError(t, err)

		callCount := 0
		a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
			get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
				callCount++
				return nil, status.Error(codes.NotFound, "not found")
			},
		}
		a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
		assert.Equal(t, 1, callCount, "should be called once")
	})
	t.Run("poll", func(t *testing.T) {
		a, err := New(&config.Config{Options: o})
		require.NoError(t, err)

		callCount := 0
		a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
			get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
				callCount++
				switch callCount {
				case 1:
					s := &session.Session{Id: "SESSION_ID"}
					a.store.UpdateRecord(0, newRecord(s))
					return &databroker.GetResponse{Record: newRecord(s)}, nil
				default:
					panic("should never be called")
				}
			},
		}
		a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
	})
	t.Run("timeout", func(t *testing.T) {
		a, err := New(&config.Config{Options: o})
		require.NoError(t, err)

		tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100)
		defer clearTimeout()

		callCount := 0
		a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
			get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
				callCount++
				s := &session.Session{Id: "SESSION_ID"}
				return &databroker.GetResponse{Record: newRecord(s)}, nil
			},
		}
		a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
		assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism
	})
}

type storableMessage interface {
	proto.Message
	GetId() string
}

func newRecord(msg storableMessage) *databroker.Record {
	any := protoutil.NewAny(msg)
	return &databroker.Record{
		Version: 1,
		Type:    any.GetTypeUrl(),
		Id:      msg.GetId(),
		Data:    any,
	}
}