package testutil import ( "context" "crypto/tls" "crypto/x509" "fmt" "net" "os" "path/filepath" "strings" "time" "github.com/go-redis/redis/v8" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/pomerium/pomerium/pkg/cryptutil" ) const maxWait = 20 * time.Minute // WithTestRedis creates a test a test redis instance using docker. func WithTestRedis(useTLS bool, handler func(rawURL string) error) error { ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait) defer clearTimeout() // uses a sensible default on windows (tcp/http) and linux/osx (socket) pool, err := dockertest.NewPool("") if err != nil { return err } opts := &dockertest.RunOptions{ Repository: "redis", Tag: "6", } scheme := "redis" if useTLS { opts.Mounts = []string{ filepath.Join(TestDataRoot(), "tls") + ":/tls", } opts.Cmd = []string{ "--port", "0", "--tls-port", "6379", "--tls-cert-file", "/tls/redis.crt", "--tls-key-file", "/tls/redis.key", "--tls-ca-cert-file", "/tls/ca.crt", } scheme = "rediss" } resource, err := pool.RunWithOptions(opts) if err != nil { return err } _ = resource.Expire(uint(maxWait.Seconds())) redisURL := fmt.Sprintf("%s://%s/0", scheme, resource.GetHostPort("6379/tcp")) if err := pool.Retry(func() error { options, err := redis.ParseURL(redisURL) if err != nil { return err } if useTLS { options.TLSConfig = RedisTLSConfig() } client := redis.NewClient(options) defer client.Close() return client.Ping(ctx).Err() }); err != nil { _ = pool.Purge(resource) return err } e := handler(redisURL) if err := pool.Purge(resource); err != nil { return err } return e } // WithTestRedisCluster creates a new redis cluster 3 node cluster. func WithTestRedisCluster(handler func(rawURL string) error) error { ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait) defer clearTimeout() // uses a sensible default on windows (tcp/http) and linux/osx (socket) pool, err := dockertest.NewPool("") if err != nil { return err } redises := make([]*dockertest.Resource, 3) for i := range redises { conf := "cluster-enabled yes\ncluster-config-file nodes.conf" r, err := pool.RunWithOptions(&dockertest.RunOptions{ Hostname: fmt.Sprintf("redis%d", i), Repository: "redis", Tag: "6", Entrypoint: []string{ "/bin/bash", "-c", `echo "` + conf + `" >/tmp/redis.conf && chmod 0777 /tmp/redis.conf && exec docker-entrypoint.sh /tmp/redis.conf`, }, ExposedPorts: []string{ "6379/tcp", "26379/tcp", }, }) if err != nil { return err } defer r.Close() _ = r.Expire(uint(maxWait.Seconds())) go func() { _ = pool.Client.Logs(docker.LogsOptions{ Context: ctx, Stderr: true, Stdout: true, Follow: true, Timestamps: true, Container: r.Container.ID, OutputStream: os.Stderr, ErrorStream: os.Stderr, }) }() redises[i] = r } addrs := make([]string, 3) for i, r := range redises { addrs[i] = net.JoinHostPort( r.Container.NetworkSettings.IPAddress, "6379", ) } for _, addr := range addrs { err := pool.Retry(func() error { options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", addr)) if err != nil { return err } client := redis.NewClient(options) defer client.Close() return client.Ping(ctx).Err() }) if err != nil { return err } } // join the nodes to the cluster err = bootstrapRedisCluster(ctx, redises) if err != nil { return err } e := handler(fmt.Sprintf("redis+cluster://%s", strings.Join(addrs, ","))) for _, r := range redises { if err := pool.Purge(r); err != nil { return err } } return e } // WithTestRedisSentinel creates a new redis sentinel 3 node cluster. func WithTestRedisSentinel(handler func(rawURL string) error) error { ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait) defer clearTimeout() // uses a sensible default on windows (tcp/http) and linux/osx (socket) pool, err := dockertest.NewPool("") if err != nil { return err } redises := make([]*dockertest.Resource, 3) for i := range redises { r, err := pool.RunWithOptions(&dockertest.RunOptions{ Hostname: fmt.Sprintf("redis%d", i), Repository: "redis", Tag: "6", ExposedPorts: []string{ "6379/tcp", "26379/tcp", }, }) if err != nil { return err } defer r.Close() _ = r.Expire(uint(maxWait.Seconds())) redises[i] = r } sentinels := make([]*dockertest.Resource, len(redises)) for i := range sentinels { conf := fmt.Sprintf("sentinel monitor master %s 6379 %d\n", redises[0].Container.NetworkSettings.IPAddress, len(redises)) if i > 0 { conf += fmt.Sprintf("sentinel known-slave master %s 6379\n", redises[i].Container.NetworkSettings.IPAddress) } r, err := pool.RunWithOptions(&dockertest.RunOptions{ Hostname: fmt.Sprintf("sentineld%d", i), Repository: "redis", Tag: "6", Entrypoint: []string{ "/bin/bash", "-c", `echo "` + conf + `" >/tmp/sentinel.conf && chmod 0777 /tmp/sentinel.conf && exec docker-entrypoint.sh /tmp/sentinel.conf --sentinel`, }, ExposedPorts: []string{ "6379/tcp", "26379/tcp", }, }) if err != nil { return err } defer r.Close() _ = r.Expire(uint(maxWait.Seconds())) go func() { _ = pool.Client.Logs(docker.LogsOptions{ Context: ctx, Stderr: true, Stdout: true, Follow: true, Timestamps: true, Container: r.Container.ID, OutputStream: os.Stderr, ErrorStream: os.Stderr, }) }() sentinels[i] = r } addrs := make([]string, len(sentinels)) for i, r := range sentinels { addrs[i] = net.JoinHostPort( r.Container.NetworkSettings.IPAddress, "26379", ) } redisURL := fmt.Sprintf("redis+sentinel://%s/master/0", strings.Join(addrs, ",")) for _, r := range redises { addr := net.JoinHostPort( r.Container.NetworkSettings.IPAddress, "6379", ) if err := pool.Retry(func() error { options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", addr)) if err != nil { return err } client := redis.NewClient(options) defer client.Close() return client.Ping(ctx).Err() }); err != nil { _ = pool.Purge(r) return err } } for _, r := range sentinels { if err := pool.Retry(func() error { options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", r.GetHostPort("26379/tcp"))) if err != nil { return err } client := redis.NewClient(options) defer client.Close() return client.Ping(ctx).Err() }); err != nil { _ = pool.Purge(r) return err } } e := handler(redisURL) for _, r := range append(redises, sentinels...) { if err := pool.Purge(r); err != nil { return err } } return e } // RedisTLSConfig returns the TLS Config to use with redis. func RedisTLSConfig() *tls.Config { cert, err := cryptutil.CertificateFromFile( filepath.Join(TestDataRoot(), "tls", "redis.crt"), filepath.Join(TestDataRoot(), "tls", "redis.key"), ) if err != nil { panic(err) } caCertPool := x509.NewCertPool() caCert, err := os.ReadFile(filepath.Join(TestDataRoot(), "tls", "ca.crt")) if err != nil { panic(err) } caCertPool.AppendCertsFromPEM(caCert) tlsConfig := &tls.Config{ RootCAs: caCertPool, Certificates: []tls.Certificate{*cert}, MinVersion: tls.VersionTLS12, } return tlsConfig } func bootstrapRedisCluster(ctx context.Context, resources []*dockertest.Resource) error { clients := make([]redis.UniversalClient, len(resources)) for i, r := range resources { addr := net.JoinHostPort(r.Container.NetworkSettings.IPAddress, "6379") options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", addr)) if err != nil { return err } clients[i] = redis.NewClient(options) defer func() { _ = clients[i].Close() }() if i > 0 { err := clients[i].ClusterMeet(ctx, resources[0].Container.NetworkSettings.IPAddress, "6379").Err() if err != nil { return err } } } // set slots const redisSlotCount = 16384 assignments := make([][]int, len(resources)) for i := 0; i < redisSlotCount; i++ { assignments[i%len(assignments)] = append(assignments[i%len(assignments)], i) } for i, c := range clients { err := c.ClusterAddSlots(ctx, assignments[i]...).Err() if err != nil { return err } } // wait for ready ticker := time.NewTicker(time.Millisecond * 50) defer ticker.Stop() for { select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: } ready := 0 for _, c := range clients { str, err := c.ClusterInfo(ctx).Result() if err != nil { return err } if strings.Contains(str, "cluster_state:ok") { ready++ } } if ready == len(clients) { return nil } } }