proxy: add short timeout for logo discovery (#5506)

This commit is contained in:
Caleb Doxsey 2025-02-28 09:59:03 -07:00 committed by GitHub
parent 624c8f0cea
commit fb06cd3c73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 23 deletions

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"net/http"
"sync"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/handlers"
@ -49,23 +50,31 @@ func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []
}
}
portalRoutes := portal.RoutesFromConfigRoutes(routes)
var wg sync.WaitGroup
for i, pr := range portalRoutes {
r := routes[i]
for _, to := range r.To {
if pr.LogoURL == "" {
var err error
pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String())
if err != nil && !errors.Is(err, portal.ErrLogoNotFound) {
log.Ctx(ctx).Error().
Err(err).
Str("from", pr.From).
Str("to", to.URL.String()).
Msg("error retrieving logo for route")
wg.Add(1)
go func() {
defer wg.Done()
r := routes[i]
for _, to := range r.To {
if pr.LogoURL == "" {
var err error
pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String())
if err != nil && !errors.Is(err, portal.ErrLogoNotFound) {
log.Ctx(ctx).Error().
Err(err).
Str("from", pr.From).
Str("to", to.URL.String()).
Msg("error retrieving logo for route")
}
}
}
}
portalRoutes[i] = pr
portalRoutes[i] = pr
}()
}
wg.Wait()
return portalRoutes
}

View file

@ -27,17 +27,19 @@ type faviconCacheValue struct {
}
type faviconDiscoveryLogoProvider struct {
mu sync.Mutex
cache map[string]*faviconCacheValue
successTTL time.Duration
failureTTL time.Duration
mu sync.Mutex
cache map[string]*faviconCacheValue
successTTL time.Duration
failureTTL time.Duration
discoveryTimeout time.Duration
}
func newFaviconDiscoveryLogoProvider() *faviconDiscoveryLogoProvider {
return &faviconDiscoveryLogoProvider{
cache: make(map[string]*faviconCacheValue),
successTTL: time.Hour,
failureTTL: 10 * time.Minute,
cache: make(map[string]*faviconCacheValue),
successTTL: time.Hour,
failureTTL: 10 * time.Minute,
discoveryTimeout: 500 * time.Millisecond,
}
}
@ -85,6 +87,9 @@ func (p *faviconDiscoveryLogoProvider) discoverLogoURL(ctx context.Context, rawU
return "", ErrLogoNotFound
}
ctx, clearTimeout := context.WithTimeout(ctx, p.discoveryTimeout)
defer clearTimeout()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", err

View file

@ -1,6 +1,7 @@
package portal_test
package portal
import (
"context"
"io"
"net/http"
"net/http/httptest"
@ -10,7 +11,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/proxy/portal"
)
func TestLogoProvider(t *testing.T) {
@ -30,8 +30,27 @@ func TestLogoProvider(t *testing.T) {
t.Cleanup(srv.Close)
ctx := testutil.GetContext(t, time.Minute)
p := portal.NewLogoProvider()
p := NewLogoProvider()
u, err := p.GetLogoURL(ctx, "", srv.URL)
assert.NoError(t, err)
assert.Equal(t, "", u)
}
func TestLogoProvider_Timeout(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
case <-time.After(time.Second):
}
http.NotFound(w, r)
}))
t.Cleanup(srv.Close)
ctx := testutil.GetContext(t, time.Minute)
p := newFaviconDiscoveryLogoProvider()
p.discoveryTimeout = time.Millisecond
_, err := p.GetLogoURL(ctx, "", srv.URL)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}