From 7b7ed2add182c351c6f40c2f5a6d22d001e6dfd0 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 24 Jan 2025 15:31:54 -0700 Subject: [PATCH] proxy: add logo discovery --- proxy/handlers_portal.go | 29 +++- proxy/portal/logo_provider.go | 230 +++++++++++++++++++++++++++++ proxy/portal/logo_provider_test.go | 21 +++ proxy/proxy.go | 3 + 4 files changed, 279 insertions(+), 4 deletions(-) create mode 100644 proxy/portal/logo_provider.go create mode 100644 proxy/portal/logo_provider_test.go diff --git a/proxy/handlers_portal.go b/proxy/handlers_portal.go index cfe94593e..66b016577 100644 --- a/proxy/handlers_portal.go +++ b/proxy/handlers_portal.go @@ -1,19 +1,22 @@ package proxy import ( + "context" "encoding/json" + "errors" "net/http" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/proxy/portal" "github.com/pomerium/pomerium/ui" ) func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error { u := p.getUserInfoData(r) - rs := p.getPortalRoutes(u) + rs := p.getPortalRoutes(r.Context(), u) m := u.ToJSON() m["routes"] = rs return ui.ServePage(w, r, "Routes", "Routes Portal", m) @@ -21,7 +24,7 @@ func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error { u := p.getUserInfoData(r) - rs := p.getPortalRoutes(u) + rs := p.getPortalRoutes(r.Context(), u) m := map[string]any{} m["routes"] = rs @@ -36,7 +39,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error { return nil } -func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route { +func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route { options := p.currentOptions.Load() pu := p.getPortalUser(u) var routes []*config.Policy @@ -45,7 +48,25 @@ func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route { routes = append(routes, route) } } - return portal.RoutesFromConfigRoutes(routes) + portalRoutes := portal.RoutesFromConfigRoutes(routes) + for i, pr := range portalRoutes { + r := routes[i] + for _, to := range r.To { + if pr.LogoURL == "" { + var err error + pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String()) + if err != nil && !errors.Is(err, portal.ErrLogoNotFound) { + log.Ctx(ctx).Error(). + Err(err). + Str("from", pr.From). + Str("to", to.URL.String()). + Msg("error retrieving logo for route") + } + } + } + portalRoutes[i] = pr + } + return portalRoutes } func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User { diff --git a/proxy/portal/logo_provider.go b/proxy/portal/logo_provider.go new file mode 100644 index 000000000..9cf63e360 --- /dev/null +++ b/proxy/portal/logo_provider.go @@ -0,0 +1,230 @@ +package portal + +import ( + "context" + "encoding/base64" + "errors" + "io" + "iter" + "mime" + "net/http" + "net/url" + "sync" + "time" + + "golang.org/x/net/html" + "golang.org/x/sync/semaphore" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/urlutil" +) + +// errors +var ErrLogoNotFound = errors.New("logo not found") + +// A LogoProvider gets logo urls for routes. +type LogoProvider interface { + GetLogoURL(ctx context.Context, from, to string) (string, error) +} + +// NewLogoProvider creates a new LogoProvider. +func NewLogoProvider() LogoProvider { + return newFaviconDiscoveryLogoProvider() +} + +type faviconCacheValue struct { + sem *semaphore.Weighted + url string + err error + expiry time.Time +} + +type faviconDiscoveryLogoProvider struct { + mu sync.Mutex + cache map[string]*faviconCacheValue + successTTL time.Duration + failureTTL time.Duration +} + +func newFaviconDiscoveryLogoProvider() *faviconDiscoveryLogoProvider { + return &faviconDiscoveryLogoProvider{ + cache: make(map[string]*faviconCacheValue), + successTTL: time.Hour, + failureTTL: 10 * time.Minute, + } +} + +func (p *faviconDiscoveryLogoProvider) GetLogoURL(ctx context.Context, _, to string) (string, error) { + p.mu.Lock() + v, ok := p.cache[to] + if !ok { + v = &faviconCacheValue{ + sem: semaphore.NewWeighted(1), + } + p.cache[to] = v + } + p.mu.Unlock() + + // take the semaphore + err := v.sem.Acquire(ctx, 1) + if err != nil { + return "", err + } + defer v.sem.Release(1) + + // if we have a valid cached url or error, return it + if v.expiry.After(time.Now()) { + return v.url, v.err + } + + // attempt to discover the logo url and save the url or the error + v.url, v.err = p.discoverLogoURL(ctx, to) + if v.err == nil { + v.expiry = time.Now().Add(p.successTTL) + } else { + v.expiry = time.Now().Add(p.failureTTL) + } + + return v.url, v.err +} + +func (p *faviconDiscoveryLogoProvider) discoverLogoURL(ctx context.Context, rawURL string) (string, error) { + u, err := urlutil.ParseAndValidateURL(rawURL) + if err != nil { + return "", ErrLogoNotFound + } + + if !(u.Scheme == "http" || u.Scheme == "https") { + return "", ErrLogoNotFound + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return "", err + } + + t := httputil.GetInsecureTransport() + c := &http.Client{ + Transport: t, + } + + res, err := c.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + + // look for any logos in the html + r := io.LimitReader(res.Body, 10*1024) + for link := range findIconLinksInHTML(r) { + linkURL, err := urlutil.ParseAndValidateURL(link) + if err != nil { + continue + } + + logoURL := p.fetchLogoURL(ctx, c, u.ResolveReference(linkURL)) + if logoURL != "" { + return logoURL, nil + } + } + + // try just the /favicon.ico + logoURL := p.fetchLogoURL(ctx, c, u.ResolveReference(&url.URL{Path: "/favicon.ico"})) + if logoURL != "" { + return logoURL, nil + } + + return "", ErrLogoNotFound +} + +func (p *faviconDiscoveryLogoProvider) fetchLogoURL(ctx context.Context, client *http.Client, logoURL *url.URL) string { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, logoURL.String(), nil) + if err != nil { + return "" + } + + res, err := client.Do(req) + if err != nil { + log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error fetching logo contents") + return "" + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + log.Ctx(ctx).Debug().Int("status-code", res.StatusCode).Str("url", logoURL.String()).Msg("error fetching logo contents") + return "" + } + + const maxImageSize = 1024 * 1024 + bs, err := io.ReadAll(io.LimitReader(res.Body, maxImageSize)) + if err != nil { + log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error reading logo contents") + return "" + } + + // first use the Content-Type header to determine the format + if mtype, _, err := mime.ParseMediaType(res.Header.Get("Content-Type")); err == nil { + if isSupportedImageType(mtype) { + return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs) + } + log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo") + return "" + } + + // next try to use mimetype sniffing + mtype := http.DetectContentType(bs) + if isSupportedImageType(mtype) { + return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs) + } + + log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo") + return "" +} + +func isSupportedImageType(mtype string) bool { + return mtype == "image/vnd.microsoft.icon" || + mtype == "image/png" || + mtype == "image/svg+xml" || + mtype == "image/jpeg" || + mtype == "image/gif" +} + +func findIconLinksInHTML(r io.Reader) iter.Seq[string] { + return func(yield func(string) bool) { + z := html.NewTokenizer(r) + for { + tt := z.Next() + if tt == html.ErrorToken { + return + } + + switch tt { + case html.StartTagToken: + name, attr := parseTag(z) + if name == "link" && attr["href"] != "" && (attr["rel"] == "shortcut icon" || attr["rel"] == "icon") { + if !yield(attr["href"]) { + return + } + } + } + } + } +} + +func parseTag(z *html.Tokenizer) (name string, attributes map[string]string) { + n, hasAttr := z.TagName() + name = string(n) + if !hasAttr { + return name, attributes + } + attributes = make(map[string]string) + for { + k, v, m := z.TagAttr() + attributes[string(k)] = string(v) + if !m { + break + } + } + return name, attributes +} diff --git a/proxy/portal/logo_provider_test.go b/proxy/portal/logo_provider_test.go new file mode 100644 index 000000000..28d35723d --- /dev/null +++ b/proxy/portal/logo_provider_test.go @@ -0,0 +1,21 @@ +package portal_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/proxy/portal" +) + +func TestLogoProvider(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + p := portal.NewLogoProvider() + u, err := p.GetLogoURL(ctx, "", "https://www.wikipedia.org") + assert.NoError(t, err) + assert.Equal(t, "data:image/vnd.microsoft.icon;base64,AAABAAMAMDAQAAEABABoBgAANgAAACAgEAABAAQA6AIAAJ4GAAAQEBAAAQAEACgBAACGCQAAKAAAADAAAABgAAAAAQAEAAAAAAAABgAAAAAAAAAAAAAQAAAAAAAAAAEBAQAXFxcAMDAwAEdHRwBYWFgAZ2dnAHZ2dgCHh4cAlZWVAKmpqQC3t7cAx8fHANfX1wDo6OgA_v7-AAAAAAD____-7u7u7u7u7u7u7u7u7u7u7u_______-7u7u7u7u7u7u7u7u7u7u7u7u7u_____u7u7u7u7u7u7u7u7u7u7u7u7u7u7___7u7u7u7u7u7u7u7u7u7u7u7u7u7u7v_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7sa-7u7u7u1b7u7u7u7u7u7u7u7u7u7u7p9u7u7u7ugG7u7u7u7u7u7u7u7u7u7u7TAa7u7u7tQBzu7u7u7u7u7u7u7u7u7u6wAF7u7u7pAAju7u7u7u7u7u7u7u7u7u1AAAru7u7U__Le7u7u7u7u7u7u7u7u7uz_8RPe7u6gAB-e7u7u7u7u7u7u7u7u7ubw94Ce7u1QAIIu7u7u7u7u7u7u7u7u7tH_G-Mt7usAAtcL7u7u7u7u7u7u7u7u7n8ATun47uQACO0T7u7u7u7u7u7u7u7u7hDxnu4x3sAPLO5Qzu7u7u7u7u7u7u7u6P_z7u6wXk_wfu7ATu7u7u7u7u7u7u7u4QAY7u7kCQADzu7kDO7u7u7u7u7u7u7uoA8u7u7sAAAG7u7r9e7u7u7u7u7u7u7uIPB-7u7uUAAs7u7uMd7u7u7u7u7u7u7rEAHe7u7uQABu7u7un37u7u7u7u7u7u7kAAXu7u7sAPHe7u7u4S3u7u7u7u7u7u7BAA3u7u7k8AHO7u7u6Aju7u7u7u7u7u5g_07u7u7B8BBe7u7u7RLu7u7u7u7u7u0v_87u7u5QAGQa7u7u7nCe7u7u7u7u7ugAA-7u7uwQ8dsE7u7u7rBO7u7u7u7u7tP_--7u7uYAB-5Qnu7u7tQa7u7u7u7u7pH_Lu7u7sLwHe6xPe7u7ur27u7u7u7u7V__ru7u7mAAju7n-e7u7u0yvu7u7u7u6h8C3u7u6yAB3u7rEs7u7u6Pfu7u7u7u1AAE7u7u5g_27u7tQG3u7u6QHO7u7u7tbwAB3u7ukfAH7u7sIAju7u5wA97u7utiAAAAF76lAA_wWeyDAA84zqUAABfO7uMiNERDIm4iNERDIrkiNEQybiI0RDJO7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u__7u7u7u7u7u7u7u7u7u7u7u7u7u7u7v___u7u7u7u7u7u7u7u7u7u7u7u7u7u7____-7u7u7u7u7u7u7u7u7u7u7u7u7u_______-7u7u7u7u7u7u7u7u7u7u7u_____-AAAAAH8AAPAAAAAADwAA4AAAAAAHAADAAAAAAAMAAIAAAAAAAQAAgAAAAAABAACAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgAAAAAABAACAAAAAAAEAAIAAAAAAAQAAwAAAAAADAADgAAAAAAcAAPAAAAAADwAA_gAAAAB_AAAoAAAAIAAAAEAAAAABAAQAAAAAAIACAAAAAAAAAAAAABAAAAAAAAAAAQEBABYWFgAnJycANTU1AEdHRwBZWVkAZWVlAHh4eACIiIgAmZmZAK6urgDMzMwA19fXAOnp6QD-_v4AAAAAAP__7u7u7u7u7u7u7u7u____7u7u7u7u7u7u7u7u7u7__u7u7u7u7u7u7u7u7u7u7_7u7u7u7u7u7u7u7u7u7u_u7u7u7u7u7u7u7u7u7u7u7u7u7u7X3u7u7I7u7u7u7u7u7u7uYF7u7uIK7u7u7u7u7u7u7QAM7u6vBO7u7u7u7u7u7ucABe7uMA_O7u7u7u7u7u7R8q_O6gCEbu7u7u7u7u7ukAnibuTx6g3u7u7u7u7u7hAe6gzP-O4Y7u7u7u7u7urwju4mXx7uge7u7u7u7u7jAd7uoACO7tCe7u7u7u7uoPfu7uEB3u7mPu7u7u7u7k8N7u7QBu7u6wru7u7u7uwAXu7ufwbu7u407u7u7u7lAM7u7RBQzu7ur87u7u7u0ATu7ucA0l7u7uFu7u7u7n_67u7RB-oL7u7nHe7u7u0fPu7ucA3uJO7u7Qju7u7o_67u7Q9u7q-u7u5R3u7u0Q_e7ub_vu7PLO7uX13u4w__Be4v_xnoH_-ekv__Xu7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7-7u7u7u7u7u7u7u7u7u7v_u7u7u7u7u7u7u7u7u7u7__u7u7u7u7u7u7u7u7u7v___-7u7u7u7u7u7u7u7v__8AAAD8AAAAOAAAABgAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIAAAAGAAAABwAAAA_AAAA8oAAAAEAAAACAAAAABAAQAAAAAAMAAAAAAAAAAAAAAABAAAAAAAAAAAQEBABcXFwAnJycAOzs7AElJSQBpaWkAeXl5AIaGhgCVlZUApqamALOzswDMzMwA2dnZAObm5gD-_v4AAAAAAP_u7u7u7u7__u7u7u7u7u_u7uzu7t7u7u7u4Y7lTu7u7u6QTtA77u7u7iaoctXu7u7qDOQZ5d7u7uRO5R7rbu7uv77iLu5O7u5D7pGn7pju7QrtKOTe4-6z-OT40z2RTO7u7u7u7u7u7u7u7u7u7u7-7u7u7u7u7__u7u7u7u7_wAMAD4ABAA8AAAAPAAAADwAAAA8AAAAPAAAADwAAAA8AAAAPAAAADwAAAA8AAAAPAAAADwAAAA-AAQAPwAMADw==", u) +} diff --git a/proxy/proxy.go b/proxy/proxy.go index ea58c63a6..07050271f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -21,6 +21,7 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/proxy/portal" ) const ( @@ -60,6 +61,7 @@ type Proxy struct { currentRouter *atomicutil.Value[*mux.Router] webauthn *webauthn.Handler tracerProvider oteltrace.TracerProvider + logoProvider portal.LogoProvider } // New takes a Proxy service from options and a validation function. @@ -76,6 +78,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { state: atomicutil.NewValue(state), currentOptions: config.NewAtomicOptions(), currentRouter: atomicutil.NewValue(httputil.NewRouter()), + logoProvider: portal.NewLogoProvider(), } p.OnConfigChange(ctx, cfg) p.webauthn = webauthn.New(p.getWebauthnState)