mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-24 13:38:17 +02:00
proxy: add per-route request headers setting (#346)
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
c95a72e12a
commit
a96aec57d5
7 changed files with 90 additions and 13 deletions
|
@ -111,3 +111,17 @@ func (p *Proxy) reqNeedsAuthentication(w http.ResponseWriter, r *http.Request) {
|
|||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// SetResponseHeaders sets a map of response headers.
|
||||
func SetResponseHeaders(headers map[string]string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.SetResponseHeaders")
|
||||
defer span.End()
|
||||
for key, val := range headers {
|
||||
r.Header.Set(key, val)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,9 +6,11 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
|
@ -173,3 +175,40 @@ func TestProxy_SignRequest(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_SetResponseHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
var sb strings.Builder
|
||||
for k, v := range r.Header {
|
||||
k = strings.ToLower(k)
|
||||
for _, h := range v {
|
||||
sb.WriteString(fmt.Sprintf("%v: %v\n", k, h))
|
||||
}
|
||||
}
|
||||
fmt.Fprint(w, sb.String())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
setHeaders map[string]string
|
||||
wantHeaders string
|
||||
}{
|
||||
{"good", map[string]string{"x-gonna": "give-it-to-ya"}, "x-gonna: give-it-to-ya\n"},
|
||||
{"nil", nil, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
got := SetResponseHeaders(tt.setHeaders)(fn)
|
||||
got.ServeHTTP(w, r)
|
||||
if diff := cmp.Diff(w.Body.String(), tt.wantHeaders); diff != "" {
|
||||
t.Errorf("SignRequest() :\n %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -242,7 +242,11 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.
|
|||
}
|
||||
rp.Use(p.SignRequest(signer))
|
||||
}
|
||||
|
||||
// Optional: if additional headers are to be set for this url
|
||||
if len(policy.SetRequestHeaders) != 0 {
|
||||
log.Warn().Interface("headers", policy.SetRequestHeaders).Msg("proxy: set request headers")
|
||||
rp.Use(SetResponseHeaders(policy.SetRequestHeaders))
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue