From 3c5c7fbd31ddff743e562be254b7093b7d72c4ec Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 29 Jan 2025 12:57:26 -0700 Subject: [PATCH] proxy: add logo discovery (#5448) * proxy: add logo discovery * use a static url for testing --- proxy/handlers_portal.go | 29 +++- proxy/portal/logo_provider.go | 230 +++++++++++++++++++++++++++++ proxy/portal/logo_provider_test.go | 37 +++++ proxy/proxy.go | 3 + 4 files changed, 295 insertions(+), 4 deletions(-) create mode 100644 proxy/portal/logo_provider.go create mode 100644 proxy/portal/logo_provider_test.go diff --git a/proxy/handlers_portal.go b/proxy/handlers_portal.go index cfe94593e..66b016577 100644 --- a/proxy/handlers_portal.go +++ b/proxy/handlers_portal.go @@ -1,19 +1,22 @@ package proxy import ( + "context" "encoding/json" + "errors" "net/http" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/proxy/portal" "github.com/pomerium/pomerium/ui" ) func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error { u := p.getUserInfoData(r) - rs := p.getPortalRoutes(u) + rs := p.getPortalRoutes(r.Context(), u) m := u.ToJSON() m["routes"] = rs return ui.ServePage(w, r, "Routes", "Routes Portal", m) @@ -21,7 +24,7 @@ func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error { u := p.getUserInfoData(r) - rs := p.getPortalRoutes(u) + rs := p.getPortalRoutes(r.Context(), u) m := map[string]any{} m["routes"] = rs @@ -36,7 +39,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error { return nil } -func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route { +func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route { options := p.currentOptions.Load() pu := p.getPortalUser(u) var routes []*config.Policy @@ -45,7 +48,25 @@ func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route { routes = append(routes, route) } } - return portal.RoutesFromConfigRoutes(routes) + portalRoutes := portal.RoutesFromConfigRoutes(routes) + for i, pr := range portalRoutes { + r := routes[i] + for _, to := range r.To { + if pr.LogoURL == "" { + var err error + pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String()) + if err != nil && !errors.Is(err, portal.ErrLogoNotFound) { + log.Ctx(ctx).Error(). + Err(err). + Str("from", pr.From). + Str("to", to.URL.String()). + Msg("error retrieving logo for route") + } + } + } + portalRoutes[i] = pr + } + return portalRoutes } func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User { diff --git a/proxy/portal/logo_provider.go b/proxy/portal/logo_provider.go new file mode 100644 index 000000000..d42b7d090 --- /dev/null +++ b/proxy/portal/logo_provider.go @@ -0,0 +1,230 @@ +package portal + +import ( + "context" + "encoding/base64" + "errors" + "io" + "iter" + "mime" + "net/http" + "net/url" + "sync" + "time" + + "golang.org/x/net/html" + "golang.org/x/sync/semaphore" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/urlutil" +) + +// errors +var ErrLogoNotFound = errors.New("logo not found") + +// A LogoProvider gets logo urls for routes. +type LogoProvider interface { + GetLogoURL(ctx context.Context, from, to string) (string, error) +} + +// NewLogoProvider creates a new LogoProvider. +func NewLogoProvider() LogoProvider { + return newFaviconDiscoveryLogoProvider() +} + +type faviconCacheValue struct { + sem *semaphore.Weighted + url string + err error + expiry time.Time +} + +type faviconDiscoveryLogoProvider struct { + mu sync.Mutex + cache map[string]*faviconCacheValue + successTTL time.Duration + failureTTL time.Duration +} + +func newFaviconDiscoveryLogoProvider() *faviconDiscoveryLogoProvider { + return &faviconDiscoveryLogoProvider{ + cache: make(map[string]*faviconCacheValue), + successTTL: time.Hour, + failureTTL: 10 * time.Minute, + } +} + +func (p *faviconDiscoveryLogoProvider) GetLogoURL(ctx context.Context, _, to string) (string, error) { + p.mu.Lock() + v, ok := p.cache[to] + if !ok { + v = &faviconCacheValue{ + sem: semaphore.NewWeighted(1), + } + p.cache[to] = v + } + p.mu.Unlock() + + // take the semaphore + err := v.sem.Acquire(ctx, 1) + if err != nil { + return "", err + } + defer v.sem.Release(1) + + // if we have a valid cached url or error, return it + if v.expiry.After(time.Now()) { + return v.url, v.err + } + + // attempt to discover the logo url and save the url or the error + v.url, v.err = p.discoverLogoURL(ctx, to) + if v.err == nil { + v.expiry = time.Now().Add(p.successTTL) + } else { + v.expiry = time.Now().Add(p.failureTTL) + } + + return v.url, v.err +} + +func (p *faviconDiscoveryLogoProvider) discoverLogoURL(ctx context.Context, rawURL string) (string, error) { + u, err := urlutil.ParseAndValidateURL(rawURL) + if err != nil { + return "", ErrLogoNotFound + } + + if !(u.Scheme == "http" || u.Scheme == "https") { + return "", ErrLogoNotFound + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return "", err + } + + t := httputil.GetInsecureTransport() + c := &http.Client{ + Transport: t, + } + + res, err := c.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + + // look for any logos in the html + r := io.LimitReader(res.Body, 10*1024) + for link := range findIconLinksInHTML(r) { + linkURL, err := u.Parse(link) + if err != nil { + continue + } + + logoURL := p.fetchLogoURL(ctx, c, linkURL) + if logoURL != "" { + return logoURL, nil + } + } + + // try just the /favicon.ico + logoURL := p.fetchLogoURL(ctx, c, u.ResolveReference(&url.URL{Path: "/favicon.ico"})) + if logoURL != "" { + return logoURL, nil + } + + return "", ErrLogoNotFound +} + +func (p *faviconDiscoveryLogoProvider) fetchLogoURL(ctx context.Context, client *http.Client, logoURL *url.URL) string { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, logoURL.String(), nil) + if err != nil { + return "" + } + + res, err := client.Do(req) + if err != nil { + log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error fetching logo contents") + return "" + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + log.Ctx(ctx).Debug().Int("status-code", res.StatusCode).Str("url", logoURL.String()).Msg("error fetching logo contents") + return "" + } + + const maxImageSize = 1024 * 1024 + bs, err := io.ReadAll(io.LimitReader(res.Body, maxImageSize)) + if err != nil { + log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error reading logo contents") + return "" + } + + // first use the Content-Type header to determine the format + if mtype, _, err := mime.ParseMediaType(res.Header.Get("Content-Type")); err == nil { + if isSupportedImageType(mtype) { + return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs) + } + log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo") + return "" + } + + // next try to use mimetype sniffing + mtype := http.DetectContentType(bs) + if isSupportedImageType(mtype) { + return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs) + } + + log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo") + return "" +} + +func isSupportedImageType(mtype string) bool { + return mtype == "image/vnd.microsoft.icon" || + mtype == "image/png" || + mtype == "image/svg+xml" || + mtype == "image/jpeg" || + mtype == "image/gif" +} + +func findIconLinksInHTML(r io.Reader) iter.Seq[string] { + return func(yield func(string) bool) { + z := html.NewTokenizer(r) + for { + tt := z.Next() + if tt == html.ErrorToken { + return + } + + switch tt { + case html.StartTagToken, html.SelfClosingTagToken: + name, attr := parseTag(z) + if name == "link" && attr["href"] != "" && (attr["rel"] == "shortcut icon" || attr["rel"] == "icon") { + if !yield(attr["href"]) { + return + } + } + } + } + } +} + +func parseTag(z *html.Tokenizer) (name string, attributes map[string]string) { + n, hasAttr := z.TagName() + name = string(n) + if !hasAttr { + return name, attributes + } + attributes = make(map[string]string) + for { + k, v, m := z.TagAttr() + attributes[string(k)] = string(v) + if !m { + break + } + } + return name, attributes +} diff --git a/proxy/portal/logo_provider_test.go b/proxy/portal/logo_provider_test.go new file mode 100644 index 000000000..a14244b06 --- /dev/null +++ b/proxy/portal/logo_provider_test.go @@ -0,0 +1,37 @@ +package portal_test + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/proxy/portal" +) + +func TestLogoProvider(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/icon": + w.Header().Set("Content-Type", "image/vnd.microsoft.icon") + io.WriteString(w, "NOT ACTUALLY AN ICON") + case "/": + io.WriteString(w, ``) + default: + http.NotFound(w, r) + } + })) + t.Cleanup(srv.Close) + + ctx := testutil.GetContext(t, time.Minute) + p := portal.NewLogoProvider() + u, err := p.GetLogoURL(ctx, "", srv.URL) + assert.NoError(t, err) + assert.Equal(t, "", u) +} diff --git a/proxy/proxy.go b/proxy/proxy.go index ea58c63a6..07050271f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -21,6 +21,7 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/proxy/portal" ) const ( @@ -60,6 +61,7 @@ type Proxy struct { currentRouter *atomicutil.Value[*mux.Router] webauthn *webauthn.Handler tracerProvider oteltrace.TracerProvider + logoProvider portal.LogoProvider } // New takes a Proxy service from options and a validation function. @@ -76,6 +78,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { state: atomicutil.NewValue(state), currentOptions: config.NewAtomicOptions(), currentRouter: atomicutil.NewValue(httputil.NewRouter()), + logoProvider: portal.NewLogoProvider(), } p.OnConfigChange(ctx, cfg) p.webauthn = webauthn.New(p.getWebauthnState)