package postgres

import (
	"context"
	"errors"
	"fmt"
	"net"
	"sync"
	"time"

	"github.com/cenkalti/backoff/v4"
	"github.com/jackc/pgx/v5/pgxpool"
	"google.golang.org/protobuf/types/known/fieldmaskpb"
	"google.golang.org/protobuf/types/known/timestamppb"

	"github.com/pomerium/pomerium/internal/log"
	"github.com/pomerium/pomerium/internal/signal"
	"github.com/pomerium/pomerium/pkg/contextutil"
	"github.com/pomerium/pomerium/pkg/grpc/databroker"
	"github.com/pomerium/pomerium/pkg/storage"
)

// Backend is a storage Backend implemented with Postgres.
type Backend struct {
	cfg             *config
	dsn             string
	onRecordChange  *signal.Signal
	onServiceChange *signal.Signal

	closeCtx context.Context
	close    context.CancelFunc

	mu            sync.RWMutex
	pool          *pgxpool.Pool
	serverVersion uint64
}

// New creates a new Backend.
func New(dsn string, options ...Option) *Backend {
	backend := &Backend{
		cfg:             getConfig(options...),
		dsn:             dsn,
		onRecordChange:  signal.New(),
		onServiceChange: signal.New(),
	}
	backend.closeCtx, backend.close = context.WithCancel(context.Background())

	go backend.doPeriodically(func(ctx context.Context) error {
		_, pool, err := backend.init(ctx)
		if err != nil {
			return err
		}

		return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry))
	}, time.Minute)

	go backend.doPeriodically(func(ctx context.Context) error {
		_, pool, err := backend.init(ctx)
		if err != nil {
			return err
		}

		rowCount, err := deleteExpiredServices(ctx, pool, time.Now())
		if err != nil {
			return err
		}
		if rowCount > 0 {
			err = signalServiceChange(ctx, pool)
			if err != nil {
				return err
			}
		}

		return nil
	}, backend.cfg.registryTTL/2)

	go backend.doPeriodically(func(ctx context.Context) error {
		return backend.listenForNotifications(ctx)
	}, time.Millisecond*100)

	return backend
}

// Close closes the underlying database connection.
func (backend *Backend) Close() error {
	backend.mu.Lock()
	defer backend.mu.Unlock()

	backend.close()

	if backend.pool != nil {
		backend.pool.Close()
		backend.pool = nil
	}
	return nil
}

// Get gets a record from the database.
func (backend *Backend) Get(
	ctx context.Context,
	recordType, recordID string,
) (*databroker.Record, error) {
	ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	_, conn, err := backend.init(ctx)
	if err != nil {
		return nil, err
	}

	return getRecord(ctx, conn, recordType, recordID, lockModeNone)
}

// GetOptions returns the options for the given record type.
func (backend *Backend) GetOptions(
	ctx context.Context,
	recordType string,
) (*databroker.Options, error) {
	ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	_, conn, err := backend.init(ctx)
	if err != nil {
		return nil, err
	}

	return getOptions(ctx, conn, recordType)
}

// Lease attempts to acquire a lease for the given name.
func (backend *Backend) Lease(
	ctx context.Context,
	leaseName, leaseID string,
	ttl time.Duration,
) (acquired bool, err error) {
	ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	_, conn, err := backend.init(ctx)
	if err != nil {
		return false, err
	}

	leaseHolderID, err := maybeAcquireLease(ctx, conn, leaseName, leaseID, ttl)
	if err != nil {
		return false, err
	}

	return leaseHolderID == leaseID, nil
}

// ListTypes lists the record types.
func (backend *Backend) ListTypes(ctx context.Context) ([]string, error) {
	ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	_, conn, err := backend.init(ctx)
	if err != nil {
		return nil, err
	}

	return listTypes(ctx, conn)
}

