mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-11 16:17:39 +02:00
multi-domain login redirects (#5564)
Add a new 'depends_on' route configuration option taking a list of additional hosts to redirect through on login. Update the authorize service and proxy service to support a chain of /.pomerium/callback redirects. Add an integration test for this feature.
This commit is contained in:
parent
c47055bece
commit
c848c225e8
12 changed files with 227 additions and 16 deletions
|
@ -235,8 +235,12 @@ func (a *Authorize) requireLoginResponse(
|
||||||
signInURLQuery = url.Values{}
|
signInURLQuery = url.Values{}
|
||||||
signInURLQuery.Add("pomerium_traceparent", id)
|
signInURLQuery.Add("pomerium_traceparent", id)
|
||||||
}
|
}
|
||||||
|
var additionalHosts []string
|
||||||
|
if request.Policy != nil {
|
||||||
|
additionalHosts = request.Policy.DependsOn
|
||||||
|
}
|
||||||
redirectTo, err := state.authenticateFlow.AuthenticateSignInURL(
|
redirectTo, err := state.authenticateFlow.AuthenticateSignInURL(
|
||||||
ctx, signInURLQuery, &checkRequestURL, idp.GetId())
|
ctx, signInURLQuery, &checkRequestURL, idp.GetId(), additionalHosts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ import (
|
||||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
|
||||||
type authenticateFlow interface {
|
type authenticateFlow interface {
|
||||||
AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error)
|
AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, additionalLoginHosts []string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type authorizeState struct {
|
type authorizeState struct {
|
||||||
|
|
|
@ -200,6 +200,8 @@ type Policy struct {
|
||||||
ShowErrorDetails bool `mapstructure:"show_error_details" yaml:"show_error_details" json:"show_error_details"`
|
ShowErrorDetails bool `mapstructure:"show_error_details" yaml:"show_error_details" json:"show_error_details"`
|
||||||
|
|
||||||
Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"`
|
Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"`
|
||||||
|
|
||||||
|
DependsOn []string `mapstructure:"depends_on" yaml:"depends_on,omitempty" json:"depends_on,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RewriteHeader is a policy configuration option to rewrite an HTTP header.
|
// RewriteHeader is a policy configuration option to rewrite an HTTP header.
|
||||||
|
@ -690,6 +692,10 @@ func (p *Policy) Validate() error {
|
||||||
return fmt.Errorf("config: unsupported jwt_issuer_format value %q", p.JWTIssuerFormat)
|
return fmt.Errorf("config: unsupported jwt_issuer_format value %q", p.JWTIssuerFormat)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(p.DependsOn) > 5 {
|
||||||
|
return fmt.Errorf("config: depends_on is limited to 5 additional redirect hosts, got %v", p.DependsOn)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@ func Test_PolicyValidate(t *testing.T) {
|
||||||
{"TCP To URLs", Policy{From: "tcp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "tcp://one.example.com:5000", "tcp://two.example.com:5000")}, false},
|
{"TCP To URLs", Policy{From: "tcp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "tcp://one.example.com:5000", "tcp://two.example.com:5000")}, false},
|
||||||
{"mix of TCP and non-TCP To URLs", Policy{From: "tcp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "https://example.com", "tcp://example.com:5000")}, true},
|
{"mix of TCP and non-TCP To URLs", Policy{From: "tcp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "https://example.com", "tcp://example.com:5000")}, true},
|
||||||
{"UDP To URLs", Policy{From: "udp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "udp://one.example.com:5000", "udp://two.example.com:5000")}, false},
|
{"UDP To URLs", Policy{From: "udp+https://httpbin.corp.example:4000", To: mustParseWeightedURLs(t, "udp://one.example.com:5000", "udp://two.example.com:5000")}, false},
|
||||||
|
{"too many depends_on hosts", Policy{From: "https://httpbin.corp.example", To: mustParseWeightedURLs(t, "https://httpbin.corp.notatld"), DependsOn: []string{"a", "b", "c", "d", "e", "f"}}, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
98
internal/authenticateflow/authenticateflow_int_test.go
Normal file
98
internal/authenticateflow/authenticateflow_int_test.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package authenticateflow_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newHTTPUpstream(
|
||||||
|
env testenv.Environment, subdomain string,
|
||||||
|
) (upstreams.HTTPUpstream, testenv.Route) {
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
up.Handle("/", func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintln(w, "hello world") })
|
||||||
|
route := up.Route().
|
||||||
|
From(env.SubdomainURL(subdomain)).
|
||||||
|
To(values.Bind(up.Addr(), func(addr string) string {
|
||||||
|
// override the target protocol to use http://
|
||||||
|
return fmt.Sprintf("http://%s", addr)
|
||||||
|
})).
|
||||||
|
Policy(func(p *config.Policy) { p.AllowAnyAuthenticatedUser = true })
|
||||||
|
env.AddUpstream(up)
|
||||||
|
return up, route
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiDomainLogin(t *testing.T) {
|
||||||
|
env := testenv.New(t)
|
||||||
|
|
||||||
|
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
|
||||||
|
|
||||||
|
// Create three routes to be linked via multi-domain login...
|
||||||
|
upstreamA, routeA := newHTTPUpstream(env, "a")
|
||||||
|
upstreamB, routeB := newHTTPUpstream(env, "b")
|
||||||
|
upstreamC, routeC := newHTTPUpstream(env, "c")
|
||||||
|
// ...and one route that will not be involved.
|
||||||
|
upstreamD, routeD := newHTTPUpstream(env, "d")
|
||||||
|
|
||||||
|
// Configure route A to redirect through routes B and C on login.
|
||||||
|
routeA.Policy(func(p *config.Policy) {
|
||||||
|
p.DependsOn = []string{
|
||||||
|
strings.TrimPrefix(routeB.URL().Value(), "https://"),
|
||||||
|
strings.TrimPrefix(routeC.URL().Value(), "https://"),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
// By default the testenv code will use a separate http.Client for each
|
||||||
|
// separate route. Instead we specifically want to test the cross-route
|
||||||
|
// behavior for a single client.
|
||||||
|
cj, err := cookiejar.New(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
sharedClient := http.Client{Jar: cj}
|
||||||
|
withSharedClient := upstreams.ClientHook(
|
||||||
|
func(_ *http.Client) *http.Client { return &sharedClient })
|
||||||
|
|
||||||
|
assertResponseOK := func(resp *http.Response, err error) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
assertRedirect := func(resp *http.Response, err error) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||||
|
io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log in to the first route.
|
||||||
|
assertResponseOK(upstreamA.Get(routeA, withSharedClient, upstreams.AuthenticateAs("test@example.com")))
|
||||||
|
|
||||||
|
// The client should also be authenticated for routes B and C without any
|
||||||
|
// additional login redirects.
|
||||||
|
sharedClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
}
|
||||||
|
assertResponseOK(upstreamB.Get(routeB, withSharedClient))
|
||||||
|
assertResponseOK(upstreamC.Get(routeC, withSharedClient))
|
||||||
|
|
||||||
|
// The client should not be authenticated for route D.
|
||||||
|
assertRedirect(upstreamD.Get(routeD, withSharedClient))
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||||
|
@ -163,7 +164,9 @@ func (s *Stateful) SignIn(
|
||||||
// base64 our encrypted payload for URL-friendlyness
|
// base64 our encrypted payload for URL-friendlyness
|
||||||
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
||||||
|
|
||||||
callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT)
|
additionalHosts := strings.Split(r.FormValue(urlutil.QueryAdditionalHosts), ",")
|
||||||
|
|
||||||
|
callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT, additionalHosts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
|
@ -324,7 +327,11 @@ func (s *Stateful) LogAuthenticateEvent(*http.Request) {}
|
||||||
// AuthenticateSignInURL returns a URL to redirect the user to the authenticate
|
// AuthenticateSignInURL returns a URL to redirect the user to the authenticate
|
||||||
// domain.
|
// domain.
|
||||||
func (s *Stateful) AuthenticateSignInURL(
|
func (s *Stateful) AuthenticateSignInURL(
|
||||||
ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string,
|
ctx context.Context,
|
||||||
|
queryParams url.Values,
|
||||||
|
redirectURL *url.URL,
|
||||||
|
idpID string,
|
||||||
|
additionalHosts []string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
signinURL := s.authenticateURL.ResolveReference(&url.URL{
|
signinURL := s.authenticateURL.ResolveReference(&url.URL{
|
||||||
Path: "/.pomerium/sign_in",
|
Path: "/.pomerium/sign_in",
|
||||||
|
@ -335,6 +342,9 @@ func (s *Stateful) AuthenticateSignInURL(
|
||||||
}
|
}
|
||||||
queryParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
queryParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
||||||
queryParams.Set(urlutil.QueryIdentityProviderID, idpID)
|
queryParams.Set(urlutil.QueryIdentityProviderID, idpID)
|
||||||
|
if len(additionalHosts) > 0 {
|
||||||
|
queryParams.Set(urlutil.QueryAdditionalHosts, strings.Join(additionalHosts, ","))
|
||||||
|
}
|
||||||
otel.GetTextMapPropagator().Inject(ctx, trace.PomeriumURLQueryCarrier(queryParams))
|
otel.GetTextMapPropagator().Inject(ctx, trace.PomeriumURLQueryCarrier(queryParams))
|
||||||
signinURL.RawQuery = queryParams.Encode()
|
signinURL.RawQuery = queryParams.Encode()
|
||||||
redirectTo := urlutil.NewSignedURL(s.sharedKey, signinURL).String()
|
redirectTo := urlutil.NewSignedURL(s.sharedKey, signinURL).String()
|
||||||
|
@ -387,6 +397,23 @@ func (s *Stateful) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||||
redirectURL.RawQuery = q.Encode()
|
redirectURL.RawQuery = q.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect chaining for multi-domain login.
|
||||||
|
additionalHosts := r.URL.Query().Get(urlutil.QueryAdditionalHosts)
|
||||||
|
if additionalHosts != "" {
|
||||||
|
nextHops := strings.Split(additionalHosts, ",")
|
||||||
|
log.Ctx(r.Context()).Debug().Strs("next-hops", nextHops).Msg("multi-domain login callback")
|
||||||
|
|
||||||
|
callbackURL, err := urlutil.GetCallbackURL(r, encryptedSession, nextHops[1:])
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusInternalServerError,
|
||||||
|
fmt.Errorf("proxy: couldn't get next hop callback URL: %w", err))
|
||||||
|
}
|
||||||
|
callbackURL.Host = nextHops[0]
|
||||||
|
signedCallbackURL := urlutil.NewSignedURL(s.sharedKey, callbackURL)
|
||||||
|
httputil.Redirect(w, r, signedCallbackURL.String(), http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// redirect
|
// redirect
|
||||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -128,7 +128,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) {
|
||||||
|
|
||||||
t.Run("NilQueryParams", func(t *testing.T) {
|
t.Run("NilQueryParams", func(t *testing.T) {
|
||||||
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
||||||
u, err := flow.AuthenticateSignInURL(context.Background(), nil, redirectURL, "fake-idp-id")
|
u, err := flow.AuthenticateSignInURL(context.Background(), nil, redirectURL, "fake-idp-id", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
parsed, _ := url.Parse(u)
|
parsed, _ := url.Parse(u)
|
||||||
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
||||||
|
@ -143,7 +143,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) {
|
||||||
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Set("foo", "bar")
|
q.Set("foo", "bar")
|
||||||
u, err := flow.AuthenticateSignInURL(context.Background(), q, redirectURL, "fake-idp-id")
|
u, err := flow.AuthenticateSignInURL(context.Background(), q, redirectURL, "fake-idp-id", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
parsed, _ := url.Parse(u)
|
parsed, _ := url.Parse(u)
|
||||||
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
||||||
|
@ -155,6 +155,21 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) {
|
||||||
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
||||||
assert.Equal(t, "bar", q.Get("foo"))
|
assert.Equal(t, "bar", q.Get("foo"))
|
||||||
})
|
})
|
||||||
|
t.Run("AdditionalHosts", func(t *testing.T) {
|
||||||
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
||||||
|
additionalHosts := []string{"foo.example.com", "bar.example.com:1234"}
|
||||||
|
u, err := flow.AuthenticateSignInURL(context.Background(), nil, redirectURL, "fake-idp-id", additionalHosts)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
parsed, _ := url.Parse(u)
|
||||||
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
||||||
|
assert.Equal(t, "https", parsed.Scheme)
|
||||||
|
assert.Equal(t, "authenticate.example.com", parsed.Host)
|
||||||
|
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
|
||||||
|
q := parsed.Query()
|
||||||
|
assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri"))
|
||||||
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
||||||
|
assert.Equal(t, "foo.example.com,bar.example.com:1234", q.Get("pomerium_additional_hosts"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) {
|
func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) {
|
||||||
|
@ -277,6 +292,7 @@ func TestStatefulCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
location, _ := url.Parse(w.Result().Header.Get("Location"))
|
location, _ := url.Parse(w.Result().Header.Get("Location"))
|
||||||
assert.Equal(t, "example.com", location.Host)
|
assert.Equal(t, "example.com", location.Host)
|
||||||
|
assert.Equal(t, "/", location.Path)
|
||||||
assert.Equal(t, "ok", location.Query().Get("pomerium_callback_uri"))
|
assert.Equal(t, "ok", location.Query().Get("pomerium_callback_uri"))
|
||||||
} else {
|
} else {
|
||||||
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
||||||
|
@ -287,6 +303,60 @@ func TestStatefulCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatefulCallback_AdditionalHosts(t *testing.T) {
|
||||||
|
opts := config.NewDefaultOptions()
|
||||||
|
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
||||||
|
sharedKey, _ := opts.GetSharedKey()
|
||||||
|
|
||||||
|
flow, err := NewStateful(
|
||||||
|
context.Background(),
|
||||||
|
trace.NewNoopTracerProvider(),
|
||||||
|
&config.Config{Options: opts},
|
||||||
|
&mstore.Store{Session: &sessions.State{}},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
redirectURI := "https://route.example.com/"
|
||||||
|
callbackURI := &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "route.example.com",
|
||||||
|
Path: "/.pomerium/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
urlutil.QuerySessionEncrypted: []string{goodEncryptionString},
|
||||||
|
urlutil.QueryRedirectURI: []string{redirectURI},
|
||||||
|
urlutil.QueryAdditionalHosts: []string{"foo.example.com,bar.example.com"},
|
||||||
|
}.Encode(),
|
||||||
|
}
|
||||||
|
signedCallbackURI := urlutil.NewSignedURL(sharedKey, callbackURI)
|
||||||
|
|
||||||
|
doCallback := func(uri string) *http.Response {
|
||||||
|
t.Helper()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, uri, nil)
|
||||||
|
r.Host = r.URL.Host
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
err = flow.Callback(w, r)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return w.Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callback() should serve redirects to the additional hosts before the final redirect URI.
|
||||||
|
res := doCallback(signedCallbackURI.String())
|
||||||
|
location, _ := url.Parse(res.Header.Get("Location"))
|
||||||
|
assert.Equal(t, "foo.example.com", location.Host)
|
||||||
|
assert.Equal(t, "/.pomerium/callback/", location.Path)
|
||||||
|
|
||||||
|
res = doCallback(location.String())
|
||||||
|
location, _ = url.Parse(res.Header.Get("Location"))
|
||||||
|
assert.Equal(t, "bar.example.com", location.Host)
|
||||||
|
assert.Equal(t, "/.pomerium/callback/", location.Path)
|
||||||
|
|
||||||
|
res = doCallback(location.String())
|
||||||
|
location, _ = url.Parse(res.Header.Get("Location"))
|
||||||
|
assert.Equal(t, "route.example.com", location.Host)
|
||||||
|
assert.Equal(t, "/", location.Path)
|
||||||
|
}
|
||||||
|
|
||||||
func TestStatefulRevokeSession(t *testing.T) {
|
func TestStatefulRevokeSession(t *testing.T) {
|
||||||
opts := config.NewDefaultOptions()
|
opts := config.NewDefaultOptions()
|
||||||
flow, err := NewStateful(context.Background(), trace.NewNoopTracerProvider(), &config.Config{Options: opts}, nil)
|
flow, err := NewStateful(context.Background(), trace.NewNoopTracerProvider(), &config.Config{Options: opts}, nil)
|
||||||
|
|
|
@ -355,7 +355,11 @@ func getUserClaim(profile *identitypb.Profile, field string) *string {
|
||||||
// AuthenticateSignInURL returns a URL to redirect the user to the authenticate
|
// AuthenticateSignInURL returns a URL to redirect the user to the authenticate
|
||||||
// domain.
|
// domain.
|
||||||
func (s *Stateless) AuthenticateSignInURL(
|
func (s *Stateless) AuthenticateSignInURL(
|
||||||
ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string,
|
ctx context.Context,
|
||||||
|
queryParams url.Values,
|
||||||
|
redirectURL *url.URL,
|
||||||
|
idpID string,
|
||||||
|
_ []string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
authenticateHPKEPublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx)
|
authenticateHPKEPublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -4,19 +4,15 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrMissingRedirectURI indicates the pomerium_redirect_uri was missing from the query string.
|
// ErrMissingRedirectURI indicates the pomerium_redirect_uri was missing from the query string.
|
||||||
var ErrMissingRedirectURI = errors.New("missing " + QueryRedirectURI)
|
var ErrMissingRedirectURI = errors.New("missing " + QueryRedirectURI)
|
||||||
|
|
||||||
// GetCallbackURL gets the proxy's callback URL from a request and a base64url encoded + encrypted session state JWT.
|
// GetCallbackURL gets the proxy's callback URL from a request and a base64url encoded + encrypted session state JWT.
|
||||||
func GetCallbackURL(r *http.Request, encodedSessionJWT string) (*url.URL, error) {
|
func GetCallbackURL(r *http.Request, encodedSessionJWT string, additionalHosts []string) (*url.URL, error) {
|
||||||
return GetCallbackURLForRedirectURI(r, encodedSessionJWT, r.FormValue(QueryRedirectURI))
|
rawRedirectURI := r.FormValue(QueryRedirectURI)
|
||||||
}
|
|
||||||
|
|
||||||
// GetCallbackURLForRedirectURI gets the proxy's callback URL from a request and a base64url encoded + encrypted session
|
|
||||||
// state JWT.
|
|
||||||
func GetCallbackURLForRedirectURI(r *http.Request, encodedSessionJWT, rawRedirectURI string) (*url.URL, error) {
|
|
||||||
if rawRedirectURI == "" {
|
if rawRedirectURI == "" {
|
||||||
return nil, ErrMissingRedirectURI
|
return nil, ErrMissingRedirectURI
|
||||||
}
|
}
|
||||||
|
@ -55,6 +51,10 @@ func GetCallbackURLForRedirectURI(r *http.Request, encodedSessionJWT, rawRedirec
|
||||||
callbackParams.Set(QueryTracestate, tracestate)
|
callbackParams.Set(QueryTracestate, tracestate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(additionalHosts) > 0 {
|
||||||
|
callbackParams.Set(QueryAdditionalHosts, strings.Join(additionalHosts, ","))
|
||||||
|
}
|
||||||
|
|
||||||
// add our encoded and encrypted route-session JWT to a query param
|
// add our encoded and encrypted route-session JWT to a query param
|
||||||
callbackParams.Set(QuerySessionEncrypted, encodedSessionJWT)
|
callbackParams.Set(QuerySessionEncrypted, encodedSessionJWT)
|
||||||
callbackParams.Set(QueryRedirectURI, redirectURI.String())
|
callbackParams.Set(QueryRedirectURI, redirectURI.String())
|
||||||
|
|
|
@ -4,6 +4,7 @@ package urlutil
|
||||||
// services over HTTP calls and redirects. They are typically used in
|
// services over HTTP calls and redirects. They are typically used in
|
||||||
// conjunction with a HMAC to ensure authenticity.
|
// conjunction with a HMAC to ensure authenticity.
|
||||||
const (
|
const (
|
||||||
|
QueryAdditionalHosts = "pomerium_additional_hosts"
|
||||||
QueryCallbackURI = "pomerium_callback_uri"
|
QueryCallbackURI = "pomerium_callback_uri"
|
||||||
QueryDeviceCredentialID = "pomerium_device_credential_id"
|
QueryDeviceCredentialID = "pomerium_device_credential_id"
|
||||||
QueryDeviceType = "pomerium_device_type"
|
QueryDeviceType = "pomerium_device_type"
|
||||||
|
|
|
@ -149,7 +149,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
|
||||||
q.Set(urlutil.QueryIsProgrammatic, "true")
|
q.Set(urlutil.QueryIsProgrammatic, "true")
|
||||||
|
|
||||||
rawURL, err := state.authenticateFlow.AuthenticateSignInURL(
|
rawURL, err := state.authenticateFlow.AuthenticateSignInURL(
|
||||||
r.Context(), q, redirectURI, idp.GetId())
|
r.Context(), q, redirectURI, idp.GetId(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusInternalServerError, err)
|
return httputil.NewError(http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
|
||||||
type authenticateFlow interface {
|
type authenticateFlow interface {
|
||||||
AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error)
|
AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, additionalHosts []string) (string, error)
|
||||||
Callback(w http.ResponseWriter, r *http.Request) error
|
Callback(w http.ResponseWriter, r *http.Request) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue