pomerium/internal/testutil/redis.go
Caleb Doxsey 0b48da1e2f
databroker: support rotating shared secret (#3502)
* databroker: support rotating shared secret

* fix test

* run tests on linux

* fix tests

* fix typo

* increase timeout
2022-07-26 10:59:54 -06:00

388 lines
8.7 KiB
Go

package testutil
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
const maxWait = 20 * time.Minute
// WithTestRedis creates a test a test redis instance using docker.
func WithTestRedis(useTLS bool, handler func(rawURL string) error) error {
ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait)
defer clearTimeout()
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
pool, err := dockertest.NewPool("")
if err != nil {
return err
}
opts := &dockertest.RunOptions{
Repository: "redis",
Tag: "6",
}
scheme := "redis"
if useTLS {
opts.Mounts = []string{
filepath.Join(TestDataRoot(), "tls") + ":/tls",
}
opts.Cmd = []string{
"--port", "0",
"--tls-port", "6379",
"--tls-cert-file", "/tls/redis.crt",
"--tls-key-file", "/tls/redis.key",
"--tls-ca-cert-file", "/tls/ca.crt",
}
scheme = "rediss"
}
resource, err := pool.RunWithOptions(opts)
if err != nil {
return err
}
_ = resource.Expire(uint(maxWait.Seconds()))
redisURL := fmt.Sprintf("%s://%s/0", scheme, resource.GetHostPort("6379/tcp"))
if err := pool.Retry(func() error {
options, err := redis.ParseURL(redisURL)
if err != nil {
return err
}
if useTLS {
options.TLSConfig = RedisTLSConfig()
}
client := redis.NewClient(options)
defer client.Close()
return client.Ping(ctx).Err()
}); err != nil {
_ = pool.Purge(resource)
return err
}
e := handler(redisURL)
if err := pool.Purge(resource); err != nil {
return err
}
return e
}
// WithTestRedisCluster creates a new redis cluster 3 node cluster.
func WithTestRedisCluster(handler func(rawURL string) error) error {
ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait)
defer clearTimeout()
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
pool, err := dockertest.NewPool("")
if err != nil {
return err
}
redises := make([]*dockertest.Resource, 3)
for i := range redises {
conf := "cluster-enabled yes\ncluster-config-file nodes.conf"
r, err := pool.RunWithOptions(&dockertest.RunOptions{
Hostname: fmt.Sprintf("redis%d", i),
Repository: "redis",
Tag: "6",
Entrypoint: []string{
"/bin/bash", "-c",
`echo "` + conf + `" >/tmp/redis.conf && chmod 0777 /tmp/redis.conf && exec docker-entrypoint.sh /tmp/redis.conf`,
},
ExposedPorts: []string{
"6379/tcp",
"26379/tcp",
},
})
if err != nil {
return err
}
defer r.Close()
_ = r.Expire(uint(maxWait.Seconds()))
go func() {
_ = pool.Client.Logs(docker.LogsOptions{
Context: ctx,
Stderr: true,
Stdout: true,
Follow: true,
Timestamps: true,
Container: r.Container.ID,
OutputStream: os.Stderr,
ErrorStream: os.Stderr,
})
}()
redises[i] = r
}
addrs := make([]string, 3)
for i, r := range redises {
addrs[i] = net.JoinHostPort(
r.Container.NetworkSettings.IPAddress,
"6379",
)
}
for _, addr := range addrs {
err := pool.Retry(func() error {
options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", addr))
if err != nil {
return err
}
client := redis.NewClient(options)
defer client.Close()
return client.Ping(ctx).Err()
})
if err != nil {
return err
}
}
// join the nodes to the cluster
err = bootstrapRedisCluster(ctx, redises)
if err != nil {
return err
}
e := handler(fmt.Sprintf("redis+cluster://%s", strings.Join(addrs, ",")))
for _, r := range redises {
if err := pool.Purge(r); err != nil {
return err
}
}
return e
}
// WithTestRedisSentinel creates a new redis sentinel 3 node cluster.
func WithTestRedisSentinel(handler func(rawURL string) error) error {
ctx, clearTimeout := context.WithTimeout(context.Background(), maxWait)
defer clearTimeout()
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
pool, err := dockertest.NewPool("")
if err != nil {
return err
}
redises := make([]*dockertest.Resource, 3)
for i := range redises {
r, err := pool.RunWithOptions(&dockertest.RunOptions{
Hostname: fmt.Sprintf("redis%d", i),
Repository: "redis",
Tag: "6",
ExposedPorts: []string{
"6379/tcp",
"26379/tcp",
},
})
if err != nil {
return err
}
defer r.Close()
_ = r.Expire(uint(maxWait.Seconds()))
redises[i] = r
}
sentinels := make([]*dockertest.Resource, len(redises))
for i := range sentinels {
conf := fmt.Sprintf("sentinel monitor master %s 6379 %d\n",
redises[0].Container.NetworkSettings.IPAddress, len(redises))
if i > 0 {
conf += fmt.Sprintf("sentinel known-slave master %s 6379\n",
redises[i].Container.NetworkSettings.IPAddress)
}
r, err := pool.RunWithOptions(&dockertest.RunOptions{
Hostname: fmt.Sprintf("sentineld%d", i),
Repository: "redis",
Tag: "6",
Entrypoint: []string{
"/bin/bash", "-c",
`echo "` + conf + `" >/tmp/sentinel.conf && chmod 0777 /tmp/sentinel.conf && exec docker-entrypoint.sh /tmp/sentinel.conf --sentinel`,
},
ExposedPorts: []string{
"6379/tcp",
"26379/tcp",
},
})
if err != nil {
return err
}
defer r.Close()
_ = r.Expire(uint(maxWait.Seconds()))
go func() {
_ = pool.Client.Logs(docker.LogsOptions{
Context: ctx,
Stderr: true,
Stdout: true,
Follow: true,
Timestamps: true,
Container: r.Container.ID,
OutputStream: os.Stderr,
ErrorStream: os.Stderr,
})
}()
sentinels[i] = r
}
addrs := make([]string, len(sentinels))
for i, r := range sentinels {
addrs[i] = net.JoinHostPort(
r.Container.NetworkSettings.IPAddress,
"26379",
)
}
redisURL := fmt.Sprintf("redis+sentinel://%s/master/0", strings.Join(addrs, ","))
for _, r := range redises {
addr := net.JoinHostPort(
r.Container.NetworkSettings.IPAddress,
"6379",
)
if err := pool.Retry(func() error {
options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", addr))
if err != nil {
return err
}
client := redis.NewClient(options)
defer client.Close()
return client.Ping(ctx).Err()
}); err != nil {
_ = pool.Purge(r)
return err
}
}
for _, r := range sentinels {
if err := pool.Retry(func() error {
options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", r.GetHostPort("26379/tcp")))
if err != nil {
return err
}
client := redis.NewClient(options)
defer client.Close()
return client.Ping(ctx).Err()
}); err != nil {
_ = pool.Purge(r)
return err
}
}
e := handler(redisURL)
for _, r := range append(redises, sentinels...) {
if err := pool.Purge(r); err != nil {
return err
}
}
return e
}
// RedisTLSConfig returns the TLS Config to use with redis.
func RedisTLSConfig() *tls.Config {
cert, err := cryptutil.CertificateFromFile(
filepath.Join(TestDataRoot(), "tls", "redis.crt"),
filepath.Join(TestDataRoot(), "tls", "redis.key"),
)
if err != nil {
panic(err)
}
caCertPool := x509.NewCertPool()
caCert, err := os.ReadFile(filepath.Join(TestDataRoot(), "tls", "ca.crt"))
if err != nil {
panic(err)
}
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig := &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{*cert},
MinVersion: tls.VersionTLS12,
}
return tlsConfig
}
func bootstrapRedisCluster(ctx context.Context, resources []*dockertest.Resource) error {
clients := make([]redis.UniversalClient, len(resources))
for i, r := range resources {
addr := net.JoinHostPort(r.Container.NetworkSettings.IPAddress, "6379")
options, err := redis.ParseURL(fmt.Sprintf("redis://%s/0", addr))
if err != nil {
return err
}
clients[i] = redis.NewClient(options)
defer func() { _ = clients[i].Close() }()
if i > 0 {
err := clients[i].ClusterMeet(ctx, resources[0].Container.NetworkSettings.IPAddress, "6379").Err()
if err != nil {
return err
}
}
}
// set slots
const redisSlotCount = 16384
assignments := make([][]int, len(resources))
for i := 0; i < redisSlotCount; i++ {
assignments[i%len(assignments)] = append(assignments[i%len(assignments)], i)
}
for i, c := range clients {
err := c.ClusterAddSlots(ctx, assignments[i]...).Err()
if err != nil {
return err
}
}
// wait for ready
ticker := time.NewTicker(time.Millisecond * 50)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
ready := 0
for _, c := range clients {
str, err := c.ClusterInfo(ctx).Result()
if err != nil {
return err
}
if strings.Contains(str, "cluster_state:ok") {
ready++
}
}
if ready == len(clients) {
return nil
}
}
}