diff --git a/authorize/check_response.go b/authorize/check_response.go index 873acc754..79fc5d0d5 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -235,8 +235,12 @@ func (a *Authorize) requireLoginResponse( signInURLQuery = url.Values{} signInURLQuery.Add("pomerium_traceparent", id) } + var additionalHosts []string + if request.Policy != nil { + additionalHosts = request.Policy.DependsOn + } redirectTo, err := state.authenticateFlow.AuthenticateSignInURL( - ctx, signInURLQuery, &checkRequestURL, idp.GetId()) + ctx, signInURLQuery, &checkRequestURL, idp.GetId(), additionalHosts) if err != nil { return nil, err } diff --git a/authorize/state.go b/authorize/state.go index 10cfb4195..56fae8dac 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -20,7 +20,7 @@ import ( var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) type authenticateFlow interface { - AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error) + AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, additionalLoginHosts []string) (string, error) } type authorizeState struct { diff --git a/config/policy.go b/config/policy.go index f73fd5965..517e22c34 100644 --- a/config/policy.go +++ b/config/policy.go @@ -200,6 +200,8 @@ type Policy struct { ShowErrorDetails bool `mapstructure:"show_error_details" yaml:"show_error_details" json:"show_error_details"` Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"` + + DependsOn []string `mapstructure:"depends_on" yaml:"depends_on,omitempty" json:"depends_on,omitempty"` } // RewriteHeader is a policy configuration option to rewrite an HTTP header. @@ -690,6 +692,10 @@ func (p *Policy) Validate() error { return fmt.Errorf("config: unsupported jwt_issuer_format value %q", p.JWTIssuerFormat) } + if len(p.DependsOn) > 5 { + return fmt.Errorf("config: depends_on is limited to 5 additional redirect hosts, got %v", p.DependsOn) + } + return nil } diff --git a/config/policy_test.go b/config/policy_test.go index 628f5fdd7..025fbb9d0 100644 --- a/config/policy_test.go +++ b/config/policy_test.go @@ -56,6 +56,7 @@ func Test_PolicyValidate(t *testing.T) { {"TCP To URLs", Policy{From: "tcp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "tcp://one.example.com:5000", "tcp://two.example.com:5000")}, false}, {"mix of TCP and non-TCP To URLs", Policy{From: "tcp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "https://example.com", "tcp://example.com:5000")}, true}, {"UDP To URLs", Policy{From: "udp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "udp://one.example.com:5000", "udp://two.example.com:5000")}, false}, + {"too many depends_on hosts", Policy{From: "https://httpbin.corp.example", To: mustParseWeightedURLs(t, "https://httpbin.corp.notatld"), DependsOn: []string{"a", "b", "c", "d", "e", "f"}}, true}, } for _, tt := range tests { diff --git a/internal/authenticateflow/authenticateflow_int_test.go b/internal/authenticateflow/authenticateflow_int_test.go new file mode 100644 index 000000000..7e70d3734 --- /dev/null +++ b/internal/authenticateflow/authenticateflow_int_test.go @@ -0,0 +1,98 @@ +package authenticateflow_test + +import ( + "fmt" + "io" + "net/http" + "net/http/cookiejar" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + "github.com/pomerium/pomerium/internal/testenv/values" +) + +func newHTTPUpstream( + env testenv.Environment, subdomain string, +) (upstreams.HTTPUpstream, testenv.Route) { + up := upstreams.HTTP(nil) + up.Handle("/", func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintln(w, "hello world") }) + route := up.Route(). + From(env.SubdomainURL(subdomain)). + To(values.Bind(up.Addr(), func(addr string) string { + // override the target protocol to use http:// + return fmt.Sprintf("http://%s", addr) + })). + Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true }) + env.AddUpstream(up) + return up, route +} + +func TestMultiDomainLogin(t *testing.T) { + env := testenv.New(t) + + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}})) + + // Create three routes to be linked via multi-domain login... + upstreamA, routeA := newHTTPUpstream(env, "a") + upstreamB, routeB := newHTTPUpstream(env, "b") + upstreamC, routeC := newHTTPUpstream(env, "c") + // ...and one route that will not be involved. + upstreamD, routeD := newHTTPUpstream(env, "d") + + // Configure route A to redirect through routes B and C on login. + routeA.Policy(func(p *config.Policy) { + p.DependsOn = []string{ + strings.TrimPrefix(routeB.URL().Value(), "https://"), + strings.TrimPrefix(routeC.URL().Value(), "https://"), + } + }) + + env.Start() + snippets.WaitStartupComplete(env) + + // By default the testenv code will use a separate http.Client for each + // separate route. Instead we specifically want to test the cross-route + // behavior for a single client. + cj, err := cookiejar.New(nil) + require.NoError(t, err) + sharedClient := http.Client{Jar: cj} + withSharedClient := upstreams.ClientHook( + func(_ *http.Client) *http.Client { return &sharedClient }) + + assertResponseOK := func(resp *http.Response, err error) { + t.Helper() + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + io.ReadAll(resp.Body) + resp.Body.Close() + } + assertRedirect := func(resp *http.Response, err error) { + t.Helper() + require.NoError(t, err) + assert.Equal(t, http.StatusFound, resp.StatusCode) + io.ReadAll(resp.Body) + resp.Body.Close() + } + + // Log in to the first route. + assertResponseOK(upstreamA.Get(routeA, withSharedClient, upstreams.AuthenticateAs("test@example.com"))) + + // The client should also be authenticated for routes B and C without any + // additional login redirects. + sharedClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + } + assertResponseOK(upstreamB.Get(routeB, withSharedClient)) + assertResponseOK(upstreamC.Get(routeC, withSharedClient)) + + // The client should not be authenticated for route D. + assertRedirect(upstreamD.Get(routeD, withSharedClient)) +} diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index cf7c4bb2c..17bc37a7c 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -163,7 +164,9 @@ func (s *Stateful) SignIn( // base64 our encrypted payload for URL-friendlyness encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) - callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT) + additionalHosts := strings.Split(r.FormValue(urlutil.QueryAdditionalHosts), ",") + + callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT, additionalHosts) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } @@ -324,7 +327,11 @@ func (s *Stateful) LogAuthenticateEvent(*http.Request) {} // AuthenticateSignInURL returns a URL to redirect the user to the authenticate // domain. func (s *Stateful) AuthenticateSignInURL( - ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, + ctx context.Context, + queryParams url.Values, + redirectURL *url.URL, + idpID string, + additionalHosts []string, ) (string, error) { signinURL := s.authenticateURL.ResolveReference(&url.URL{ Path: "/.pomerium/sign_in", @@ -335,6 +342,9 @@ func (s *Stateful) AuthenticateSignInURL( } queryParams.Set(urlutil.QueryRedirectURI, redirectURL.String()) queryParams.Set(urlutil.QueryIdentityProviderID, idpID) + if len(additionalHosts) > 0 { + queryParams.Set(urlutil.QueryAdditionalHosts, strings.Join(additionalHosts, ",")) + } otel.GetTextMapPropagator().Inject(ctx, trace.PomeriumURLQueryCarrier(queryParams)) signinURL.RawQuery = queryParams.Encode() redirectTo := urlutil.NewSignedURL(s.sharedKey, signinURL).String() @@ -387,6 +397,23 @@ func (s *Stateful) Callback(w http.ResponseWriter, r *http.Request) error { redirectURL.RawQuery = q.Encode() } + // Redirect chaining for multi-domain login. + additionalHosts := r.URL.Query().Get(urlutil.QueryAdditionalHosts) + if additionalHosts != "" { + nextHops := strings.Split(additionalHosts, ",") + log.Ctx(r.Context()).Debug().Strs("next-hops", nextHops).Msg("multi-domain login callback") + + callbackURL, err := urlutil.GetCallbackURL(r, encryptedSession, nextHops[1:]) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, + fmt.Errorf("proxy: couldn't get next hop callback URL: %w", err)) + } + callbackURL.Host = nextHops[0] + signedCallbackURL := urlutil.NewSignedURL(s.sharedKey, callbackURL) + httputil.Redirect(w, r, signedCallbackURL.String(), http.StatusFound) + return nil + } + // redirect httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) return nil diff --git a/internal/authenticateflow/stateful_test.go b/internal/authenticateflow/stateful_test.go index 06d14fdd9..ddfff9fbd 100644 --- a/internal/authenticateflow/stateful_test.go +++ b/internal/authenticateflow/stateful_test.go @@ -128,7 +128,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) { t.Run("NilQueryParams", func(t *testing.T) { redirectURL := &url.URL{Scheme: "https", Host: "example.com"} - u, err := flow.AuthenticateSignInURL(context.Background(), nil, redirectURL, "fake-idp-id") + u, err := flow.AuthenticateSignInURL(context.Background(), nil, redirectURL, "fake-idp-id", nil) assert.NoError(t, err) parsed, _ := url.Parse(u) assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate()) @@ -143,7 +143,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) { redirectURL := &url.URL{Scheme: "https", Host: "example.com"} q := url.Values{} q.Set("foo", "bar") - u, err := flow.AuthenticateSignInURL(context.Background(), q, redirectURL, "fake-idp-id") + u, err := flow.AuthenticateSignInURL(context.Background(), q, redirectURL, "fake-idp-id", nil) assert.NoError(t, err) parsed, _ := url.Parse(u) assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate()) @@ -155,6 +155,21 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) { assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id")) assert.Equal(t, "bar", q.Get("foo")) }) + t.Run("AdditionalHosts", func(t *testing.T) { + redirectURL := &url.URL{Scheme: "https", Host: "example.com"} + additionalHosts := []string{"foo.example.com", "bar.example.com:1234"} + u, err := flow.AuthenticateSignInURL(context.Background(), nil, redirectURL, "fake-idp-id", additionalHosts) + assert.NoError(t, err) + parsed, _ := url.Parse(u) + assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate()) + assert.Equal(t, "https", parsed.Scheme) + assert.Equal(t, "authenticate.example.com", parsed.Host) + assert.Equal(t, "/.pomerium/sign_in", parsed.Path) + q := parsed.Query() + assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri")) + assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id")) + assert.Equal(t, "foo.example.com,bar.example.com:1234", q.Get("pomerium_additional_hosts")) + }) } func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) { @@ -277,6 +292,7 @@ func TestStatefulCallback(t *testing.T) { } location, _ := url.Parse(w.Result().Header.Get("Location")) assert.Equal(t, "example.com", location.Host) + assert.Equal(t, "/", location.Path) assert.Equal(t, "ok", location.Query().Get("pomerium_callback_uri")) } else { if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) { @@ -287,6 +303,60 @@ func TestStatefulCallback(t *testing.T) { } } +func TestStatefulCallback_AdditionalHosts(t *testing.T) { + opts := config.NewDefaultOptions() + opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" + sharedKey, _ := opts.GetSharedKey() + + flow, err := NewStateful( + context.Background(), + trace.NewNoopTracerProvider(), + &config.Config{Options: opts}, + &mstore.Store{Session: &sessions.State{}}, + ) + require.NoError(t, err) + + redirectURI := "https://route.example.com/" + callbackURI := &url.URL{ + Scheme: "https", + Host: "route.example.com", + Path: "/.pomerium/callback", + RawQuery: url.Values{ + urlutil.QuerySessionEncrypted: []string{goodEncryptionString}, + urlutil.QueryRedirectURI: []string{redirectURI}, + urlutil.QueryAdditionalHosts: []string{"foo.example.com,bar.example.com"}, + }.Encode(), + } + signedCallbackURI := urlutil.NewSignedURL(sharedKey, callbackURI) + + doCallback := func(uri string) *http.Response { + t.Helper() + r := httptest.NewRequest(http.MethodGet, uri, nil) + r.Host = r.URL.Host + + w := httptest.NewRecorder() + err = flow.Callback(w, r) + require.NoError(t, err) + return w.Result() + } + + // Callback() should serve redirects to the additional hosts before the final redirect URI. + res := doCallback(signedCallbackURI.String()) + location, _ := url.Parse(res.Header.Get("Location")) + assert.Equal(t, "foo.example.com", location.Host) + assert.Equal(t, "/.pomerium/callback/", location.Path) + + res = doCallback(location.String()) + location, _ = url.Parse(res.Header.Get("Location")) + assert.Equal(t, "bar.example.com", location.Host) + assert.Equal(t, "/.pomerium/callback/", location.Path) + + res = doCallback(location.String()) + location, _ = url.Parse(res.Header.Get("Location")) + assert.Equal(t, "route.example.com", location.Host) + assert.Equal(t, "/", location.Path) +} + func TestStatefulRevokeSession(t *testing.T) { opts := config.NewDefaultOptions() flow, err := NewStateful(context.Background(), trace.NewNoopTracerProvider(), &config.Config{Options: opts}, nil) diff --git a/internal/authenticateflow/stateless.go b/internal/authenticateflow/stateless.go index 0c2c189d9..6060581d3 100644 --- a/internal/authenticateflow/stateless.go +++ b/internal/authenticateflow/stateless.go @@ -355,7 +355,11 @@ func getUserClaim(profile *identitypb.Profile, field string) *string { // AuthenticateSignInURL returns a URL to redirect the user to the authenticate // domain. func (s *Stateless) AuthenticateSignInURL( - ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, + ctx context.Context, + queryParams url.Values, + redirectURL *url.URL, + idpID string, + _ []string, ) (string, error) { authenticateHPKEPublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx) if err != nil { diff --git a/internal/urlutil/proxy.go b/internal/urlutil/proxy.go index 7386b8e5d..0ea96ab6c 100644 --- a/internal/urlutil/proxy.go +++ b/internal/urlutil/proxy.go @@ -4,19 +4,15 @@ import ( "errors" "net/http" "net/url" + "strings" ) // ErrMissingRedirectURI indicates the pomerium_redirect_uri was missing from the query string. var ErrMissingRedirectURI = errors.New("missing " + QueryRedirectURI) // GetCallbackURL gets the proxy's callback URL from a request and a base64url encoded + encrypted session state JWT. -func GetCallbackURL(r *http.Request, encodedSessionJWT string) (*url.URL, error) { - return GetCallbackURLForRedirectURI(r, encodedSessionJWT, r.FormValue(QueryRedirectURI)) -} - -// GetCallbackURLForRedirectURI gets the proxy's callback URL from a request and a base64url encoded + encrypted session -// state JWT. -func GetCallbackURLForRedirectURI(r *http.Request, encodedSessionJWT, rawRedirectURI string) (*url.URL, error) { +func GetCallbackURL(r *http.Request, encodedSessionJWT string, additionalHosts []string) (*url.URL, error) { + rawRedirectURI := r.FormValue(QueryRedirectURI) if rawRedirectURI == "" { return nil, ErrMissingRedirectURI } @@ -55,6 +51,10 @@ func GetCallbackURLForRedirectURI(r *http.Request, encodedSessionJWT, rawRedirec callbackParams.Set(QueryTracestate, tracestate) } + if len(additionalHosts) > 0 { + callbackParams.Set(QueryAdditionalHosts, strings.Join(additionalHosts, ",")) + } + // add our encoded and encrypted route-session JWT to a query param callbackParams.Set(QuerySessionEncrypted, encodedSessionJWT) callbackParams.Set(QueryRedirectURI, redirectURI.String()) diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 3ddee3cc6..746c888f8 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -4,6 +4,7 @@ package urlutil // services over HTTP calls and redirects. They are typically used in // conjunction with a HMAC to ensure authenticity. const ( + QueryAdditionalHosts = "pomerium_additional_hosts" QueryCallbackURI = "pomerium_callback_uri" QueryDeviceCredentialID = "pomerium_device_credential_id" QueryDeviceType = "pomerium_device_type" diff --git a/proxy/handlers.go b/proxy/handlers.go index 2013df2a4..de51ae37e 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -149,7 +149,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error q.Set(urlutil.QueryIsProgrammatic, "true") rawURL, err := state.authenticateFlow.AuthenticateSignInURL( - r.Context(), q, redirectURI, idp.GetId()) + r.Context(), q, redirectURI, idp.GetId(), nil) if err != nil { return httputil.NewError(http.StatusInternalServerError, err) } diff --git a/proxy/state.go b/proxy/state.go index 110459293..6386149cc 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -19,7 +19,7 @@ import ( var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) type authenticateFlow interface { - AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error) + AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, additionalHosts []string) (string, error) Callback(w http.ResponseWriter, r *http.Request) error }