mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 11:22:45 +02:00
authenticate: support webauthn redirects to non-pomerium domains (#2936)
* authenticate: support webauthn redirects to non-pomerium domains * add test * remove dead code
This commit is contained in:
parent
6b26f58e4f
commit
95d6d97143
5 changed files with 191 additions and 93 deletions
|
@ -686,13 +686,19 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &webauthn.State{
|
return &webauthn.State{
|
||||||
SharedKey: state.sharedKey,
|
SharedKey: state.sharedKey,
|
||||||
Client: state.dataBrokerClient,
|
Client: state.dataBrokerClient,
|
||||||
Session: s,
|
PomeriumDomains: pomeriumDomains,
|
||||||
SessionState: ss,
|
Session: s,
|
||||||
SessionStore: state.sessionStore,
|
SessionState: ss,
|
||||||
RelyingParty: state.webauthnRelyingParty,
|
SessionStore: state.sessionStore,
|
||||||
|
RelyingParty: state.webauthnRelyingParty,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,12 +50,13 @@ var (
|
||||||
|
|
||||||
// State is the state needed by the Handler to handle requests.
|
// State is the state needed by the Handler to handle requests.
|
||||||
type State struct {
|
type State struct {
|
||||||
SharedKey []byte
|
SharedKey []byte
|
||||||
Client databroker.DataBrokerServiceClient
|
Client databroker.DataBrokerServiceClient
|
||||||
Session *session.Session
|
PomeriumDomains []string
|
||||||
SessionState *sessions.State
|
Session *session.Session
|
||||||
SessionStore sessions.SessionStore
|
SessionState *sessions.State
|
||||||
RelyingParty *webauthn.RelyingParty
|
SessionStore sessions.SessionStore
|
||||||
|
RelyingParty *webauthn.RelyingParty
|
||||||
}
|
}
|
||||||
|
|
||||||
// A StateProvider provides state for the handler.
|
// A StateProvider provides state for the handler.
|
||||||
|
@ -392,6 +393,12 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
|
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
|
||||||
|
// if the redirect URL is for a URL we don't control, just do a plain redirect
|
||||||
|
if !isURLForPomerium(state.PomeriumDomains, rawRedirectURI) {
|
||||||
|
httputil.Redirect(w, r, rawRedirectURI, http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// save the session to the databroker
|
// save the session to the databroker
|
||||||
res, err := session.Put(r.Context(), state.Client, state.Session)
|
res, err := session.Put(r.Context(), state.Client, state.Session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -513,3 +520,18 @@ func getOrCreateDeviceEnrollment(
|
||||||
}
|
}
|
||||||
return deviceEnrollment, nil
|
return deviceEnrollment, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isURLForPomerium(pomeriumDomains []string, rawURI string) bool {
|
||||||
|
uri, err := urlutil.ParseAndValidateURL(rawURI)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range pomeriumDomains {
|
||||||
|
if urlutil.StripPort(domain) == urlutil.StripPort(uri.Host) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -734,89 +734,15 @@ func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDo
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
|
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
|
||||||
forwardAuthURL, err := options.GetForwardAuthURL()
|
switch addr {
|
||||||
if err != nil {
|
case options.Addr:
|
||||||
return nil, err
|
return options.GetAllRouteableHTTPDomains()
|
||||||
|
case options.GetGRPCAddr():
|
||||||
|
return options.GetAllRouteableGRPCDomains()
|
||||||
}
|
}
|
||||||
|
|
||||||
lookup := map[string]struct{}{}
|
// no other domains supported
|
||||||
if config.IsAuthenticate(options.Services) && addr == options.Addr {
|
return nil, nil
|
||||||
authenticateURL, err := options.GetInternalAuthenticateURL()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
|
|
||||||
lookup[h] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// authorize urls
|
|
||||||
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
|
|
||||||
authorizeURLs, err := options.GetAuthorizeURLs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, u := range authorizeURLs {
|
|
||||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
|
||||||
lookup[h] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if config.IsAuthorize(options.Services) && addr == options.GetGRPCAddr() {
|
|
||||||
authorizeURLs, err := options.GetInternalAuthorizeURLs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, u := range authorizeURLs {
|
|
||||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
|
||||||
lookup[h] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// databroker urls
|
|
||||||
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
|
|
||||||
dataBrokerURLs, err := options.GetDataBrokerURLs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, u := range dataBrokerURLs {
|
|
||||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
|
||||||
lookup[h] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if config.IsDataBroker(options.Services) && addr == options.GetGRPCAddr() {
|
|
||||||
dataBrokerURLs, err := options.GetInternalDataBrokerURLs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, u := range dataBrokerURLs {
|
|
||||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
|
||||||
lookup[h] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// policy urls
|
|
||||||
if config.IsProxy(options.Services) && addr == options.Addr {
|
|
||||||
for _, policy := range options.GetAllPolicies() {
|
|
||||||
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
|
|
||||||
lookup[h] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if forwardAuthURL != nil {
|
|
||||||
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
|
|
||||||
lookup[h] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
domains := make([]string, 0, len(lookup))
|
|
||||||
for domain := range lookup {
|
|
||||||
domains = append(domains, domain)
|
|
||||||
}
|
|
||||||
sort.Strings(domains)
|
|
||||||
|
|
||||||
return domains, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
|
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -1041,6 +1042,106 @@ func (o *Options) GetCodecType() CodecType {
|
||||||
return o.CodecType
|
return o.CodecType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options.
|
||||||
|
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
|
||||||
|
lookup := map[string]struct{}{}
|
||||||
|
|
||||||
|
// authorize urls
|
||||||
|
if IsAll(o.Services) {
|
||||||
|
authorizeURLs, err := o.GetAuthorizeURLs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, u := range authorizeURLs {
|
||||||
|
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||||
|
lookup[h] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if IsAuthorize(o.Services) {
|
||||||
|
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, u := range authorizeURLs {
|
||||||
|
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||||
|
lookup[h] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// databroker urls
|
||||||
|
if IsAll(o.Services) {
|
||||||
|
dataBrokerURLs, err := o.GetDataBrokerURLs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, u := range dataBrokerURLs {
|
||||||
|
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||||
|
lookup[h] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if IsDataBroker(o.Services) {
|
||||||
|
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, u := range dataBrokerURLs {
|
||||||
|
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||||
|
lookup[h] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domains := make([]string, 0, len(lookup))
|
||||||
|
for domain := range lookup {
|
||||||
|
domains = append(domains, domain)
|
||||||
|
}
|
||||||
|
sort.Strings(domains)
|
||||||
|
|
||||||
|
return domains, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options.
|
||||||
|
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
|
||||||
|
forwardAuthURL, err := o.GetForwardAuthURL()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lookup := map[string]struct{}{}
|
||||||
|
if IsAuthenticate(o.Services) {
|
||||||
|
authenticateURL, err := o.GetInternalAuthenticateURL()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
|
||||||
|
lookup[h] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// policy urls
|
||||||
|
if IsProxy(o.Services) {
|
||||||
|
for _, policy := range o.GetAllPolicies() {
|
||||||
|
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
|
||||||
|
lookup[h] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if forwardAuthURL != nil {
|
||||||
|
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
|
||||||
|
lookup[h] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domains := make([]string, 0, len(lookup))
|
||||||
|
for domain := range lookup {
|
||||||
|
domains = append(domains, domain)
|
||||||
|
}
|
||||||
|
sort.Strings(domains)
|
||||||
|
|
||||||
|
return domains, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Checksum returns the checksum of the current options struct
|
// Checksum returns the checksum of the current options struct
|
||||||
func (o *Options) Checksum() uint64 {
|
func (o *Options) Checksum() uint64 {
|
||||||
return hashutil.MustHash(o)
|
return hashutil.MustHash(o)
|
||||||
|
|
|
@ -684,6 +684,49 @@ func TestOptions_GetOauthOptions(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
|
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
|
||||||
}
|
}
|
||||||
|
func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
|
||||||
|
opts := &Options{
|
||||||
|
AuthenticateURLString: "https://authenticate.example.com",
|
||||||
|
AuthorizeURLString: "https://authorize.example.com",
|
||||||
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
Services: "all",
|
||||||
|
}
|
||||||
|
domains, err := opts.GetAllRouteableGRPCDomains()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"authorize.example.com",
|
||||||
|
"authorize.example.com:443",
|
||||||
|
"databroker.example.com",
|
||||||
|
"databroker.example.com:443",
|
||||||
|
}, domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
|
||||||
|
p1 := Policy{From: "https://from1.example.com"}
|
||||||
|
p1.Validate()
|
||||||
|
p2 := Policy{From: "https://from2.example.com"}
|
||||||
|
p2.Validate()
|
||||||
|
|
||||||
|
opts := &Options{
|
||||||
|
AuthenticateURLString: "https://authenticate.example.com",
|
||||||
|
AuthorizeURLString: "https://authorize.example.com",
|
||||||
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
Policies: []Policy{p1, p2},
|
||||||
|
Services: "all",
|
||||||
|
}
|
||||||
|
domains, err := opts.GetAllRouteableHTTPDomains()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"authenticate.example.com",
|
||||||
|
"authenticate.example.com:443",
|
||||||
|
"from1.example.com",
|
||||||
|
"from1.example.com:443",
|
||||||
|
"from2.example.com",
|
||||||
|
"from2.example.com:443",
|
||||||
|
}, domains)
|
||||||
|
}
|
||||||
|
|
||||||
func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL {
|
func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL {
|
||||||
wu, err := ParseWeightedUrls(urls...)
|
wu, err := ParseWeightedUrls(urls...)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue