diff --git a/integration/forward_auth_test.go b/integration/forward_auth_test.go new file mode 100644 index 000000000..9a9eb6bad --- /dev/null +++ b/integration/forward_auth_test.go @@ -0,0 +1,27 @@ +package main + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/integration/internal/flows" +) + +func TestForwardAuth(t *testing.T) { + ctx := mainCtx + ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) + defer clearTimeout() + + client := testcluster.NewHTTPClient() + res, err := flows.Authenticate(ctx, client, mustParseURL("https://fa-httpdetails.localhost.pomerium.io/by-user"), + flows.WithForwardAuth(true), flows.WithEmail("bob@dogs.test"), flows.WithGroups("user")) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) +} diff --git a/integration/internal/flows/flows.go b/integration/internal/flows/flows.go index dbe1809f0..798e9f1ca 100644 --- a/integration/internal/flows/flows.go +++ b/integration/internal/flows/flows.go @@ -27,6 +27,7 @@ type authenticateConfig struct { groups []string tokenExpiration time.Duration apiPath string + forwardAuth bool } // An AuthenticateOption is an option for authentication. @@ -44,6 +45,13 @@ func getAuthenticateConfig(options ...AuthenticateOption) *authenticateConfig { return cfg } +// WithForwardAuth enables/disables forward auth. +func WithForwardAuth(fa bool) AuthenticateOption { + return func(cfg *authenticateConfig) { + cfg.forwardAuth = fa + } +} + // WithEmail sets the email to use. func WithEmail(email string) AuthenticateOption { return func(cfg *authenticateConfig) { @@ -184,10 +192,28 @@ func Authenticate(ctx context.Context, client *http.Client, url *url.URL, option } // (5) finally to callback - if req.URL.Path != pomeriumCallbackPath { + if !cfg.forwardAuth && req.URL.Path != pomeriumCallbackPath { return nil, fmt.Errorf("expected to redirect back to %s, but got %s", pomeriumCallbackPath, req.URL.String()) } + if cfg.forwardAuth { + for { + res, err = client.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != 302 { + break + } + req, err = requestFromRedirectResponse(ctx, res, req) + if err != nil { + return nil, fmt.Errorf("expected redirect to %s: %w", originalHostname, err) + } + } + return res, err + } + res, err = client.Do(req) if err != nil { return nil, err