mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 11:22:45 +02:00
authenticate: support webauthn redirects to non-pomerium domains (#2936)
* authenticate: support webauthn redirects to non-pomerium domains * add test * remove dead code
This commit is contained in:
parent
6b26f58e4f
commit
95d6d97143
5 changed files with 191 additions and 93 deletions
|
@ -686,9 +686,15 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
|||
return nil, err
|
||||
}
|
||||
|
||||
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &webauthn.State{
|
||||
SharedKey: state.sharedKey,
|
||||
Client: state.dataBrokerClient,
|
||||
PomeriumDomains: pomeriumDomains,
|
||||
Session: s,
|
||||
SessionState: ss,
|
||||
SessionStore: state.sessionStore,
|
||||
|
|
|
@ -52,6 +52,7 @@ var (
|
|||
type State struct {
|
||||
SharedKey []byte
|
||||
Client databroker.DataBrokerServiceClient
|
||||
PomeriumDomains []string
|
||||
Session *session.Session
|
||||
SessionState *sessions.State
|
||||
SessionStore sessions.SessionStore
|
||||
|
@ -392,6 +393,12 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat
|
|||
}
|
||||
|
||||
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
|
||||
// if the redirect URL is for a URL we don't control, just do a plain redirect
|
||||
if !isURLForPomerium(state.PomeriumDomains, rawRedirectURI) {
|
||||
httputil.Redirect(w, r, rawRedirectURI, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// save the session to the databroker
|
||||
res, err := session.Put(r.Context(), state.Client, state.Session)
|
||||
if err != nil {
|
||||
|
@ -513,3 +520,18 @@ func getOrCreateDeviceEnrollment(
|
|||
}
|
||||
return deviceEnrollment, nil
|
||||
}
|
||||
|
||||
func isURLForPomerium(pomeriumDomains []string, rawURI string) bool {
|
||||
uri, err := urlutil.ParseAndValidateURL(rawURI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, domain := range pomeriumDomains {
|
||||
if urlutil.StripPort(domain) == urlutil.StripPort(uri.Host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -734,89 +734,15 @@ func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDo
|
|||
}
|
||||
|
||||
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
|
||||
forwardAuthURL, err := options.GetForwardAuthURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
switch addr {
|
||||
case options.Addr:
|
||||
return options.GetAllRouteableHTTPDomains()
|
||||
case options.GetGRPCAddr():
|
||||
return options.GetAllRouteableGRPCDomains()
|
||||
}
|
||||
|
||||
lookup := map[string]struct{}{}
|
||||
if config.IsAuthenticate(options.Services) && addr == options.Addr {
|
||||
authenticateURL, err := options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// authorize urls
|
||||
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
|
||||
authorizeURLs, err := options.GetAuthorizeURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else if config.IsAuthorize(options.Services) && addr == options.GetGRPCAddr() {
|
||||
authorizeURLs, err := options.GetInternalAuthorizeURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// databroker urls
|
||||
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
|
||||
dataBrokerURLs, err := options.GetDataBrokerURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else if config.IsDataBroker(options.Services) && addr == options.GetGRPCAddr() {
|
||||
dataBrokerURLs, err := options.GetInternalDataBrokerURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policy urls
|
||||
if config.IsProxy(options.Services) && addr == options.Addr {
|
||||
for _, policy := range options.GetAllPolicies() {
|
||||
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
if forwardAuthURL != nil {
|
||||
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
domains := make([]string, 0, len(lookup))
|
||||
for domain := range lookup {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
sort.Strings(domains)
|
||||
|
||||
return domains, nil
|
||||
// no other domains supported
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -1041,6 +1042,106 @@ func (o *Options) GetCodecType() CodecType {
|
|||
return o.CodecType
|
||||
}
|
||||
|
||||
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
|
||||
lookup := map[string]struct{}{}
|
||||
|
||||
// authorize urls
|
||||
if IsAll(o.Services) {
|
||||
authorizeURLs, err := o.GetAuthorizeURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else if IsAuthorize(o.Services) {
|
||||
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// databroker urls
|
||||
if IsAll(o.Services) {
|
||||
dataBrokerURLs, err := o.GetDataBrokerURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else if IsDataBroker(o.Services) {
|
||||
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
domains := make([]string, 0, len(lookup))
|
||||
for domain := range lookup {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
sort.Strings(domains)
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
|
||||
forwardAuthURL, err := o.GetForwardAuthURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lookup := map[string]struct{}{}
|
||||
if IsAuthenticate(o.Services) {
|
||||
authenticateURL, err := o.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// policy urls
|
||||
if IsProxy(o.Services) {
|
||||
for _, policy := range o.GetAllPolicies() {
|
||||
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
if forwardAuthURL != nil {
|
||||
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
|
||||
lookup[h] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
domains := make([]string, 0, len(lookup))
|
||||
for domain := range lookup {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
sort.Strings(domains)
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
// Checksum returns the checksum of the current options struct
|
||||
func (o *Options) Checksum() uint64 {
|
||||
return hashutil.MustHash(o)
|
||||
|
|
|
@ -684,6 +684,49 @@ func TestOptions_GetOauthOptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
|
||||
}
|
||||
func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
|
||||
opts := &Options{
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
AuthorizeURLString: "https://authorize.example.com",
|
||||
DataBrokerURLString: "https://databroker.example.com",
|
||||
Services: "all",
|
||||
}
|
||||
domains, err := opts.GetAllRouteableGRPCDomains()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []string{
|
||||
"authorize.example.com",
|
||||
"authorize.example.com:443",
|
||||
"databroker.example.com",
|
||||
"databroker.example.com:443",
|
||||
}, domains)
|
||||
}
|
||||
|
||||
func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
|
||||
p1 := Policy{From: "https://from1.example.com"}
|
||||
p1.Validate()
|
||||
p2 := Policy{From: "https://from2.example.com"}
|
||||
p2.Validate()
|
||||
|
||||
opts := &Options{
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
AuthorizeURLString: "https://authorize.example.com",
|
||||
DataBrokerURLString: "https://databroker.example.com",
|
||||
Policies: []Policy{p1, p2},
|
||||
Services: "all",
|
||||
}
|
||||
domains, err := opts.GetAllRouteableHTTPDomains()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []string{
|
||||
"authenticate.example.com",
|
||||
"authenticate.example.com:443",
|
||||
"from1.example.com",
|
||||
"from1.example.com:443",
|
||||
"from2.example.com",
|
||||
"from2.example.com:443",
|
||||
}, domains)
|
||||
}
|
||||
|
||||
func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL {
|
||||
wu, err := ParseWeightedUrls(urls...)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue