// Package postgres contains an implementation of the storage.Backend backed by postgres.
package postgres

import (
	"context"
	"errors"
	"fmt"
	"sort"
	"strings"
	"time"

	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgconn"
	"github.com/jackc/pgx/v5/pgtype"
	"github.com/jackc/pgx/v5/pgxpool"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoregistry"
	"google.golang.org/protobuf/types/known/anypb"
	"google.golang.org/protobuf/types/known/fieldmaskpb"
	"google.golang.org/protobuf/types/known/timestamppb"

	"github.com/pomerium/pomerium/pkg/grpc/databroker"
	"github.com/pomerium/pomerium/pkg/grpc/registry"
	"github.com/pomerium/pomerium/pkg/protoutil"
	"github.com/pomerium/pomerium/pkg/storage"
)

var (
	schemaName              = "pomerium"
	migrationInfoTableName  = "migration_info"
	recordsTableName        = "records"
	recordChangesTableName  = "record_changes"
	recordChangeNotifyName  = "pomerium_record_change"
	recordOptionsTableName  = "record_options"
	leasesTableName         = "leases"
	serviceChangeNotifyName = "pomerium_service_change"
	servicesTableName       = "services"
)

type querier interface {
	Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error)
	Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
	QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}

func deleteChangesBefore(ctx context.Context, q querier, cutoff time.Time) error {
	_, err := q.Exec(ctx, `
		DELETE FROM `+schemaName+`.`+recordChangesTableName+`
		WHERE modified_at < $1
	`, cutoff)
	return err
}

func deleteExpiredServices(ctx context.Context, q querier, cutoff time.Time) (rowCount int64, err error) {
	cmd, err := q.Exec(ctx, `
		DELETE FROM `+schemaName+`.`+servicesTableName+`
		WHERE expires_at < $1
	`, cutoff)
	if err != nil {
		return 0, err
	}
	return cmd.RowsAffected(), nil
}

func dup(record *databroker.Record) *databroker.Record {
	return proto.Clone(record).(*databroker.Record)
}

func enforceOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
	if options == nil || options.Capacity == nil {
		return nil
	}

	_, err := q.Exec(ctx, `
		DELETE FROM `+schemaName+`.`+recordsTableName+`
		WHERE type=$1
		  AND id NOT IN (
			SELECT id
			FROM `+schemaName+`.`+recordsTableName+`
			WHERE type=$1
			ORDER BY version DESC
			LIMIT $2
		)
	`, recordType, options.GetCapacity())
	return err
}

func getLatestRecordVersion(ctx context.Context, q querier) (recordVersion uint64, err error) {
	err = q.QueryRow(ctx, `
		SELECT version
		FROM `+schemaName+`.`+recordChangesTableName+`
		ORDER BY version DESC
		LIMIT 1
	`).Scan(&recordVersion)
	if isNotFound(err) {
		err = nil
	}
	return recordVersion, err
}

func getNextChangedRecord(ctx context.Context, q querier, recordType string, afterRecordVersion uint64) (*databroker.Record, error) {
	var recordID string
	var version uint64
	var data []byte
	var modifiedAt pgtype.Timestamptz
	var deletedAt pgtype.Timestamptz
	query := `
			SELECT type, id, version, data, modified_at, deleted_at
			FROM ` + schemaName + `.` + recordChangesTableName + `
			WHERE version > $1
		`
	args := []any{afterRecordVersion}
	if recordType != "" {
		query += ` AND type = $2`
		args = append(args, recordType)
	}
	query += `
			ORDER BY version ASC
			LIMIT 1
		`
	err := q.QueryRow(ctx, query, args...).Scan(&recordType, &recordID, &version, &data, &modifiedAt, &deletedAt)
	if isNotFound(err) {
		return nil, storage.ErrNotFound
	} else if err != nil {
		return nil, fmt.Errorf("error querying next changed record: %w", err)
	}

	// data may be nil if a record is deleted
	var a *anypb.Any
	if len(data) != 0 {
		a, err = protoutil.UnmarshalAnyJSON(data)
		if isUnknownType(err) {
			a = protoutil.ToAny(protoutil.ToStruct(map[string]string{
				"id": recordID,
			}))
		} else if err != nil {
			return nil, fmt.Errorf("error unmarshaling changed record data: %w", err)
		}
	}

	return &databroker.Record{
		Version:    version,
		Type:       recordType,
		Id:         recordID,
		Data:       a,
		ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
		DeletedAt:  timestamppbFromTimestamptz(deletedAt),
	}, nil
}

func getOptions(ctx context.Context, q querier, recordType string) (*databroker.Options, error) {
	var capacity pgtype.Int8
	err := q.QueryRow(ctx, `
		SELECT capacity
		FROM `+schemaName+`.`+recordOptionsTableName+`
		WHERE type=$1
	`, recordType).Scan(&capacity)
	if err != nil && !isNotFound(err) {
		return nil, err
	}
	options := new(databroker.Options)
	if capacity.Valid {
		options.Capacity = proto.Uint64(uint64(capacity.Int64))
	}
	return options, nil
}

type lockMode string

const (
	lockModeNone   lockMode = ""
	lockModeUpdate lockMode = "FOR UPDATE"
)

func getRecord(
	ctx context.Context, q querier, recordType, recordID string, lockMode lockMode,
) (*databroker.Record, error) {
	var version uint64
	var data []byte
	var modifiedAt pgtype.Timestamptz
	err := q.QueryRow(ctx, `
		SELECT version, data, modified_at
		  FROM `+schemaName+`.`+recordsTableName+`
		 WHERE type=$1 AND id=$2 `+string(lockMode),
		recordType, recordID).Scan(&version, &data, &modifiedAt)
	if isNotFound(err) {
		return nil, storage.ErrNotFound
	} else if err != nil {
		return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
	}

	a, err := protoutil.UnmarshalAnyJSON(data)
	if isUnknownType(err) {
		return nil, storage.ErrNotFound
	} else if err != nil {
		return nil, fmt.Errorf("postgres: failed to unmarshal data: %w", err)
	}

	return &databroker.Record{
		Version:    version,
		Type:       recordType,
		Id:         recordID,
		Data:       a,
		ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
	}, nil
}

func listRecords(ctx context.Context, q querier, expr storage.FilterExpression, offset, limit int) ([]*databroker.Record, error) {
	args := []any{offset, limit}
	query := `
		SELECT type, id, version, data, modified_at
		FROM ` + schemaName + `.` + recordsTableName + `
	`
	if expr != nil {
		query += "WHERE "
		err := addFilterExpressionToQuery(&query, &args, expr)
		if err != nil {
			return nil, fmt.Errorf("postgres: failed to add filter to query: %w", err)
		}
	}
	query += `
		ORDER BY type, id
		LIMIT $2
		OFFSET $1
	`
	rows, err := q.Query(ctx, query, args...)
	if err != nil {
		return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
	}
	defer rows.Close()

	var records []*databroker.Record
	for rows.Next() {
		var recordType, id string
		var version uint64
		var data []byte
		var modifiedAt pgtype.Timestamptz
		err = rows.Scan(&recordType, &id, &version, &data, &modifiedAt)
		if err != nil {
			return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
		}

		a, err := protoutil.UnmarshalAnyJSON(data)
		if isUnknownType(err) {
			a = protoutil.ToAny(protoutil.ToStruct(map[string]string{
				"id": id,
			}))
		} else if err != nil {
			return nil, fmt.Errorf("postgres: failed to unmarshal data: %w", err)
		}

		records = append(records, &databroker.Record{
			Version:    version,
			Type:       recordType,
			Id:         id,
			Data:       a,
			ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
		})
	}
	err = rows.Err()
	if err != nil {
		return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
	}

	return records, nil
}

func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
	var services []*registry.Service

	query := `
		SELECT kind, endpoint
		FROM  ` + schemaName + `.` + servicesTableName + `
		ORDER BY kind, endpoint
	`
	rows, err := q.Query(ctx, query)
	if err != nil {
		return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
	}
	defer rows.Close()

	for rows.Next() {
		var kind, endpoint string
		err = rows.Scan(&kind, &endpoint)
		if err != nil {
			return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
		}

		services = append(services, &registry.Service{
			Kind:     registry.ServiceKind(registry.ServiceKind_value[kind]),
			Endpoint: endpoint,
		})
	}
	err = rows.Err()
	if err != nil {
		return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
	}

	return services, nil
}

func listTypes(ctx context.Context, q querier) ([]string, error) {
	query := `
		SELECT DISTINCT type
		FROM ` + schemaName + `.` + recordsTableName + `
	`
	rows, err := q.Query(ctx, query)
	if err != nil {
		return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
	}
	defer rows.Close()

	var types []string
	for rows.Next() {
		var recordType string
		err = rows.Scan(&recordType)
		if err != nil {
			return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
		}

		types = append(types, recordType)
	}
	err = rows.Err()
	if err != nil {
		return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
	}

	sort.Strings(types)
	return types, nil
}

func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) {
	tbl := schemaName + "." + leasesTableName
	expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl)))
	now := timestamptzFromTimestamppb(timestamppb.Now())
	err = q.QueryRow(ctx, `
		INSERT INTO `+tbl+` (name, id, expires_at)
		VALUES ($1, $2, $3)
		ON CONFLICT (name) DO UPDATE
		SET id=CASE WHEN `+tbl+`.expires_at<$4 OR `+tbl+`.id=$2 THEN $2 ELSE `+tbl+`.id END,
		    expires_at=CASE WHEN `+tbl+`.expires_at<$4 OR `+tbl+`.id=$2 THEN $3 ELSE `+tbl+`.expires_at END
		RETURNING `+tbl+`.id
	`, leaseName, leaseID, expiresAt, now).Scan(&leaseHolderID)
	return leaseHolderID, err
}

func putRecordAndChange(ctx context.Context, q querier, record *databroker.Record) error {
	data, err := jsonbFromAny(record.GetData())
	if err != nil {
		return fmt.Errorf("postgres: failed to convert any to json: %w", err)
	}

	modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
	deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt())
	indexCIDR := &pgtype.Text{Valid: false}
	if cidr := storage.GetRecordIndexCIDR(record.GetData()); cidr != nil {
		indexCIDR.String = cidr.String()
		indexCIDR.Valid = true
	}

	query := `
		WITH t1 AS (
			INSERT INTO ` + schemaName + `.` + recordChangesTableName + ` (type, id, data, modified_at, deleted_at)
			VALUES ($1, $2, $3, $4, $5)
			RETURNING *
		)
	`
	args := []any{
		record.GetType(), record.GetId(), data, modifiedAt, deletedAt,
	}
	if record.GetDeletedAt() == nil {
		query += `
			INSERT INTO ` + schemaName + `.` + recordsTableName + ` (type, id, version, data, modified_at, index_cidr)
			VALUES ($1, $2, (SELECT version FROM t1), $3, $4, $6)
			ON CONFLICT (type, id) DO UPDATE
			SET version=(SELECT version FROM t1), data=$3, modified_at=$4, index_cidr=$6
			RETURNING ` + schemaName + `.` + recordsTableName + `.version
		`
		args = append(args, indexCIDR)
	} else {
		query += `
			DELETE FROM ` + schemaName + `.` + recordsTableName + `
			WHERE type=$1 AND id=$2
			RETURNING ` + schemaName + `.` + recordsTableName + `.version
		`
	}
	err = q.QueryRow(ctx, query, args...).Scan(&record.Version)
	if err != nil && !isNotFound(err) {
		return fmt.Errorf("postgres: failed to execute query: %w", err)
	}

	return nil
}

// patchRecord updates specific fields of an existing record.
func patchRecord(
	ctx context.Context, p *pgxpool.Pool, record *databroker.Record, fields *fieldmaskpb.FieldMask,
) error {
	tx, err := p.Begin(ctx)
	if err != nil {
		return err
	}
	defer func() { _ = tx.Rollback(ctx) }()

	existing, err := getRecord(ctx, tx, record.GetType(), record.GetId(), lockModeUpdate)
	if isNotFound(err) {
		return storage.ErrNotFound
	} else if err != nil {
		return err
	}

	if err := storage.PatchRecord(existing, record, fields); err != nil {
		return err
	}

	if err := putRecordAndChange(ctx, tx, record); err != nil {
		return err
	}

	return tx.Commit(ctx)
}

func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
	query := `
		INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
		VALUES ($1, $2, $3)
		ON CONFLICT (kind, endpoint) DO UPDATE
		SET expires_at=$3
	`
	_, err := q.Exec(ctx, query, svc.GetKind().String(), svc.GetEndpoint(), expiresAt)
	return err
}

func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
	capacity := pgtype.Int8{}
	if options != nil && options.Capacity != nil {
		capacity.Int64 = int64(options.GetCapacity())
		capacity.Valid = true
	}

	_, err := q.Exec(ctx, `
		INSERT INTO `+schemaName+`.`+recordOptionsTableName+` (type, capacity)
		VALUES ($1, $2)
		ON CONFLICT (type) DO UPDATE
		SET capacity=$2
	`, recordType, capacity)
	return err
}

func signalRecordChange(ctx context.Context, q querier) error {
	_, err := q.Exec(ctx, `NOTIFY `+recordChangeNotifyName)
	return err
}

func signalServiceChange(ctx context.Context, q querier) error {
	_, err := q.Exec(ctx, `NOTIFY `+serviceChangeNotifyName)
	return err
}

func jsonbFromAny(any *anypb.Any) ([]byte, error) {
	if any == nil {
		return nil, nil
	}

	return protojson.Marshal(any)
}

func timestamppbFromTimestamptz(ts pgtype.Timestamptz) *timestamppb.Timestamp {
	if !ts.Valid {
		return nil
	}
	return timestamppb.New(ts.Time)
}

func timestamptzFromTimestamppb(ts *timestamppb.Timestamp) pgtype.Timestamptz {
	if !ts.IsValid() {
		return pgtype.Timestamptz{}
	}
	return pgtype.Timestamptz{Time: ts.AsTime(), Valid: true}
}

func isNotFound(err error) bool {
	return errors.Is(err, pgx.ErrNoRows) || errors.Is(err, storage.ErrNotFound)
}

func isUnknownType(err error) bool {
	if err == nil {
		return false
	}

	return errors.Is(err, protoregistry.NotFound) ||
		strings.Contains(err.Error(), "unable to resolve") // protojson doesn't wrap errors so check for the string
}