authenticate: support webauthn redirects to non-pomerium domains (#2936)

* authenticate: support webauthn redirects to non-pomerium domains

* add test

* remove dead code
This commit is contained in:
Caleb Doxsey 2022-01-19 15:10:57 -07:00 committed by GitHub
parent 6b26f58e4f
commit 95d6d97143
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 191 additions and 93 deletions

View file

@ -686,13 +686,19 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
return nil, err return nil, err
} }
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains()
if err != nil {
return nil, err
}
return &webauthn.State{ return &webauthn.State{
SharedKey: state.sharedKey, SharedKey: state.sharedKey,
Client: state.dataBrokerClient, Client: state.dataBrokerClient,
Session: s, PomeriumDomains: pomeriumDomains,
SessionState: ss, Session: s,
SessionStore: state.sessionStore, SessionState: ss,
RelyingParty: state.webauthnRelyingParty, SessionStore: state.sessionStore,
RelyingParty: state.webauthnRelyingParty,
}, nil }, nil
} }

View file

@ -50,12 +50,13 @@ var (
// State is the state needed by the Handler to handle requests. // State is the state needed by the Handler to handle requests.
type State struct { type State struct {
SharedKey []byte SharedKey []byte
Client databroker.DataBrokerServiceClient Client databroker.DataBrokerServiceClient
Session *session.Session PomeriumDomains []string
SessionState *sessions.State Session *session.Session
SessionStore sessions.SessionStore SessionState *sessions.State
RelyingParty *webauthn.RelyingParty SessionStore sessions.SessionStore
RelyingParty *webauthn.RelyingParty
} }
// A StateProvider provides state for the handler. // A StateProvider provides state for the handler.
@ -392,6 +393,12 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat
} }
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error { func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
// if the redirect URL is for a URL we don't control, just do a plain redirect
if !isURLForPomerium(state.PomeriumDomains, rawRedirectURI) {
httputil.Redirect(w, r, rawRedirectURI, http.StatusFound)
return nil
}
// save the session to the databroker // save the session to the databroker
res, err := session.Put(r.Context(), state.Client, state.Session) res, err := session.Put(r.Context(), state.Client, state.Session)
if err != nil { if err != nil {
@ -513,3 +520,18 @@ func getOrCreateDeviceEnrollment(
} }
return deviceEnrollment, nil return deviceEnrollment, nil
} }
func isURLForPomerium(pomeriumDomains []string, rawURI string) bool {
uri, err := urlutil.ParseAndValidateURL(rawURI)
if err != nil {
return false
}
for _, domain := range pomeriumDomains {
if urlutil.StripPort(domain) == urlutil.StripPort(uri.Host) {
return true
}
}
return false
}

View file

@ -734,89 +734,15 @@ func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDo
} }
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) { func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
forwardAuthURL, err := options.GetForwardAuthURL() switch addr {
if err != nil { case options.Addr:
return nil, err return options.GetAllRouteableHTTPDomains()
case options.GetGRPCAddr():
return options.GetAllRouteableGRPCDomains()
} }
lookup := map[string]struct{}{} // no other domains supported
if config.IsAuthenticate(options.Services) && addr == options.Addr { return nil, nil
authenticateURL, err := options.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
lookup[h] = struct{}{}
}
}
// authorize urls
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
authorizeURLs, err := options.GetAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if config.IsAuthorize(options.Services) && addr == options.GetGRPCAddr() {
authorizeURLs, err := options.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
// databroker urls
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
dataBrokerURLs, err := options.GetDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if config.IsDataBroker(options.Services) && addr == options.GetGRPCAddr() {
dataBrokerURLs, err := options.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
// policy urls
if config.IsProxy(options.Services) && addr == options.Addr {
for _, policy := range options.GetAllPolicies() {
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
lookup[h] = struct{}{}
}
}
if forwardAuthURL != nil {
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
lookup[h] = struct{}{}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
} }
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {

View file

@ -11,6 +11,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"sort"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -1041,6 +1042,106 @@ func (o *Options) GetCodecType() CodecType {
return o.CodecType return o.CodecType
} }
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
lookup := map[string]struct{}{}
// authorize urls
if IsAll(o.Services) {
authorizeURLs, err := o.GetAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
// databroker urls
if IsAll(o.Services) {
dataBrokerURLs, err := o.GetDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
}
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
forwardAuthURL, err := o.GetForwardAuthURL()
if err != nil {
return nil, err
}
lookup := map[string]struct{}{}
if IsAuthenticate(o.Services) {
authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
lookup[h] = struct{}{}
}
}
// policy urls
if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() {
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
lookup[h] = struct{}{}
}
}
if forwardAuthURL != nil {
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
lookup[h] = struct{}{}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
}
// Checksum returns the checksum of the current options struct // Checksum returns the checksum of the current options struct
func (o *Options) Checksum() uint64 { func (o *Options) Checksum() uint64 {
return hashutil.MustHash(o) return hashutil.MustHash(o)

View file

@ -684,6 +684,49 @@ func TestOptions_GetOauthOptions(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname()) assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
} }
func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Services: "all",
}
domains, err := opts.GetAllRouteableGRPCDomains()
assert.NoError(t, err)
assert.Equal(t, []string{
"authorize.example.com",
"authorize.example.com:443",
"databroker.example.com",
"databroker.example.com:443",
}, domains)
}
func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
p1 := Policy{From: "https://from1.example.com"}
p1.Validate()
p2 := Policy{From: "https://from2.example.com"}
p2.Validate()
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Policies: []Policy{p1, p2},
Services: "all",
}
domains, err := opts.GetAllRouteableHTTPDomains()
assert.NoError(t, err)
assert.Equal(t, []string{
"authenticate.example.com",
"authenticate.example.com:443",
"from1.example.com",
"from1.example.com:443",
"from2.example.com",
"from2.example.com:443",
}, domains)
}
func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL { func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL {
wu, err := ParseWeightedUrls(urls...) wu, err := ParseWeightedUrls(urls...)