proxy: add logo discovery

This commit is contained in:
Caleb Doxsey 2025-01-24 15:31:54 -07:00
parent 332d3dc334
commit 7b7ed2add1
4 changed files with 279 additions and 4 deletions

View file

@ -1,19 +1,22 @@
package proxy package proxy
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/proxy/portal" "github.com/pomerium/pomerium/proxy/portal"
"github.com/pomerium/pomerium/ui" "github.com/pomerium/pomerium/ui"
) )
func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error {
u := p.getUserInfoData(r) u := p.getUserInfoData(r)
rs := p.getPortalRoutes(u) rs := p.getPortalRoutes(r.Context(), u)
m := u.ToJSON() m := u.ToJSON()
m["routes"] = rs m["routes"] = rs
return ui.ServePage(w, r, "Routes", "Routes Portal", m) return ui.ServePage(w, r, "Routes", "Routes Portal", m)
@ -21,7 +24,7 @@ func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error {
func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
u := p.getUserInfoData(r) u := p.getUserInfoData(r)
rs := p.getPortalRoutes(u) rs := p.getPortalRoutes(r.Context(), u)
m := map[string]any{} m := map[string]any{}
m["routes"] = rs m["routes"] = rs
@ -36,7 +39,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
return nil return nil
} }
func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route { func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route {
options := p.currentOptions.Load() options := p.currentOptions.Load()
pu := p.getPortalUser(u) pu := p.getPortalUser(u)
var routes []*config.Policy var routes []*config.Policy
@ -45,7 +48,25 @@ func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route {
routes = append(routes, route) routes = append(routes, route)
} }
} }
return portal.RoutesFromConfigRoutes(routes) portalRoutes := portal.RoutesFromConfigRoutes(routes)
for i, pr := range portalRoutes {
r := routes[i]
for _, to := range r.To {
if pr.LogoURL == "" {
var err error
pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String())
if err != nil && !errors.Is(err, portal.ErrLogoNotFound) {
log.Ctx(ctx).Error().
Err(err).
Str("from", pr.From).
Str("to", to.URL.String()).
Msg("error retrieving logo for route")
}
}
}
portalRoutes[i] = pr
}
return portalRoutes
} }
func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User { func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User {

View file

@ -0,0 +1,230 @@
package portal
import (
"context"
"encoding/base64"
"errors"
"io"
"iter"
"mime"
"net/http"
"net/url"
"sync"
"time"
"golang.org/x/net/html"
"golang.org/x/sync/semaphore"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
)
// errors
var ErrLogoNotFound = errors.New("logo not found")
// A LogoProvider gets logo urls for routes.
type LogoProvider interface {
GetLogoURL(ctx context.Context, from, to string) (string, error)
}
// NewLogoProvider creates a new LogoProvider.
func NewLogoProvider() LogoProvider {
return newFaviconDiscoveryLogoProvider()
}
type faviconCacheValue struct {
sem *semaphore.Weighted
url string
err error
expiry time.Time
}
type faviconDiscoveryLogoProvider struct {
mu sync.Mutex
cache map[string]*faviconCacheValue
successTTL time.Duration
failureTTL time.Duration
}
func newFaviconDiscoveryLogoProvider() *faviconDiscoveryLogoProvider {
return &faviconDiscoveryLogoProvider{
cache: make(map[string]*faviconCacheValue),
successTTL: time.Hour,
failureTTL: 10 * time.Minute,
}
}
func (p *faviconDiscoveryLogoProvider) GetLogoURL(ctx context.Context, _, to string) (string, error) {
p.mu.Lock()
v, ok := p.cache[to]
if !ok {
v = &faviconCacheValue{
sem: semaphore.NewWeighted(1),
}
p.cache[to] = v
}
p.mu.Unlock()
// take the semaphore
err := v.sem.Acquire(ctx, 1)
if err != nil {
return "", err
}
defer v.sem.Release(1)
// if we have a valid cached url or error, return it
if v.expiry.After(time.Now()) {
return v.url, v.err
}
// attempt to discover the logo url and save the url or the error
v.url, v.err = p.discoverLogoURL(ctx, to)
if v.err == nil {
v.expiry = time.Now().Add(p.successTTL)
} else {
v.expiry = time.Now().Add(p.failureTTL)
}
return v.url, v.err
}
func (p *faviconDiscoveryLogoProvider) discoverLogoURL(ctx context.Context, rawURL string) (string, error) {
u, err := urlutil.ParseAndValidateURL(rawURL)
if err != nil {
return "", ErrLogoNotFound
}
if !(u.Scheme == "http" || u.Scheme == "https") {
return "", ErrLogoNotFound
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", err
}
t := httputil.GetInsecureTransport()
c := &http.Client{
Transport: t,
}
res, err := c.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
// look for any logos in the html
r := io.LimitReader(res.Body, 10*1024)
for link := range findIconLinksInHTML(r) {
linkURL, err := urlutil.ParseAndValidateURL(link)
if err != nil {
continue
}
logoURL := p.fetchLogoURL(ctx, c, u.ResolveReference(linkURL))
if logoURL != "" {
return logoURL, nil
}
}
// try just the /favicon.ico
logoURL := p.fetchLogoURL(ctx, c, u.ResolveReference(&url.URL{Path: "/favicon.ico"}))
if logoURL != "" {
return logoURL, nil
}
return "", ErrLogoNotFound
}
func (p *faviconDiscoveryLogoProvider) fetchLogoURL(ctx context.Context, client *http.Client, logoURL *url.URL) string {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, logoURL.String(), nil)
if err != nil {
return ""
}
res, err := client.Do(req)
if err != nil {
log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error fetching logo contents")
return ""
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
log.Ctx(ctx).Debug().Int("status-code", res.StatusCode).Str("url", logoURL.String()).Msg("error fetching logo contents")
return ""
}
const maxImageSize = 1024 * 1024
bs, err := io.ReadAll(io.LimitReader(res.Body, maxImageSize))
if err != nil {
log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error reading logo contents")
return ""
}
// first use the Content-Type header to determine the format
if mtype, _, err := mime.ParseMediaType(res.Header.Get("Content-Type")); err == nil {
if isSupportedImageType(mtype) {
return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs)
}
log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo")
return ""
}
// next try to use mimetype sniffing
mtype := http.DetectContentType(bs)
if isSupportedImageType(mtype) {
return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs)
}
log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo")
return ""
}
func isSupportedImageType(mtype string) bool {
return mtype == "image/vnd.microsoft.icon" ||
mtype == "image/png" ||
mtype == "image/svg+xml" ||
mtype == "image/jpeg" ||
mtype == "image/gif"
}
func findIconLinksInHTML(r io.Reader) iter.Seq[string] {
return func(yield func(string) bool) {
z := html.NewTokenizer(r)
for {
tt := z.Next()
if tt == html.ErrorToken {
return
}
switch tt {
case html.StartTagToken:
name, attr := parseTag(z)
if name == "link" && attr["href"] != "" && (attr["rel"] == "shortcut icon" || attr["rel"] == "icon") {
if !yield(attr["href"]) {
return
}
}
}
}
}
}
func parseTag(z *html.Tokenizer) (name string, attributes map[string]string) {
n, hasAttr := z.TagName()
name = string(n)
if !hasAttr {
return name, attributes
}
attributes = make(map[string]string)
for {
k, v, m := z.TagAttr()
attributes[string(k)] = string(v)
if !m {
break
}
}
return name, attributes
}

