mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
205 lines
5 KiB
Go
205 lines
5 KiB
Go
package redis
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gomodule/redigo/redis"
|
|
"github.com/ory/dockertest/v3"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
)
|
|
|
|
var db *DB
|
|
|
|
func cleanup(c redis.Conn, db *DB, t *testing.T) {
|
|
require.NoError(t, c.Send("MULTI"))
|
|
require.NoError(t, c.Send("DEL", db.recordType))
|
|
require.NoError(t, c.Send("DEL", db.versionSet))
|
|
require.NoError(t, c.Send("DEL", db.deletedSet))
|
|
_, err := c.Do("EXEC")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func tlsConfig(rawURL string, t *testing.T) *tls.Config {
|
|
if !strings.HasPrefix(rawURL, "rediss") {
|
|
return nil
|
|
}
|
|
cert, err := cryptutil.CertificateFromFile("./testdata/tls/redis.crt", "./testdata/tls/redis.key")
|
|
require.NoError(t, err)
|
|
caCertPool := x509.NewCertPool()
|
|
caCert, err := ioutil.ReadFile("./testdata/tls/ca.crt")
|
|
require.NoError(t, err)
|
|
caCertPool.AppendCertsFromPEM(caCert)
|
|
tlsConfig := &tls.Config{
|
|
RootCAs: caCertPool,
|
|
Certificates: []tls.Certificate{*cert},
|
|
}
|
|
return tlsConfig
|
|
}
|
|
|
|
func runWithRedisDockerImage(t *testing.T, runOpts *dockertest.RunOptions, withTLS bool, testFunc func(t *testing.T)) {
|
|
pool, err := dockertest.NewPool("")
|
|
if err != nil {
|
|
t.Fatalf("Could not connect to docker: %s", err)
|
|
}
|
|
resource, err := pool.RunWithOptions(runOpts)
|
|
if err != nil {
|
|
t.Fatalf("Could not start resource: %s", err)
|
|
}
|
|
|
|
defer func() {
|
|
if err := pool.Purge(resource); err != nil {
|
|
t.Fatalf("Could not purge resource: %s", err)
|
|
}
|
|
}()
|
|
|
|
scheme := "redis"
|
|
if withTLS {
|
|
scheme = "rediss"
|
|
}
|
|
address := fmt.Sprintf(scheme+"://localhost:%s/0", resource.GetPort("6379/tcp"))
|
|
if err := pool.Retry(func() error {
|
|
var err error
|
|
db, err = New(address, "record_type", WithTLSConfig(tlsConfig(address, t)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = db.pool.Get().Do("PING")
|
|
return err
|
|
}); err != nil {
|
|
t.Fatalf("Could not connect to docker: %s", err)
|
|
}
|
|
|
|
testFunc(t)
|
|
}
|
|
|
|
func TestDB(t *testing.T) {
|
|
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
t.Skip("Github action can not run docker on MacOS")
|
|
}
|
|
|
|
cwd, err := os.Getwd()
|
|
assert.NoError(t, err)
|
|
|
|
tlsCmd := []string{
|
|
"--port", "0",
|
|
"--tls-port", "6379",
|
|
"--tls-cert-file", "/tls/redis.crt",
|
|
"--tls-key-file", "/tls/redis.key",
|
|
"--tls-ca-cert-file", "/tls/ca.crt",
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
withTLS bool
|
|
runOpts *dockertest.RunOptions
|
|
}{
|
|
{"redis", false, &dockertest.RunOptions{Repository: "redis", Tag: "latest"}},
|
|
{"redis TLS", true, &dockertest.RunOptions{Repository: "redis", Tag: "latest", Cmd: tlsCmd, Mounts: []string{cwd + "/testdata/tls:/tls"}}},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
runWithRedisDockerImage(t, tc.runOpts, tc.withTLS, testDB)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testDB(t *testing.T) {
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
defer cancelFunc()
|
|
|
|
ids := []string{"a", "b", "c"}
|
|
id := ids[0]
|
|
c := db.pool.Get()
|
|
defer c.Close()
|
|
|
|
ch := db.Watch(ctx)
|
|
|
|
t.Run("get missing record", func(t *testing.T) {
|
|
record, err := db.Get(ctx, id)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, record)
|
|
})
|
|
t.Run("get record", func(t *testing.T) {
|
|
data := new(anypb.Any)
|
|
assert.NoError(t, db.Put(ctx, id, data))
|
|
record, err := db.Get(ctx, id)
|
|
require.NoError(t, err)
|
|
if assert.NotNil(t, record) {
|
|
assert.NotNil(t, record.CreatedAt)
|
|
assert.Equal(t, data, record.Data)
|
|
assert.Nil(t, record.DeletedAt)
|
|
assert.Equal(t, "a", record.Id)
|
|
assert.NotNil(t, record.ModifiedAt)
|
|
assert.Equal(t, "000000000001", record.Version)
|
|
}
|
|
})
|
|
t.Run("delete record", func(t *testing.T) {
|
|
assert.NoError(t, db.Delete(ctx, id))
|
|
record, err := db.Get(ctx, id)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, record)
|
|
assert.NotNil(t, record.DeletedAt)
|
|
})
|
|
t.Run("clear deleted", func(t *testing.T) {
|
|
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
|
record, err := db.Get(ctx, id)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, record)
|
|
})
|
|
t.Run("get all", func(t *testing.T) {
|
|
records, err := db.GetAll(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, records, 0)
|
|
data := new(anypb.Any)
|
|
|
|
for _, id := range ids {
|
|
assert.NoError(t, db.Put(ctx, id, data))
|
|
}
|
|
records, err = db.GetAll(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, records, len(ids))
|
|
for _, id := range ids {
|
|
_, _ = c.Do("DEL", id)
|
|
}
|
|
})
|
|
t.Run("list", func(t *testing.T) {
|
|
cleanup(c, db, t)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
id := fmt.Sprintf("%02d", i)
|
|
data := new(anypb.Any)
|
|
assert.NoError(t, db.Put(ctx, id, data))
|
|
}
|
|
|
|
records, err := db.List(ctx, "")
|
|
assert.NoError(t, err)
|
|
assert.Len(t, records, 10)
|
|
records, err = db.List(ctx, "00000000000A")
|
|
assert.NoError(t, err)
|
|
assert.Len(t, records, 5)
|
|
records, err = db.List(ctx, "00000000000F")
|
|
assert.NoError(t, err)
|
|
assert.Len(t, records, 0)
|
|
})
|
|
|
|
expectedNumEvents := 14
|
|
actualNumEvents := 0
|
|
for range ch {
|
|
actualNumEvents++
|
|
if actualNumEvents == expectedNumEvents {
|
|
cancelFunc()
|
|
}
|
|
}
|
|
}
|