diff --git a/authorize/state.go b/authorize/state.go index 352544674..c39ee617f 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -3,14 +3,12 @@ package authorize import ( "context" "fmt" - "net/url" googlegrpc "google.golang.org/grpc" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/hpke" @@ -79,28 +77,11 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho return nil, fmt.Errorf("authorize: invalid session store: %w", err) } - authenticateURL, err := cfg.Options.GetAuthenticateURL() - if err != nil { - return nil, fmt.Errorf("authorize: invalid authenticate service url: %w", err) - } - state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey) - - jwksURL := authenticateURL.ResolveReference(&url.URL{ - Path: "/.well-known/pomerium/jwks.json", - }).String() - transport := httputil.GetInsecureTransport() - ok, err := cfg.WillHaveCertificateForServerName(authenticateURL.Hostname()) + state.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher() if err != nil { - return nil, fmt.Errorf("authorize: error determining if authenticate service will have a certificate name: %w", err) + return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err) } - if ok { - transport, err = config.GetTLSClientTransport(cfg) - if err != nil { - return nil, fmt.Errorf("authorize: get tls client config: %w", err) - } - } - state.authenticateKeyFetcher = hpke.NewKeyFetcher(jwksURL, transport) return state, nil } diff --git a/config/config.go b/config/config.go index 5e3b56b17..38d220253 100644 --- a/config/config.go +++ b/config/config.go @@ -6,13 +6,17 @@ import ( "crypto/x509" "encoding/base64" "fmt" + "net/http" + "net/url" "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/hashutil" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/derivecert" + "github.com/pomerium/pomerium/pkg/hpke" ) // MetricsScrapeEndpoint defines additional metrics endpoints that would be scraped and exposed by pomerium @@ -236,3 +240,36 @@ func (cfg *Config) GetCertificatePool() (*x509.CertPool, error) { return pool, nil } + +// GetAuthenticateKeyFetcher returns a key fetcher for the authenticate service +func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) { + authenticateURL, transport, err := cfg.resolveAuthenticateURL() + if err != nil { + return nil, err + } + jwksURL := authenticateURL.ResolveReference(&url.URL{ + Path: "/.well-known/pomerium/jwks.json", + }).String() + return hpke.NewKeyFetcher(jwksURL, transport), nil +} + +func (cfg *Config) resolveAuthenticateURL() (*url.URL, *http.Transport, error) { + authenticateURL, err := cfg.Options.GetInternalAuthenticateURL() + if err != nil { + return nil, nil, fmt.Errorf("invalid authenticate service url: %w", err) + } + ok, err := cfg.WillHaveCertificateForServerName(authenticateURL.Hostname()) + if err != nil { + return nil, nil, fmt.Errorf("error determining if authenticate service will have a certificate name: %w", err) + } + if !ok { + return authenticateURL, httputil.GetInsecureTransport(), nil + } + + transport, err := GetTLSClientTransport(cfg) + if err != nil { + return nil, nil, fmt.Errorf("get tls client config: %w", err) + } + + return authenticateURL, transport, nil +} diff --git a/config/envoyconfig/routes.go b/config/envoyconfig/routes.go index 96758cf2c..b638ca9d2 100644 --- a/config/envoyconfig/routes.go +++ b/config/envoyconfig/routes.go @@ -75,7 +75,8 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, host string) } } // if we're handling authentication, add the oauth2 callback url - authenticateURL, err := options.GetInternalAuthenticateURL() + // as the callback url is from the IdP, it is expected only on the public authenticate URL endpoint + authenticateURL, err := options.GetAuthenticateURL() if err != nil { return nil, err } diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index b501cdc73..bcba8e862 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -5,6 +5,7 @@ import ( "net" "net/http" "net/http/pprof" + "net/url" "time" envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" @@ -289,12 +290,20 @@ func (srv *Server) updateRouter(cfg *config.Config) error { return err } if srv.authenticateSvc != nil { - authenticateURL, err := cfg.Options.GetInternalAuthenticateURL() - if err != nil { - return err + seen := make(map[string]struct{}) + // mount auth handler for both internal and external endpoints + for _, fn := range []func() (*url.URL, error){cfg.Options.GetAuthenticateURL, cfg.Options.GetInternalAuthenticateURL} { + authenticateURL, err := fn() + if err != nil { + return err + } + authenticateHost := urlutil.StripPort(authenticateURL.Host) + if _, ok := seen[authenticateHost]; ok { + continue + } + seen[authenticateHost] = struct{}{} + srv.authenticateSvc.Mount(httpRouter.Host(authenticateHost).Subrouter()) } - authenticateHost := urlutil.StripPort(authenticateURL.Host) - srv.authenticateSvc.Mount(httpRouter.Host(authenticateHost).Subrouter()) } if srv.proxySvc != nil { srv.proxySvc.Mount(httpRouter) diff --git a/proxy/state.go b/proxy/state.go index 0bcdc8016..eacdc43fc 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -9,7 +9,6 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" - "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/pkg/cryptutil" @@ -49,11 +48,6 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { state := new(proxyState) - authenticateURL, err := cfg.Options.GetAuthenticateURL() - if err != nil { - return nil, err - } - state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err @@ -64,21 +58,10 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { return nil, err } - jwksURL := authenticateURL.ResolveReference(&url.URL{ - Path: "/.well-known/pomerium/jwks.json", - }).String() - transport := httputil.GetInsecureTransport() - ok, err := cfg.WillHaveCertificateForServerName(authenticateURL.Hostname()) + state.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher() if err != nil { - return nil, fmt.Errorf("proxy: error determining if authenticate service will have a certificate name: %w", err) + return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err) } - if ok { - transport, err = config.GetTLSClientTransport(cfg) - if err != nil { - return nil, fmt.Errorf("proxy: get tls client config: %w", err) - } - } - state.authenticateKeyFetcher = hpke.NewKeyFetcher(jwksURL, transport) state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey) if err != nil {