View file

@ -0,0 +1,21 @@
package portal_test
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/proxy/portal"
)
func TestLogoProvider(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
p := portal.NewLogoProvider()
u, err := p.GetLogoURL(ctx, "", "https://www.wikipedia.org")
assert.NoError(t, err)
assert.Equal(t, "_v7-AAAAAAD____-7u7u7u7u7u7u7u7u7u7u7u_______-7u7u7u7u7u7u7u7u7u7u7u7u7u_____u7u7u7u7u7u7u7u7u7u7u7u7u7u7___7u7u7u7u7u7u7u7u7u7u7u7u7u7u7v_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7sa-7u7u7u1b7u7u7u7u7u7u7u7u7u7u7p9u7u7u7ugG7u7u7u7u7u7u7u7u7u7u7TAa7u7u7tQBzu7u7u7u7u7u7u7u7u7u6wAF7u7u7pAAju7u7u7u7u7u7u7u7u7u1AAAru7u7U__Le7u7u7u7u7u7u7u7u7uz_8RPe7u6gAB-e7u7u7u7u7u7u7u7u7ubw94Ce7u1QAIIu7u7u7u7u7u7u7u7u7tH_G-Mt7usAAtcL7u7u7u7u7u7u7u7u7n8ATun47uQACO0T7u7u7u7u7u7u7u7u7hDxnu4x3sAPLO5Qzu7u7u7u7u7u7u7u6P_z7u6wXk_wfu7ATu7u7u7u7u7u7u7u4QAY7u7kCQADzu7kDO7u7u7u7u7u7u7uoA8u7u7sAAAG7u7r9e7u7u7u7u7u7u7uIPB-7u7uUAAs7u7uMd7u7u7u7u7u7u7rEAHe7u7uQABu7u7un37u7u7u7u7u7u7kAAXu7u7sAPHe7u7u4S3u7u7u7u7u7u7BAA3u7u7k8AHO7u7u6Aju7u7u7u7u7u5g_07u7u7B8BBe7u7u7RLu7u7u7u7u7u0v_87u7u5QAGQa7u7u7nCe7u7u7u7u7ugAA-7u7uwQ8dsE7u7u7rBO7u7u7u7u7tP_--7u7uYAB-5Qnu7u7tQa7u7u7u7u7pH_Lu7u7sLwHe6xPe7u7ur27u7u7u7u7V__ru7u7mAAju7n-e7u7u0yvu7u7u7u6h8C3u7u6yAB3u7rEs7u7u6Pfu7u7u7u1AAE7u7u5g_27u7tQG3u7u6QHO7u7u7tbwAB3u7ukfAH7u7sIAju7u5wA97u7utiAAAAF76lAA_wWeyDAA84zqUAABfO7uMiNERDIm4iNERDIrkiNEQybiI0RDJO7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u_-7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u__7u7u7u7u7u7u7u7u7u7u7u7u7u7u7v___u7u7u7u7u7u7u7u7u7u7u7u7u7u7____-7u7u7u7u7u7u7u7u7u7u7u7u7u_______-7u7u7u7u7u7u7u7u7u7u7u_____-AAAAAH8AAPAAAAAADwAA4AAAAAAHAADAAAAAAAMAAIAAAAAAAQAAgAAAAAABAACAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgAAAAAABAACAAAAAAAEAAIAAAAAAAQAAwAAAAAADAADgAAAAAAcAAPAAAAAADwAA_gAAAAB_AAAoAAAAIAAAAEAAAAABAAQAAAAAAIACAAAAAAAAAAAAABAAAAAAAAAAAQEBABYWFgAnJycANTU1AEdHRwBZWVkAZWVlAHh4eACIiIgAmZmZAK6urgDMzMwA19fXAOnp6QD-_v4AAAAAAP__7u7u7u7u7u7u7u7u____7u7u7u7u7u7u7u7u7u7__u7u7u7u7u7u7u7u7u7u7_7u7u7u7u7u7u7u7u7u7u_u7u7u7u7u7u7u7u7u7u7u7u7u7u7X3u7u7I7u7u7u7u7u7u7uYF7u7uIK7u7u7u7u7u7u7QAM7u6vBO7u7u7u7u7u7ucABe7uMA_O7u7u7u7u7u7R8q_O6gCEbu7u7u7u7u7ukAnibuTx6g3u7u7u7u7u7hAe6gzP-O4Y7u7u7u7u7urwju4mXx7uge7u7u7u7u7jAd7uoACO7tCe7u7u7u7uoPfu7uEB3u7mPu7u7u7u7k8N7u7QBu7u6wru7u7u7uwAXu7ufwbu7u407u7u7u7lAM7u7RBQzu7ur87u7u7u0ATu7ucA0l7u7uFu7u7u7n_67u7RB-oL7u7nHe7u7u0fPu7ucA3uJO7u7Qju7u7o_67u7Q9u7q-u7u5R3u7u0Q_e7ub_vu7PLO7uX13u4w__Be4v_xnoH_-ekv__Xu7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7-7u7u7u7u7u7u7u7u7u7v_u7u7u7u7u7u7u7u7u7u7__u7u7u7u7u7u7u7u7u7v___-7u7u7u7u7u7u7u7v__8AAAD8AAAAOAAAABgAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIAAAAGAAAABwAAAA_AAAA8oAAAAEAAAACAAAAABAAQAAAAAAMAAAAAAAAAAAAAAABAAAAAAAAAAAQEBABcXFwAnJycAOzs7AElJSQBpaWkAeXl5AIaGhgCVlZUApqamALOzswDMzMwA2dnZAObm5gD-_v4AAAAAAP_u7u7u7u7__u7u7u7u7u_u7uzu7t7u7u7u4Y7lTu7u7u6QTtA77u7u7iaoctXu7u7qDOQZ5d7u7uRO5R7rbu7uv77iLu5O7u5D7pGn7pju7QrtKOTe4-6z-OT40z2RTO7u7u7u7u7u7u7u7u7u7u7-7u7u7u7u7__u7u7u7u7_wAMAD4ABAA8AAAAPAAAADwAAAA8AAAAPAAAADwAAAA8AAAAPAAAADwAAAA8AAAAPAAAADwAAAA-AAQAPwAMADw==", u)
}

View file

@ -21,6 +21,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/proxy/portal"
) )
const ( const (
@ -60,6 +61,7 @@ type Proxy struct {
currentRouter *atomicutil.Value[*mux.Router] currentRouter *atomicutil.Value[*mux.Router]
webauthn *webauthn.Handler webauthn *webauthn.Handler
tracerProvider oteltrace.TracerProvider tracerProvider oteltrace.TracerProvider
logoProvider portal.LogoProvider
} }
// New takes a Proxy service from options and a validation function. // New takes a Proxy service from options and a validation function.
@ -76,6 +78,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
state: atomicutil.NewValue(state), state: atomicutil.NewValue(state),
currentOptions: config.NewAtomicOptions(), currentOptions: config.NewAtomicOptions(),
currentRouter: atomicutil.NewValue(httputil.NewRouter()), currentRouter: atomicutil.NewValue(httputil.NewRouter()),
logoProvider: portal.NewLogoProvider(),
} }
p.OnConfigChange(ctx, cfg) p.OnConfigChange(ctx, cfg)
p.webauthn = webauthn.New(p.getWebauthnState) p.webauthn = webauthn.New(p.getWebauthnState)