authenticate: support webauthn redirects to non-pomerium domains (#2936)

* authenticate: support webauthn redirects to non-pomerium domains

* add test

* remove dead code
This commit is contained in:
Caleb Doxsey 2022-01-19 15:10:57 -07:00 committed by GitHub
parent 6b26f58e4f
commit 95d6d97143
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 191 additions and 93 deletions

View file

@ -686,9 +686,15 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
return nil, err
}
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains()
if err != nil {
return nil, err
}
return &webauthn.State{
SharedKey: state.sharedKey,
Client: state.dataBrokerClient,
PomeriumDomains: pomeriumDomains,
Session: s,
SessionState: ss,
SessionStore: state.sessionStore,

View file

@ -52,6 +52,7 @@ var (
type State struct {
SharedKey []byte
Client databroker.DataBrokerServiceClient
PomeriumDomains []string
Session *session.Session
SessionState *sessions.State
SessionStore sessions.SessionStore
@ -392,6 +393,12 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat
}
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
// if the redirect URL is for a URL we don't control, just do a plain redirect
if !isURLForPomerium(state.PomeriumDomains, rawRedirectURI) {
httputil.Redirect(w, r, rawRedirectURI, http.StatusFound)
return nil
}
// save the session to the databroker
res, err := session.Put(r.Context(), state.Client, state.Session)
if err != nil {
@ -513,3 +520,18 @@ func getOrCreateDeviceEnrollment(
}
return deviceEnrollment, nil
}
func isURLForPomerium(pomeriumDomains []string, rawURI string) bool {
uri, err := urlutil.ParseAndValidateURL(rawURI)
if err != nil {
return false
}
for _, domain := range pomeriumDomains {
if urlutil.StripPort(domain) == urlutil.StripPort(uri.Host) {
return true
}
}
return false
}

View file

@ -734,89 +734,15 @@ func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDo
}
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
forwardAuthURL, err := options.GetForwardAuthURL()
if err != nil {
return nil, err
switch addr {
case options.Addr:
return options.GetAllRouteableHTTPDomains()
case options.GetGRPCAddr():
return options.GetAllRouteableGRPCDomains()
}
lookup := map[string]struct{}{}
if config.IsAuthenticate(options.Services) && addr == options.Addr {
authenticateURL, err := options.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
lookup[h] = struct{}{}
}
}
// authorize urls
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
authorizeURLs, err := options.GetAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if config.IsAuthorize(options.Services) && addr == options.GetGRPCAddr() {
authorizeURLs, err := options.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
// databroker urls
if config.IsAll(options.Services) && addr == options.GetGRPCAddr() {
dataBrokerURLs, err := options.GetDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if config.IsDataBroker(options.Services) && addr == options.GetGRPCAddr() {
dataBrokerURLs, err := options.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
// policy urls
if config.IsProxy(options.Services) && addr == options.Addr {
for _, policy := range options.GetAllPolicies() {
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
lookup[h] = struct{}{}
}
}
if forwardAuthURL != nil {
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
lookup[h] = struct{}{}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
// no other domains supported
return nil, nil
}
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {

View file

@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"sync/atomic"
"time"
@ -1041,6 +1042,106 @@ func (o *Options) GetCodecType() CodecType {
return o.CodecType
}
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
lookup := map[string]struct{}{}
// authorize urls
if IsAll(o.Services) {
authorizeURLs, err := o.GetAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
// databroker urls
if IsAll(o.Services) {
dataBrokerURLs, err := o.GetDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
} else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
}
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
forwardAuthURL, err := o.GetForwardAuthURL()
if err != nil {
return nil, err
}
lookup := map[string]struct{}{}
if IsAuthenticate(o.Services) {
authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
lookup[h] = struct{}{}
}
}
// policy urls
if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() {
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
lookup[h] = struct{}{}
}
}
if forwardAuthURL != nil {
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
lookup[h] = struct{}{}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
}
// Checksum returns the checksum of the current options struct
func (o *Options) Checksum() uint64 {
return hashutil.MustHash(o)

View file

@ -684,6 +684,49 @@ func TestOptions_GetOauthOptions(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
}
func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Services: "all",
}
domains, err := opts.GetAllRouteableGRPCDomains()
assert.NoError(t, err)
assert.Equal(t, []string{
"authorize.example.com",
"authorize.example.com:443",
"databroker.example.com",
"databroker.example.com:443",
}, domains)
}
func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
p1 := Policy{From: "https://from1.example.com"}
p1.Validate()
p2 := Policy{From: "https://from2.example.com"}
p2.Validate()
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Policies: []Policy{p1, p2},
Services: "all",
}
domains, err := opts.GetAllRouteableHTTPDomains()
assert.NoError(t, err)
assert.Equal(t, []string{
"authenticate.example.com",
"authenticate.example.com:443",
"from1.example.com",
"from1.example.com:443",
"from2.example.com",
"from2.example.com:443",
}, domains)
}
func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL {
wu, err := ParseWeightedUrls(urls...)