pomerium/integration/main_test.go
Kenneth Jenkins 08c186a72e
integration: test with both authentication flows (#4817)
Add an environment variable to allow forcing either the stateful or the
stateless authenticate flow.

Split the existing integration test clusters "single" and "multi" into
four new clusters: "single-stateful", "single-stateless",
"multi-stateful", and "multi-stateless", so that the integration tests
will run for both the stateful and the stateless authenticate flows.

(The "kubernetes" cluster is not currently being run, so I've left it
alone for now.)
2023-12-07 16:06:41 -08:00

221 lines
4.8 KiB
Go

// Package main contains the pomerium integration tests
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"net/http"
"net/http/cookiejar"
"net/url"
"os"
"path/filepath"
"regexp"
"testing"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/client"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/net/publicsuffix"
)
var IDP, ClusterType, AuthenticateFlow string
func TestMain(m *testing.M) {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
flag.Parse()
if testing.Verbose() {
log.Logger = log.Logger.Level(zerolog.DebugLevel)
} else {
log.Logger = log.Logger.Level(zerolog.InfoLevel)
}
logger := log.With().Logger()
ctx := logger.WithContext(context.Background())
if err := waitForHealthy(ctx); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "services not healthy")
os.Exit(1)
return
}
setClusterInfo(ctx)
status := m.Run()
os.Exit(status)
}
type loggingRoundTripper struct {
t testing.TB
transport http.RoundTripper
}
func (l loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if l.t != nil {
l.t.Logf("%s %s", req.Method, req.URL.String())
}
return l.transport.RoundTrip(req)
}
func getTransport(t testing.TB) http.RoundTripper {
if t != nil {
t.Helper()
}
rootCAs, err := x509.SystemCertPool()
if err != nil {
panic(err)
}
bs, err := os.ReadFile(filepath.Join(".", "tpl", "files", "ca.pem"))
if err != nil {
panic(err)
}
_ = rootCAs.AppendCertsFromPEM(bs)
transport := &http.Transport{
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
},
}
return loggingRoundTripper{t, transport}
}
func getClient(t testing.TB) *http.Client {
if t != nil {
t.Helper()
}
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
panic(err)
}
return &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Transport: getTransport(t),
Jar: jar,
}
}
// Returns a new http.Client configured with the same settings as getClient(),
// as well as a pointer to the wrapped http.Transport, so that the
// http.Transport can be easily customized.
func getClientWithTransport(t testing.TB) (*http.Client, *http.Transport) {
client := getClient(t)
return client, client.Transport.(loggingRoundTripper).transport.(*http.Transport)
}
func waitForHealthy(ctx context.Context) error {
client := getClient(nil)
check := func(endpoint string) error {
reqCtx, clearTimeout := context.WithTimeout(ctx, time.Second)
defer clearTimeout()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
res, err := client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return fmt.Errorf("%s unavailable: %s", endpoint, res.Status)
}
log.Info().Int("status", res.StatusCode).Msgf("%s healthy", endpoint)
return nil
}
ticker := time.NewTicker(time.Second * 3)
defer ticker.Stop()
endpoints := []string{
"https://authenticate.localhost.pomerium.io/.well-known/pomerium/jwks.json",
"https://mock-idp.localhost.pomerium.io/.well-known/jwks.json",
}
for {
var err error
for _, endpoint := range endpoints {
err = check(endpoint)
if err != nil {
break
}
}
if err == nil {
return nil
}
log.Ctx(ctx).Info().Err(err).Msg("waiting for healthy")
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
func setClusterInfo(ctx context.Context) {
IDP = "oidc"
ClusterType = "single"
AuthenticateFlow = "stateful"
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
log.Error().Err(err).Msg("failed to create docker client")
return
}
containers, err := cli.ContainerList(ctx, types.ContainerListOptions{})
if err != nil {
log.Error().Err(err).Msg("failed to retrieve docker containers")
}
for _, container := range containers {
for _, name := range container.Names {
parts := regexp.MustCompile(`^/(\w+?)-(\w+?)[-_]pomerium.*$`).FindStringSubmatch(name)
if len(parts) == 3 {
ClusterType = parts[1]
AuthenticateFlow = parts[2]
}
}
}
log.Info().
Str("idp", IDP).
Str("cluster-type", ClusterType).
Str("authenticate-flow", AuthenticateFlow).
Send()
}
func mustParseURL(str string) *url.URL {
u, err := url.Parse(str)
if err != nil {
panic(err)
}
return u
}
func loadCertificate(t *testing.T, certName string) tls.Certificate {
t.Helper()
certFile := filepath.Join(".", "tpl", "files", certName+".pem")
keyFile := filepath.Join(".", "tpl", "files", certName+"-key.pem")
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
t.Fatal(err)
}
return cert
}