diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9c536a6cc..7e110b937 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -731,20 +731,26 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e return nil, err } + internalAuthenticateURL, err := a.options.Load().GetInternalAuthenticateURL() + if err != nil { + return nil, err + } + pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains() if err != nil { return nil, err } return &webauthn.State{ - AuthenticateURL: authenticateURL, - SharedKey: state.sharedKey, - Client: state.dataBrokerClient, - PomeriumDomains: pomeriumDomains, - Session: s, - SessionState: ss, - SessionStore: state.sessionStore, - RelyingParty: state.webauthnRelyingParty, + AuthenticateURL: authenticateURL, + InternalAuthenticateURL: internalAuthenticateURL, + SharedKey: state.sharedKey, + Client: state.dataBrokerClient, + PomeriumDomains: pomeriumDomains, + Session: s, + SessionState: ss, + SessionStore: state.sessionStore, + RelyingParty: state.webauthnRelyingParty, }, nil } diff --git a/authenticate/handlers/webauthn/webauthn.go b/authenticate/handlers/webauthn/webauthn.go index 37cea8ae1..846c082bc 100644 --- a/authenticate/handlers/webauthn/webauthn.go +++ b/authenticate/handlers/webauthn/webauthn.go @@ -47,14 +47,15 @@ var ( // State is the state needed by the Handler to handle requests. type State struct { - AuthenticateURL *url.URL - Client databroker.DataBrokerServiceClient - PomeriumDomains []string - RelyingParty *webauthn.RelyingParty - Session *session.Session - SessionState *sessions.State - SessionStore sessions.SessionStore - SharedKey []byte + AuthenticateURL *url.URL + InternalAuthenticateURL *url.URL + Client databroker.DataBrokerServiceClient + PomeriumDomains []string + RelyingParty *webauthn.RelyingParty + Session *session.Session + SessionState *sessions.State + SessionStore sessions.SessionStore + SharedKey []byte } // A StateProvider provides state for the handler. @@ -122,7 +123,10 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) error { return err } - err = middleware.ValidateRequestURL(r, s.SharedKey) + err = middleware.ValidateRequestURL( + urlutil.GetExternalRequest(s.InternalAuthenticateURL, s.AuthenticateURL, r), + s.SharedKey, + ) if err != nil { return err } diff --git a/authenticate/middleware.go b/authenticate/middleware.go index 5f4d91464..86cc3cbdb 100644 --- a/authenticate/middleware.go +++ b/authenticate/middleware.go @@ -46,18 +46,5 @@ func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request { return r } - // if we're not using a different internal URL there's nothing to do - if externalURL.String() == internalURL.String() { - return r - } - - // replace the internal host with the external host - er := r.Clone(r.Context()) - if er.URL.Host == internalURL.Host { - er.URL.Host = externalURL.Host - } - if er.Host == internalURL.Host { - er.Host = externalURL.Host - } - return er + return urlutil.GetExternalRequest(internalURL, externalURL, r) } diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index 0edecd794..9c8fa443b 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -141,3 +141,22 @@ func Join(elements ...string) string { } return builder.String() } + +// GetExternalRequest modifies a request so that it appears to be for an external URL instead of +// an internal URL. +func GetExternalRequest(internalURL, externalURL *url.URL, r *http.Request) *http.Request { + // if we're not using a different internal URL there's nothing to do + if externalURL.String() == internalURL.String() { + return r + } + + // replace the internal host with the external host + er := r.Clone(r.Context()) + if er.URL.Host == internalURL.Host { + er.URL.Host = externalURL.Host + } + if er.Host == internalURL.Host { + er.Host = externalURL.Host + } + return er +}