// Package main contains the pomerium integration tests package main import ( "context" "crypto/tls" "crypto/x509" "flag" "fmt" "net/http" "net/http/cookiejar" "net/url" "os" "path/filepath" "regexp" "testing" "time" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "golang.org/x/net/publicsuffix" ) var IDP, ClusterType, AuthenticateFlow string func TestMain(m *testing.M) { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) flag.Parse() if testing.Verbose() { log.Logger = log.Logger.Level(zerolog.DebugLevel) } else { log.Logger = log.Logger.Level(zerolog.InfoLevel) } logger := log.With().Logger() ctx := logger.WithContext(context.Background()) if err := waitForHealthy(ctx); err != nil { _, _ = fmt.Fprintf(os.Stderr, "services not healthy") os.Exit(1) return } setClusterInfo(ctx) status := m.Run() os.Exit(status) } type loggingRoundTripper struct { t testing.TB transport http.RoundTripper } func (l loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if l.t != nil { l.t.Logf("%s %s", req.Method, req.URL.String()) } return l.transport.RoundTrip(req) } func getTransport(t testing.TB, useHTTP3 bool) http.RoundTripper { if t != nil { t.Helper() } rootCAs, err := x509.SystemCertPool() if err != nil { panic(err) } bs, err := os.ReadFile(filepath.Join(".", "tpl", "files", "ca.pem")) if err != nil { panic(err) } _ = rootCAs.AppendCertsFromPEM(bs) var transport http.RoundTripper if useHTTP3 { transport = &http3.Transport{ TLSClientConfig: &tls.Config{ RootCAs: rootCAs, }, } } else { transport = &http.Transport{ DisableKeepAlives: true, TLSClientConfig: &tls.Config{ RootCAs: rootCAs, }, } } return loggingRoundTripper{t, transport} } func getClient(t testing.TB, useHTTP3 bool) *http.Client { if t != nil { t.Helper() } jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { panic(err) } return &http.Client{ CheckRedirect: func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse }, Transport: getTransport(t, useHTTP3), Jar: jar, } } // Returns a new http.Client configured with the same settings as getClient(), // as well as a pointer to the wrapped http.Transport, so that the // http.Transport can be easily customized. func getClientWithTransport(t testing.TB) (*http.Client, *http.Transport) { client := getClient(t, false) return client, client.Transport.(loggingRoundTripper).transport.(*http.Transport) } func waitForHealthy(ctx context.Context) error { client := getClient(nil, false) check := func(endpoint string) error { reqCtx, clearTimeout := context.WithTimeout(ctx, time.Second) defer clearTimeout() req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, endpoint, nil) if err != nil { return err } res, err := client.Do(req) if err != nil { return err } defer res.Body.Close() if res.StatusCode/100 != 2 { return fmt.Errorf("%s unavailable: %s", endpoint, res.Status) } log.Info().Int("status", res.StatusCode).Msgf("%s healthy", endpoint) return nil } ticker := time.NewTicker(time.Second * 3) defer ticker.Stop() endpoints := []string{ "https://authenticate.localhost.pomerium.io/.well-known/pomerium/jwks.json", "https://mock-idp.localhost.pomerium.io/.well-known/jwks.json", } for { var err error for _, endpoint := range endpoints { err = check(endpoint) if err != nil { break } } if err == nil { return nil } log.Ctx(ctx).Info().Err(err).Msg("waiting for healthy") select { case <-ctx.Done(): return context.Cause(ctx) case <-ticker.C: } } } func setClusterInfo(ctx context.Context) { IDP = "oidc" ClusterType = "single" AuthenticateFlow = "stateful" cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { log.Error().Err(err).Msg("failed to create docker client") return } containers, err := cli.ContainerList(ctx, container.ListOptions{}) if err != nil { log.Error().Err(err).Msg("failed to retrieve docker containers") } for _, container := range containers { for _, name := range container.Names { parts := regexp.MustCompile(`^/(\w+?)-(\w+?)[-_]pomerium.*$`).FindStringSubmatch(name) if len(parts) == 3 { ClusterType = parts[1] AuthenticateFlow = parts[2] } } } log.Info(). Str("idp", IDP). Str("cluster-type", ClusterType). Str("authenticate-flow", AuthenticateFlow). Send() } func mustParseURL(str string) *url.URL { u, err := url.Parse(str) if err != nil { panic(err) } return u } func loadCertificate(t *testing.T, certName string) tls.Certificate { t.Helper() certFile := filepath.Join(".", "tpl", "files", certName+".pem") keyFile := filepath.Join(".", "tpl", "files", certName+"-key.pem") cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { t.Fatal(err) } return cert } func testHTTPClient(t *testing.T, f func(t *testing.T, client *http.Client)) { t.Helper() t.Run("http2", func(t *testing.T) { f(t, getClient(t, false)) }) t.Run("http3", func(t *testing.T) { f(t, getClient(t, true)) }) }