From fb06cd3c732ff4566c735f0aa35b6e7defcf7734 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 28 Feb 2025 09:59:03 -0700 Subject: [PATCH] proxy: add short timeout for logo discovery (#5506) --- proxy/handlers_portal.go | 35 +++++++++++++++++---------- proxy/portal/logo_provider_favicon.go | 19 +++++++++------ proxy/portal/logo_provider_test.go | 25 ++++++++++++++++--- 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/proxy/handlers_portal.go b/proxy/handlers_portal.go index 34523f542..8cf007f07 100644 --- a/proxy/handlers_portal.go +++ b/proxy/handlers_portal.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "net/http" + "sync" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/handlers" @@ -49,23 +50,31 @@ func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) [] } } portalRoutes := portal.RoutesFromConfigRoutes(routes) + + var wg sync.WaitGroup for i, pr := range portalRoutes { - r := routes[i] - for _, to := range r.To { - if pr.LogoURL == "" { - var err error - pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String()) - if err != nil && !errors.Is(err, portal.ErrLogoNotFound) { - log.Ctx(ctx).Error(). - Err(err). - Str("from", pr.From). - Str("to", to.URL.String()). - Msg("error retrieving logo for route") + wg.Add(1) + go func() { + defer wg.Done() + + r := routes[i] + for _, to := range r.To { + if pr.LogoURL == "" { + var err error + pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String()) + if err != nil && !errors.Is(err, portal.ErrLogoNotFound) { + log.Ctx(ctx).Error(). + Err(err). + Str("from", pr.From). + Str("to", to.URL.String()). + Msg("error retrieving logo for route") + } } } - } - portalRoutes[i] = pr + portalRoutes[i] = pr + }() } + wg.Wait() return portalRoutes } diff --git a/proxy/portal/logo_provider_favicon.go b/proxy/portal/logo_provider_favicon.go index 0d19e7142..2027abf95 100644 --- a/proxy/portal/logo_provider_favicon.go +++ b/proxy/portal/logo_provider_favicon.go @@ -27,17 +27,19 @@ type faviconCacheValue struct { } type faviconDiscoveryLogoProvider struct { - mu sync.Mutex - cache map[string]*faviconCacheValue - successTTL time.Duration - failureTTL time.Duration + mu sync.Mutex + cache map[string]*faviconCacheValue + successTTL time.Duration + failureTTL time.Duration + discoveryTimeout time.Duration } func newFaviconDiscoveryLogoProvider() *faviconDiscoveryLogoProvider { return &faviconDiscoveryLogoProvider{ - cache: make(map[string]*faviconCacheValue), - successTTL: time.Hour, - failureTTL: 10 * time.Minute, + cache: make(map[string]*faviconCacheValue), + successTTL: time.Hour, + failureTTL: 10 * time.Minute, + discoveryTimeout: 500 * time.Millisecond, } } @@ -85,6 +87,9 @@ func (p *faviconDiscoveryLogoProvider) discoverLogoURL(ctx context.Context, rawU return "", ErrLogoNotFound } + ctx, clearTimeout := context.WithTimeout(ctx, p.discoveryTimeout) + defer clearTimeout() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) if err != nil { return "", err diff --git a/proxy/portal/logo_provider_test.go b/proxy/portal/logo_provider_test.go index a14244b06..4dddc3afa 100644 --- a/proxy/portal/logo_provider_test.go +++ b/proxy/portal/logo_provider_test.go @@ -1,6 +1,7 @@ -package portal_test +package portal import ( + "context" "io" "net/http" "net/http/httptest" @@ -10,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/pomerium/pomerium/internal/testutil" - "github.com/pomerium/pomerium/proxy/portal" ) func TestLogoProvider(t *testing.T) { @@ -30,8 +30,27 @@ func TestLogoProvider(t *testing.T) { t.Cleanup(srv.Close) ctx := testutil.GetContext(t, time.Minute) - p := portal.NewLogoProvider() + p := NewLogoProvider() u, err := p.GetLogoURL(ctx, "", srv.URL) assert.NoError(t, err) assert.Equal(t, "data:image/vnd.microsoft.icon;base64,Tk9UIEFDVFVBTExZIEFOIElDT04=", u) } + +func TestLogoProvider_Timeout(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(time.Second): + } + http.NotFound(w, r) + })) + t.Cleanup(srv.Close) + + ctx := testutil.GetContext(t, time.Minute) + p := newFaviconDiscoveryLogoProvider() + p.discoveryTimeout = time.Millisecond + _, err := p.GetLogoURL(ctx, "", srv.URL) + assert.ErrorIs(t, err, context.DeadlineExceeded) +}