// Put puts a record into Postgres.
func (backend *Backend) Put(
	ctx context.Context,
	records []*databroker.Record,
) (serverVersion uint64, err error) {
	ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	serverVersion, pool, err := backend.init(ctx)
	if err != nil {
		return 0, err
	}

	now := timestamppb.Now()

	// add all the records
	recordTypes := map[string]struct{}{}
	for i, record := range records {
		recordTypes[record.GetType()] = struct{}{}

		record = dup(record)
		record.ModifiedAt = now
		err := putRecordAndChange(ctx, pool, record)
		if err != nil {
			return serverVersion, fmt.Errorf("storage/postgres: error saving record: %w", err)
		}
		records[i] = record
	}

	// enforce options for each record type
	for recordType := range recordTypes {
		options, err := getOptions(ctx, pool, recordType)
		if err != nil {
			return serverVersion, fmt.Errorf("storage/postgres: error getting options: %w", err)
		}
		err = enforceOptions(ctx, pool, recordType, options)
		if err != nil {
			return serverVersion, fmt.Errorf("storage/postgres: error enforcing options: %w", err)
		}
	}

	err = signalRecordChange(ctx, pool)
	return serverVersion, err
}

// Patch updates specific fields of existing records in Postgres.
func (backend *Backend) Patch(
	ctx context.Context,
	records []*databroker.Record,
	fields *fieldmaskpb.FieldMask,
) (uint64, []*databroker.Record, error) {
	ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	serverVersion, pool, err := backend.init(ctx)
	if err != nil {
		return serverVersion, nil, err
	}

	patchedRecords := make([]*databroker.Record, 0, len(records))

	now := timestamppb.Now()

	for _, record := range records {
		record = dup(record)
		record.ModifiedAt = now
		err := patchRecord(ctx, pool, record, fields)
		if storage.IsNotFound(err) {
			continue
		} else if err != nil {
			err = fmt.Errorf("storage/postgres: error patching record %q of type %q: %w",
				record.GetId(), record.GetType(), err)
			return serverVersion, patchedRecords, err
		}
		patchedRecords = append(patchedRecords, record)
	}

	err = signalRecordChange(ctx, pool)
	return serverVersion, patchedRecords, err
}

// SetOptions sets the options for the given record type.
func (backend *Backend) SetOptions(
	ctx context.Context,
	recordType string,
	options *databroker.Options,
) error {
	ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	_, conn, err := backend.init(ctx)
	if err != nil {
		return err
	}

	return setOptions(ctx, conn, recordType, options)
}

// Sync syncs the records.
func (backend *Backend) Sync(
	ctx context.Context,
	recordType string,
	serverVersion, recordVersion uint64,
) (storage.RecordStream, error) {
	// the original ctx will be used for the stream, this ctx used for pre-stream calls
	callCtx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	currentServerVersion, _, err := backend.init(callCtx)
	if err != nil {
		return nil, err
	}
	if currentServerVersion != serverVersion {
		return nil, storage.ErrInvalidServerVersion
	}

	return newChangedRecordStream(ctx, backend, recordType, recordVersion), nil
}

// SyncLatest syncs the latest version of each record.
func (backend *Backend) SyncLatest(
	ctx context.Context,
	recordType string,
	expr storage.FilterExpression,
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
	// the original ctx will be used for the stream, this ctx used for pre-stream calls
	callCtx, cancel := contextutil.Merge(ctx, backend.closeCtx)
	defer cancel()

	serverVersion, pool, err := backend.init(callCtx)
	if err != nil {
		return 0, 0, nil, err
	}

	recordVersion, err = getLatestRecordVersion(callCtx, pool)
	if err != nil {
		return 0, 0, nil, err
	}

	if recordType != "" {
		f := storage.EqualsFilterExpression{
			Fields: []string{"type"},
			Value:  recordType,
		}
		if expr != nil {
			expr = storage.AndFilterExpression{expr, f}
		} else {
			expr = f
		}
	}

	stream = newRecordStream(ctx, backend, expr)
	return serverVersion, recordVersion, stream, nil
}

