mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 02:12:50 +02:00
proxy: add logo discovery (#5448)
* proxy: add logo discovery * use a static url for testing
This commit is contained in:
parent
936bd28ae4
commit
3c5c7fbd31
4 changed files with 295 additions and 4 deletions
|
@ -1,19 +1,22 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/proxy/portal"
|
"github.com/pomerium/pomerium/proxy/portal"
|
||||||
"github.com/pomerium/pomerium/ui"
|
"github.com/pomerium/pomerium/ui"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error {
|
func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error {
|
||||||
u := p.getUserInfoData(r)
|
u := p.getUserInfoData(r)
|
||||||
rs := p.getPortalRoutes(u)
|
rs := p.getPortalRoutes(r.Context(), u)
|
||||||
m := u.ToJSON()
|
m := u.ToJSON()
|
||||||
m["routes"] = rs
|
m["routes"] = rs
|
||||||
return ui.ServePage(w, r, "Routes", "Routes Portal", m)
|
return ui.ServePage(w, r, "Routes", "Routes Portal", m)
|
||||||
|
@ -21,7 +24,7 @@ func (p *Proxy) routesPortalHTML(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
|
||||||
func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
|
func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
|
||||||
u := p.getUserInfoData(r)
|
u := p.getUserInfoData(r)
|
||||||
rs := p.getPortalRoutes(u)
|
rs := p.getPortalRoutes(r.Context(), u)
|
||||||
m := map[string]any{}
|
m := map[string]any{}
|
||||||
m["routes"] = rs
|
m["routes"] = rs
|
||||||
|
|
||||||
|
@ -36,7 +39,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route {
|
func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route {
|
||||||
options := p.currentOptions.Load()
|
options := p.currentOptions.Load()
|
||||||
pu := p.getPortalUser(u)
|
pu := p.getPortalUser(u)
|
||||||
var routes []*config.Policy
|
var routes []*config.Policy
|
||||||
|
@ -45,7 +48,25 @@ func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route {
|
||||||
routes = append(routes, route)
|
routes = append(routes, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return portal.RoutesFromConfigRoutes(routes)
|
portalRoutes := portal.RoutesFromConfigRoutes(routes)
|
||||||
|
for i, pr := range portalRoutes {
|
||||||
|
r := routes[i]
|
||||||
|
for _, to := range r.To {
|
||||||
|
if pr.LogoURL == "" {
|
||||||
|
var err error
|
||||||
|
pr.LogoURL, err = p.logoProvider.GetLogoURL(ctx, pr.From, to.URL.String())
|
||||||
|
if err != nil && !errors.Is(err, portal.ErrLogoNotFound) {
|
||||||
|
log.Ctx(ctx).Error().
|
||||||
|
Err(err).
|
||||||
|
Str("from", pr.From).
|
||||||
|
Str("to", to.URL.String()).
|
||||||
|
Msg("error retrieving logo for route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
portalRoutes[i] = pr
|
||||||
|
}
|
||||||
|
return portalRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User {
|
func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User {
|
||||||
|
|
230
proxy/portal/logo_provider.go
Normal file
230
proxy/portal/logo_provider.go
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
package portal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"iter"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/html"
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// errors
|
||||||
|
var ErrLogoNotFound = errors.New("logo not found")
|
||||||
|
|
||||||
|
// A LogoProvider gets logo urls for routes.
|
||||||
|
type LogoProvider interface {
|
||||||
|
GetLogoURL(ctx context.Context, from, to string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLogoProvider creates a new LogoProvider.
|
||||||
|
func NewLogoProvider() LogoProvider {
|
||||||
|
return newFaviconDiscoveryLogoProvider()
|
||||||
|
}
|
||||||
|
|
||||||
|
type faviconCacheValue struct {
|
||||||
|
sem *semaphore.Weighted
|
||||||
|
url string
|
||||||
|
err error
|
||||||
|
expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type faviconDiscoveryLogoProvider struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cache map[string]*faviconCacheValue
|
||||||
|
successTTL time.Duration
|
||||||
|
failureTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFaviconDiscoveryLogoProvider() *faviconDiscoveryLogoProvider {
|
||||||
|
return &faviconDiscoveryLogoProvider{
|
||||||
|
cache: make(map[string]*faviconCacheValue),
|
||||||
|
successTTL: time.Hour,
|
||||||
|
failureTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *faviconDiscoveryLogoProvider) GetLogoURL(ctx context.Context, _, to string) (string, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
v, ok := p.cache[to]
|
||||||
|
if !ok {
|
||||||
|
v = &faviconCacheValue{
|
||||||
|
sem: semaphore.NewWeighted(1),
|
||||||
|
}
|
||||||
|
p.cache[to] = v
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
// take the semaphore
|
||||||
|
err := v.sem.Acquire(ctx, 1)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer v.sem.Release(1)
|
||||||
|
|
||||||
|
// if we have a valid cached url or error, return it
|
||||||
|
if v.expiry.After(time.Now()) {
|
||||||
|
return v.url, v.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// attempt to discover the logo url and save the url or the error
|
||||||
|
v.url, v.err = p.discoverLogoURL(ctx, to)
|
||||||
|
if v.err == nil {
|
||||||
|
v.expiry = time.Now().Add(p.successTTL)
|
||||||
|
} else {
|
||||||
|
v.expiry = time.Now().Add(p.failureTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return v.url, v.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *faviconDiscoveryLogoProvider) discoverLogoURL(ctx context.Context, rawURL string) (string, error) {
|
||||||
|
u, err := urlutil.ParseAndValidateURL(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", ErrLogoNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(u.Scheme == "http" || u.Scheme == "https") {
|
||||||
|
return "", ErrLogoNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
t := httputil.GetInsecureTransport()
|
||||||
|
c := &http.Client{
|
||||||
|
Transport: t,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
// look for any logos in the html
|
||||||
|
r := io.LimitReader(res.Body, 10*1024)
|
||||||
|
for link := range findIconLinksInHTML(r) {
|
||||||
|
linkURL, err := u.Parse(link)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logoURL := p.fetchLogoURL(ctx, c, linkURL)
|
||||||
|
if logoURL != "" {
|
||||||
|
return logoURL, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// try just the /favicon.ico
|
||||||
|
logoURL := p.fetchLogoURL(ctx, c, u.ResolveReference(&url.URL{Path: "/favicon.ico"}))
|
||||||
|
if logoURL != "" {
|
||||||
|
return logoURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", ErrLogoNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *faviconDiscoveryLogoProvider) fetchLogoURL(ctx context.Context, client *http.Client, logoURL *url.URL) string {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, logoURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error fetching logo contents")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode/100 != 2 {
|
||||||
|
log.Ctx(ctx).Debug().Int("status-code", res.StatusCode).Str("url", logoURL.String()).Msg("error fetching logo contents")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxImageSize = 1024 * 1024
|
||||||
|
bs, err := io.ReadAll(io.LimitReader(res.Body, maxImageSize))
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(ctx).Debug().Str("url", logoURL.String()).Err(err).Msg("error reading logo contents")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// first use the Content-Type header to determine the format
|
||||||
|
if mtype, _, err := mime.ParseMediaType(res.Header.Get("Content-Type")); err == nil {
|
||||||
|
if isSupportedImageType(mtype) {
|
||||||
|
return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs)
|
||||||
|
}
|
||||||
|
log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// next try to use mimetype sniffing
|
||||||
|
mtype := http.DetectContentType(bs)
|
||||||
|
if isSupportedImageType(mtype) {
|
||||||
|
return "data:" + mtype + ";base64," + base64.StdEncoding.EncodeToString(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Ctx(ctx).Debug().Str("mime-type", mtype).Str("url", logoURL.String()).Msg("rejecting logo")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSupportedImageType(mtype string) bool {
|
||||||
|
return mtype == "image/vnd.microsoft.icon" ||
|
||||||
|
mtype == "image/png" ||
|
||||||
|
mtype == "image/svg+xml" ||
|
||||||
|
mtype == "image/jpeg" ||
|
||||||
|
mtype == "image/gif"
|
||||||
|
}
|
||||||
|
|
||||||
|
func findIconLinksInHTML(r io.Reader) iter.Seq[string] {
|
||||||
|
return func(yield func(string) bool) {
|
||||||
|
z := html.NewTokenizer(r)
|
||||||
|
for {
|
||||||
|
tt := z.Next()
|
||||||
|
if tt == html.ErrorToken {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tt {
|
||||||
|
case html.StartTagToken, html.SelfClosingTagToken:
|
||||||
|
name, attr := parseTag(z)
|
||||||
|
if name == "link" && attr["href"] != "" && (attr["rel"] == "shortcut icon" || attr["rel"] == "icon") {
|
||||||
|
if !yield(attr["href"]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTag(z *html.Tokenizer) (name string, attributes map[string]string) {
|
||||||
|
n, hasAttr := z.TagName()
|
||||||
|
name = string(n)
|
||||||
|
if !hasAttr {
|
||||||
|
return name, attributes
|
||||||
|
}
|
||||||
|
attributes = make(map[string]string)
|
||||||
|
for {
|
||||||
|
k, v, m := z.TagAttr()
|
||||||
|
attributes[string(k)] = string(v)
|
||||||
|
if !m {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return name, attributes
|
||||||
|
}
|
37
proxy/portal/logo_provider_test.go
Normal file
37
proxy/portal/logo_provider_test.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package portal_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
"github.com/pomerium/pomerium/proxy/portal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogoProvider(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/icon":
|
||||||
|
w.Header().Set("Content-Type", "image/vnd.microsoft.icon")
|
||||||
|
io.WriteString(w, "NOT ACTUALLY AN ICON")
|
||||||
|
case "/":
|
||||||
|
io.WriteString(w, `<!doctype html><html><head><link rel="icon" href="/icon" /></head><body></body></html>`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
ctx := testutil.GetContext(t, time.Minute)
|
||||||
|
p := portal.NewLogoProvider()
|
||||||
|
u, err := p.GetLogoURL(ctx, "", srv.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "data:image/vnd.microsoft.icon;base64,Tk9UIEFDVFVBTExZIEFOIElDT04=", u)
|
||||||
|
}
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/proxy/portal"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -60,6 +61,7 @@ type Proxy struct {
|
||||||
currentRouter *atomicutil.Value[*mux.Router]
|
currentRouter *atomicutil.Value[*mux.Router]
|
||||||
webauthn *webauthn.Handler
|
webauthn *webauthn.Handler
|
||||||
tracerProvider oteltrace.TracerProvider
|
tracerProvider oteltrace.TracerProvider
|
||||||
|
logoProvider portal.LogoProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
// New takes a Proxy service from options and a validation function.
|
// New takes a Proxy service from options and a validation function.
|
||||||
|
@ -76,6 +78,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
|
||||||
state: atomicutil.NewValue(state),
|
state: atomicutil.NewValue(state),
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentOptions: config.NewAtomicOptions(),
|
||||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||||
|
logoProvider: portal.NewLogoProvider(),
|
||||||
}
|
}
|
||||||
p.OnConfigChange(ctx, cfg)
|
p.OnConfigChange(ctx, cfg)
|
||||||
p.webauthn = webauthn.New(p.getWebauthnState)
|
p.webauthn = webauthn.New(p.getWebauthnState)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue