mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 09:19:39 +02:00
proxy: add JWT request signing support (#19)
- Refactored middleware and request hander logging. - Request refactored to use context.Context. - Add helper (based on Alice) to allow middleware chaining. - Add helper scripts to generate elliptic curve self-signed certificate that can be used to sign JWT. - Changed LetsEncrypt scripts to use acme instead of certbot. - Add script to have LetsEncrypt sign an RSA based certificate. - Add documentation to explain how to verify headers. - Refactored internal/cryptutil signer's code to expect a valid EC priv key. - Changed JWT expiries to use default leeway period. - Update docs and add screenshots. - Replaced logging handler logic to use context.Context. - Removed specific XML error handling. - Refactored handler function signatures to prefer standard go idioms.
This commit is contained in:
parent
98b8c7481f
commit
426e003b03
30 changed files with 1711 additions and 588 deletions
212
internal/log/handler_log.go
Normal file
212
internal/log/handler_log.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
// Package log provides a set of http.Handler helpers for zerolog.
|
||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/zenazn/goji/web/mutil"
|
||||
)
|
||||
|
||||
// FromRequest gets the logger in the request's context.
|
||||
// This is a shortcut for log.Ctx(r.Context())
|
||||
func FromRequest(r *http.Request) *zerolog.Logger {
|
||||
return Ctx(r.Context())
|
||||
}
|
||||
|
||||
// NewHandler injects log into requests context.
|
||||
func NewHandler(log zerolog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a copy of the logger (including internal context slice)
|
||||
// to prevent data race when using UpdateContext.
|
||||
l := log.With().Logger()
|
||||
r = r.WithContext(l.WithContext(r.Context()))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// URLHandler adds the requested URL as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func URLHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, r.URL.String())
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MethodHandler adds the request method as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func MethodHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, r.Method)
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequestHandler adds the request method and URL as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func RequestHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, r.Method+" "+r.URL.String())
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RemoteAddrHandler adds the request's remote address as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, host)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// UserAgentHandler adds the request's user-agent as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if ua := r.Header.Get("User-Agent"); ua != "" {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, ua)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RefererHandler adds the request's referer as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func RefererHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if ref := r.Header.Get("Referer"); ref != "" {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, ref)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type idKey struct{}
|
||||
|
||||
// IDFromRequest returns the unique id associated to the request if any.
|
||||
func IDFromRequest(r *http.Request) (id string, ok bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
return IDFromCtx(r.Context())
|
||||
}
|
||||
|
||||
// IDFromCtx returns the unique id associated to the context if any.
|
||||
func IDFromCtx(ctx context.Context) (id string, ok bool) {
|
||||
id, ok = ctx.Value(idKey{}).(string)
|
||||
return
|
||||
}
|
||||
|
||||
// RequestIDHandler returns a handler setting a unique id to the request which can
|
||||
// be gathered using IDFromRequest(req). This generated id is added as a field to the
|
||||
// logger using the passed fieldKey as field name. The id is also added as a response
|
||||
// header if the headerName is not empty.
|
||||
func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
id, ok := IDFromRequest(r)
|
||||
if !ok {
|
||||
id = uuid()
|
||||
ctx = context.WithValue(ctx, idKey{}, id)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
if fieldKey != "" {
|
||||
log := zerolog.Ctx(ctx)
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, id)
|
||||
})
|
||||
}
|
||||
if headerName != "" {
|
||||
w.Header().Set(headerName, id)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// AccessHandler returns a handler that call f after each request.
|
||||
func AccessHandler(f func(r *http.Request, status, size int, duration time.Duration)) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
lw := mutil.WrapWriter(w)
|
||||
next.ServeHTTP(lw, r)
|
||||
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardedAddrHandler returns the client IP address from a request. If present, the
|
||||
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
||||
// rightmost entry (the client IP that connected to the LB) is returned.
|
||||
func ForwardedAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
addr := r.RemoteAddr
|
||||
if ra := r.Header.Get("X-Forwarded-For"); ra != "" {
|
||||
forwardedList := strings.Split(ra, ",")
|
||||
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
||||
if forwardedAddr != "" {
|
||||
addr = forwardedAddr
|
||||
}
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, addr)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// uuid generates a random 128-bit non-RFC UUID.
|
||||
func uuid() string {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:])
|
||||
}
|
260
internal/log/handler_log_test.go
Normal file
260
internal/log/handler_log_test.go
Normal file
|
@ -0,0 +1,260 @@
|
|||
// Package log provides a set of http.Handler helpers for zerolog.
|
||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestGenerateUUID(t *testing.T) {
|
||||
prev := uuid()
|
||||
for i := 0; i < 100; i++ {
|
||||
id := uuid()
|
||||
if id == "" {
|
||||
t.Fatal("random pool failure")
|
||||
}
|
||||
if prev == id {
|
||||
t.Fatalf("Should get a new ID!")
|
||||
}
|
||||
matched, err := regexp.MatchString("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}", id)
|
||||
if !matched || err != nil {
|
||||
t.Fatalf("expected match %s %v %s", id, matched, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeIfBinary(out *bytes.Buffer) string {
|
||||
// p := out.Bytes()
|
||||
// if len(p) == 0 || p[0] < 0x7F {
|
||||
// return out.String()
|
||||
// }
|
||||
return out.String() //cbor.DecodeObjectToStr(p) + "\n"
|
||||
}
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
log := zerolog.New(nil).With().
|
||||
Str("foo", "bar").
|
||||
Logger()
|
||||
lh := NewHandler(log)
|
||||
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
if !reflect.DeepEqual(*l, log) {
|
||||
t.Fail()
|
||||
}
|
||||
}))
|
||||
h.ServeHTTP(nil, &http.Request{})
|
||||
}
|
||||
|
||||
func TestURLHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
}
|
||||
h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteAddrHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
RemoteAddr: "1.2.3.4:1234",
|
||||
}
|
||||
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteAddrHandlerIPv6(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
RemoteAddr: "[2001:db8:a0b:12f0::1]:1234",
|
||||
}
|
||||
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAgentHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"User-Agent": []string{"some user agent string"},
|
||||
},
|
||||
}
|
||||
h := UserAgentHandler("ua")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"ua":"some user agent string"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefererHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"Referer": []string{"http://foo.com/bar"},
|
||||
},
|
||||
}
|
||||
h := RefererHandler("referer")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"referer":"http://foo.com/bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestIDHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"Referer": []string{"http://foo.com/bar"},
|
||||
},
|
||||
}
|
||||
h := RequestIDHandler("id", "Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id, ok := IDFromRequest(r)
|
||||
if !ok {
|
||||
t.Fatal("Missing id in request")
|
||||
}
|
||||
// if want, got := id.String(), w.Header().Get("Request-Id"); got != want {
|
||||
// t.Errorf("Invalid Request-Id header, got: %s, want: %s", got, want)
|
||||
// }
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
if want, got := fmt.Sprintf(`{"id":"%s"}`+"\n", id), decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(httptest.NewRecorder(), r)
|
||||
}
|
||||
|
||||
func TestCombinedHandlers(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandlers(b *testing.B) {
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h2 := MethodHandler("method")(RequestHandler("request")(h1))
|
||||
handlers := map[string]http.Handler{
|
||||
"Single": NewHandler(zerolog.New(ioutil.Discard))(h1),
|
||||
"Combined": NewHandler(zerolog.New(ioutil.Discard))(h2),
|
||||
"SingleDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h1),
|
||||
"CombinedDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h2),
|
||||
}
|
||||
for name := range handlers {
|
||||
h := handlers[name]
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
h.ServeHTTP(nil, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDataRace(b *testing.B) {
|
||||
log := zerolog.New(nil).With().
|
||||
Str("foo", "bar").
|
||||
Logger()
|
||||
lh := NewHandler(log)
|
||||
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("bar", "baz")
|
||||
})
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
h.ServeHTTP(nil, &http.Request{})
|
||||
}
|
||||
})
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -21,19 +21,6 @@ func With() zerolog.Context {
|
|||
return Logger.With()
|
||||
}
|
||||
|
||||
// WithRequest creates a child logger with the remote user added to its context.
|
||||
func WithRequest(req *http.Request, function string) zerolog.Logger {
|
||||
remoteUser := getRemoteAddr(req)
|
||||
return Logger.With().
|
||||
Str("function", function).
|
||||
Str("req-remote-user", remoteUser).
|
||||
Str("req-http-method", req.Method).
|
||||
Str("req-host", req.Host).
|
||||
Str("req-url", req.URL.String()).
|
||||
// Str("req-user-agent", req.Header.Get("User-Agent")).
|
||||
Logger()
|
||||
}
|
||||
|
||||
// Level creates a child logger with the minimum accepted level set to level.
|
||||
func Level(level zerolog.Level) zerolog.Logger {
|
||||
return Logger.Level(level)
|
||||
|
@ -109,3 +96,9 @@ func Print(v ...interface{}) {
|
|||
func Printf(format string, v ...interface{}) {
|
||||
Logger.Printf(format, v...)
|
||||
}
|
||||
|
||||
// Ctx returns the Logger associated with the ctx. If no logger
|
||||
// is associated, a disabled logger is returned.
|
||||
func Ctx(ctx context.Context) *zerolog.Logger {
|
||||
return zerolog.Ctx(ctx)
|
||||
}
|
||||
|
|
|
@ -1,145 +0,0 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Used to stash the authenticated user in the response for access when logging requests.
|
||||
const loggingUserHeader = "SSO-Authenticated-User"
|
||||
const gapMetaDataHeader = "GAP-Auth"
|
||||
|
||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
||||
// code and body size
|
||||
type responseLogger struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
proxyHost string
|
||||
authInfo string
|
||||
}
|
||||
|
||||
func (l *responseLogger) Header() http.Header {
|
||||
return l.w.Header()
|
||||
}
|
||||
|
||||
func (l *responseLogger) extractUser() {
|
||||
authInfo := l.w.Header().Get(loggingUserHeader)
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
l.w.Header().Del(loggingUserHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *responseLogger) ExtractGAPMetadata() {
|
||||
authInfo := l.w.Header().Get(gapMetaDataHeader)
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
|
||||
l.w.Header().Del(gapMetaDataHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *responseLogger) Write(b []byte) (int, error) {
|
||||
if l.status == 0 {
|
||||
// The status will be StatusOK if WriteHeader has not been called yet
|
||||
l.status = http.StatusOK
|
||||
}
|
||||
l.extractUser()
|
||||
l.ExtractGAPMetadata()
|
||||
|
||||
size, err := l.w.Write(b)
|
||||
l.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (l *responseLogger) WriteHeader(s int) {
|
||||
l.extractUser()
|
||||
l.ExtractGAPMetadata()
|
||||
|
||||
l.w.WriteHeader(s)
|
||||
l.status = s
|
||||
}
|
||||
|
||||
func (l *responseLogger) Status() int {
|
||||
return l.status
|
||||
}
|
||||
|
||||
func (l *responseLogger) Size() int {
|
||||
return l.size
|
||||
}
|
||||
|
||||
func (l *responseLogger) Flush() {
|
||||
f := l.w.(http.Flusher)
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends
|
||||
type loggingHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer.
|
||||
func NewLoggingHandler(h http.Handler) http.Handler {
|
||||
return loggingHandler{
|
||||
handler: h,
|
||||
}
|
||||
}
|
||||
|
||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
t := time.Now()
|
||||
url := *req.URL
|
||||
logger := &responseLogger{w: w, proxyHost: getProxyHost(req)}
|
||||
h.handler.ServeHTTP(logger, req)
|
||||
requestDuration := time.Since(t)
|
||||
|
||||
logRequest(logger.proxyHost, logger.authInfo, req, url, requestDuration, logger.Status())
|
||||
}
|
||||
|
||||
// logRequest logs information about a request
|
||||
func logRequest(proxyHost, username string, req *http.Request, url url.URL, requestDuration time.Duration, status int) {
|
||||
uri := req.Host + url.RequestURI()
|
||||
Info().
|
||||
Int("http-status", status).
|
||||
Str("request-method", req.Method).
|
||||
Str("request-uri", uri).
|
||||
Str("proxy-host", proxyHost).
|
||||
// Str("user-agent", req.Header.Get("User-Agent")).
|
||||
Str("remote-address", getRemoteAddr(req)).
|
||||
Dur("duration", requestDuration).
|
||||
Str("user", username).
|
||||
Msg("request")
|
||||
|
||||
}
|
||||
|
||||
// getRemoteAddr returns the client IP address from a request. If present, the
|
||||
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
||||
// rightmost entry (the client IP that connected to the LB) is returned.
|
||||
func getRemoteAddr(req *http.Request) string {
|
||||
addr := req.RemoteAddr
|
||||
forwardedHeader := req.Header.Get("X-Forwarded-For")
|
||||
if forwardedHeader != "" {
|
||||
forwardedList := strings.Split(forwardedHeader, ",")
|
||||
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
||||
if forwardedAddr != "" {
|
||||
addr = forwardedAddr
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// getProxyHost attempts to get the proxy host from the redirect_uri parameter
|
||||
func getProxyHost(req *http.Request) string {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
redirect := req.Form.Get("redirect_uri")
|
||||
redirectURL, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return redirectURL.Host
|
||||
}
|
|
@ -1,72 +0,0 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetRemoteAddr(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
forwardedHeader string
|
||||
expectedAddr string
|
||||
}{
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is given",
|
||||
remoteAddr: "1.1.1.1",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is only whitespace",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " , , ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For header is preferred to RemoteAddr",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "9.9.9.9",
|
||||
expectedAddr: "9.9.9.9",
|
||||
},
|
||||
{
|
||||
name: "rightmost entry in X-Forwarded-For header is used",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5",
|
||||
expectedAddr: "5.5.5.5",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "2.2.2.2, 3.3.3.3, ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwaded-For header entries are stripped",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ",
|
||||
expectedAddr: "5.5.5.5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
if tc.forwardedHeader != "" {
|
||||
req.Header.Set("X-Forwarded-For", tc.forwardedHeader)
|
||||
}
|
||||
|
||||
addr := getRemoteAddr(req)
|
||||
if addr != tc.expectedAddr {
|
||||
t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue