package testutil

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io/ioutil"
	"path/filepath"
	"time"

	"github.com/go-redis/redis/v8"
	"github.com/ory/dockertest/v3"

	"github.com/pomerium/pomerium/pkg/cryptutil"
)

const maxWait = time.Minute

// WithTestRedis creates a test a test redis instance using docker.
func WithTestRedis(useTLS bool, handler func(rawURL string) error) error {
	ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait)
	defer clearTimeout()

	// uses a sensible default on windows (tcp/http) and linux/osx (socket)
	pool, err := dockertest.NewPool("")
	if err != nil {
		return err
	}

	opts := &dockertest.RunOptions{
		Repository: "redis",
		Tag:        "6",
	}
	scheme := "redis"
	if useTLS {
		opts.Mounts = []string{
			filepath.Join(TestDataRoot(), "tls") + ":/tls",
		}
		opts.Cmd = []string{
			"--port", "0",
			"--tls-port", "6379",
			"--tls-cert-file", "/tls/redis.crt",
			"--tls-key-file", "/tls/redis.key",
			"--tls-ca-cert-file", "/tls/ca.crt",
		}
		scheme = "rediss"
	}

	resource, err := pool.RunWithOptions(opts)
	if err != nil {
		return err
	}
	_ = resource.Expire(uint(maxWait.Seconds()))

	redisURL := fmt.Sprintf("%s://%s/0", scheme, resource.GetHostPort("6379/tcp"))
	if err := pool.Retry(func() error {
		options, err := redis.ParseURL(redisURL)
		if err != nil {
			return err
		}
		if useTLS {
			options.TLSConfig = RedisTLSConfig()
		}

		client := redis.NewClient(options)
		defer client.Close()

		return client.Ping(ctx).Err()
	}); err != nil {
		_ = pool.Purge(resource)
		return err
	}

	e := handler(redisURL)

	if err := pool.Purge(resource); err != nil {
		return err
	}

	return e
}

// RedisTLSConfig returns the TLS Config to use with redis.
func RedisTLSConfig() *tls.Config {
	cert, err := cryptutil.CertificateFromFile(
		filepath.Join(TestDataRoot(), "tls", "redis.crt"),
		filepath.Join(TestDataRoot(), "tls", "redis.key"),
	)
	if err != nil {
		panic(err)
	}
	caCertPool := x509.NewCertPool()
	caCert, err := ioutil.ReadFile(filepath.Join(TestDataRoot(), "tls", "ca.crt"))
	if err != nil {
		panic(err)
	}
	caCertPool.AppendCertsFromPEM(caCert)
	tlsConfig := &tls.Config{
		RootCAs:      caCertPool,
		Certificates: []tls.Certificate{*cert},
		MinVersion:   tls.VersionTLS12,
	}
	return tlsConfig
}