mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
Add an environment variable to allow forcing either the stateful or the stateless authenticate flow. Split the existing integration test clusters "single" and "multi" into four new clusters: "single-stateful", "single-stateless", "multi-stateful", and "multi-stateless", so that the integration tests will run for both the stateful and the stateless authenticate flows. (The "kubernetes" cluster is not currently being run, so I've left it alone for now.)
221 lines
4.8 KiB
Go
221 lines
4.8 KiB
Go
// Package main contains the pomerium integration tests
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"flag"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/docker/docker/api/types"
|
|
"github.com/docker/docker/client"
|
|
"github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/net/publicsuffix"
|
|
)
|
|
|
|
var IDP, ClusterType, AuthenticateFlow string
|
|
|
|
func TestMain(m *testing.M) {
|
|
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
|
|
|
flag.Parse()
|
|
if testing.Verbose() {
|
|
log.Logger = log.Logger.Level(zerolog.DebugLevel)
|
|
} else {
|
|
log.Logger = log.Logger.Level(zerolog.InfoLevel)
|
|
}
|
|
|
|
logger := log.With().Logger()
|
|
ctx := logger.WithContext(context.Background())
|
|
|
|
if err := waitForHealthy(ctx); err != nil {
|
|
_, _ = fmt.Fprintf(os.Stderr, "services not healthy")
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
|
|
setClusterInfo(ctx)
|
|
|
|
status := m.Run()
|
|
os.Exit(status)
|
|
}
|
|
|
|
type loggingRoundTripper struct {
|
|
t testing.TB
|
|
transport http.RoundTripper
|
|
}
|
|
|
|
func (l loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
if l.t != nil {
|
|
l.t.Logf("%s %s", req.Method, req.URL.String())
|
|
}
|
|
return l.transport.RoundTrip(req)
|
|
}
|
|
|
|
func getTransport(t testing.TB) http.RoundTripper {
|
|
if t != nil {
|
|
t.Helper()
|
|
}
|
|
|
|
rootCAs, err := x509.SystemCertPool()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
bs, err := os.ReadFile(filepath.Join(".", "tpl", "files", "ca.pem"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
_ = rootCAs.AppendCertsFromPEM(bs)
|
|
transport := &http.Transport{
|
|
DisableKeepAlives: true,
|
|
TLSClientConfig: &tls.Config{
|
|
RootCAs: rootCAs,
|
|
},
|
|
}
|
|
return loggingRoundTripper{t, transport}
|
|
}
|
|
|
|
func getClient(t testing.TB) *http.Client {
|
|
if t != nil {
|
|
t.Helper()
|
|
}
|
|
|
|
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return &http.Client{
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
Transport: getTransport(t),
|
|
Jar: jar,
|
|
}
|
|
}
|
|
|
|
// Returns a new http.Client configured with the same settings as getClient(),
|
|
// as well as a pointer to the wrapped http.Transport, so that the
|
|
// http.Transport can be easily customized.
|
|
func getClientWithTransport(t testing.TB) (*http.Client, *http.Transport) {
|
|
client := getClient(t)
|
|
return client, client.Transport.(loggingRoundTripper).transport.(*http.Transport)
|
|
}
|
|
|
|
func waitForHealthy(ctx context.Context) error {
|
|
client := getClient(nil)
|
|
check := func(endpoint string) error {
|
|
reqCtx, clearTimeout := context.WithTimeout(ctx, time.Second)
|
|
defer clearTimeout()
|
|
|
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
if res.StatusCode/100 != 2 {
|
|
return fmt.Errorf("%s unavailable: %s", endpoint, res.Status)
|
|
}
|
|
|
|
log.Info().Int("status", res.StatusCode).Msgf("%s healthy", endpoint)
|
|
|
|
return nil
|
|
}
|
|
|
|
ticker := time.NewTicker(time.Second * 3)
|
|
defer ticker.Stop()
|
|
|
|
endpoints := []string{
|
|
"https://authenticate.localhost.pomerium.io/.well-known/pomerium/jwks.json",
|
|
"https://mock-idp.localhost.pomerium.io/.well-known/jwks.json",
|
|
}
|
|
|
|
for {
|
|
var err error
|
|
for _, endpoint := range endpoints {
|
|
err = check(endpoint)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
log.Ctx(ctx).Info().Err(err).Msg("waiting for healthy")
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func setClusterInfo(ctx context.Context) {
|
|
IDP = "oidc"
|
|
ClusterType = "single"
|
|
AuthenticateFlow = "stateful"
|
|
|
|
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("failed to create docker client")
|
|
return
|
|
}
|
|
|
|
containers, err := cli.ContainerList(ctx, types.ContainerListOptions{})
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("failed to retrieve docker containers")
|
|
}
|
|
for _, container := range containers {
|
|
for _, name := range container.Names {
|
|
parts := regexp.MustCompile(`^/(\w+?)-(\w+?)[-_]pomerium.*$`).FindStringSubmatch(name)
|
|
if len(parts) == 3 {
|
|
ClusterType = parts[1]
|
|
AuthenticateFlow = parts[2]
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Info().
|
|
Str("idp", IDP).
|
|
Str("cluster-type", ClusterType).
|
|
Str("authenticate-flow", AuthenticateFlow).
|
|
Send()
|
|
}
|
|
|
|
func mustParseURL(str string) *url.URL {
|
|
u, err := url.Parse(str)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return u
|
|
}
|
|
|
|
func loadCertificate(t *testing.T, certName string) tls.Certificate {
|
|
t.Helper()
|
|
certFile := filepath.Join(".", "tpl", "files", certName+".pem")
|
|
keyFile := filepath.Join(".", "tpl", "files", certName+"-key.pem")
|
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return cert
|
|
}
|