func (backend *Backend) init(ctx context.Context) (serverVersion uint64, pool *pgxpool.Pool, err error) {
	backend.mu.RLock()
	serverVersion = backend.serverVersion
	pool = backend.pool
	backend.mu.RUnlock()

	if pool != nil {
		return serverVersion, pool, nil
	}

	backend.mu.Lock()
	defer backend.mu.Unlock()

	// double-checked locking, might have already initialized, so just return
	serverVersion = backend.serverVersion
	pool = backend.pool
	if pool != nil {
		return serverVersion, pool, nil
	}

	config, err := ParseConfig(backend.dsn)
	if err != nil {
		return serverVersion, nil, err
	}

	pool, err = pgxpool.NewWithConfig(context.Background(), config)
	if err != nil {
		return serverVersion, nil, err
	}

	tx, err := pool.Begin(ctx)
	if err != nil {
		return serverVersion, nil, err
	}

	serverVersion, err = migrate(ctx, tx)
	if err != nil {
		_ = tx.Rollback(ctx)
		return serverVersion, nil, err
	}

	err = tx.Commit(ctx)
	if err != nil {
		_ = tx.Rollback(ctx)
		return serverVersion, nil, err
	}

	backend.serverVersion = serverVersion
	backend.pool = pool
	return serverVersion, pool, nil
}

func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur time.Duration) {
	ctx := backend.closeCtx

	ticker := time.NewTicker(dur)
	defer ticker.Stop()

	bo := backoff.NewExponentialBackOff()
	bo.MaxElapsedTime = 0

	for {
		err := f(ctx)
		if err == nil {
			bo.Reset()
			select {
			case <-backend.closeCtx.Done():
				return
			case <-ticker.C:
			}
		} else {
			if !errors.Is(err, context.Canceled) {
				log.Error(ctx).Err(err).Msg("storage/postgres")
			}
			select {
			case <-backend.closeCtx.Done():
				return
			case <-time.After(bo.NextBackOff()):
			}
		}
	}
}

func (backend *Backend) listenForNotifications(ctx context.Context) error {
	_, pool, err := backend.init(ctx)
	if err != nil {
		return fmt.Errorf("error initializing pool for notifications: %w", err)
	}

	poolConn, err := pool.Acquire(ctx)
	if err != nil {
		return fmt.Errorf("error acquiring connection from pool for notifications: %w", err)
	}

	// hijack the connection so the pool can be left for short-lived queries
	// and so that LISTENs don't leak to other queries
	conn := poolConn.Hijack()
	defer conn.Close(ctx)

	for _, ch := range []string{recordChangeNotifyName, serviceChangeNotifyName} {
		_, err = conn.Exec(ctx, `LISTEN `+ch)
		if err != nil {
			return fmt.Errorf("error listening on channel %s for notifications: %w", ch, err)
		}
	}

	// for each notification broadcast the signal
	for {
		n, err := conn.WaitForNotification(ctx)
		if err != nil {
			// on error we'll close the connection to stop listening
			return fmt.Errorf("error receiving notification: %w", err)
		}

		switch n.Channel {
		case recordChangeNotifyName:
			backend.onRecordChange.Broadcast(ctx)
		case serviceChangeNotifyName:
			backend.onServiceChange.Broadcast(ctx)
		}
	}
}

// ParseConfig parses a DSN into a pgxpool.Config.
func ParseConfig(dsn string) (*pgxpool.Config, error) {
	config, err := pgxpool.ParseConfig(dsn)
	if err != nil {
		return nil, err
	}
	config.ConnConfig.LookupFunc = lookup
	return config, nil
}

func lookup(ctx context.Context, host string) (addrs []string, err error) {
	addrs, err = net.DefaultResolver.LookupHost(ctx, host)
	// ignore no such host errors
	if e := new(net.DNSError); errors.As(err, &e) && e.IsNotFound {
		addrs = nil
		err = nil
	}
	return addrs, err
}