package log import ( "bytes" "fmt" "io/ioutil" "net/http" "net/http/httptest" "net/url" "reflect" "regexp" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/rs/zerolog" "github.com/pomerium/pomerium/internal/telemetry/requestid" ) func TestGenerateUUID(t *testing.T) { prev := uuid() for i := 0; i < 100; i++ { id := uuid() if id == "" { t.Fatal("random pool failure") } if prev == id { t.Fatalf("Should get a new ID!") } matched := regexp.MustCompile("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}").MatchString(id) if !matched { t.Fatalf("expected match %s %v", id, matched) } } } func decodeIfBinary(out fmt.Stringer) string { return out.String() } func TestNewHandler(t *testing.T) { log := zerolog.New(nil).With(). Str("foo", "bar"). Logger() lh := NewHandler(func() *zerolog.Logger { return &log }) h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) if !reflect.DeepEqual(*l, log) { t.Fail() } })) h.ServeHTTP(nil, &http.Request{}) } func TestURLHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, } h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestMethodHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ Method: "POST", } h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestRequestHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ Method: "POST", URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, } h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestRemoteAddrHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ RemoteAddr: "1.2.3.4:1234", } h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestRemoteAddrHandlerIPv6(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ RemoteAddr: "[2001:db8:a0b:12f0::1]:1234", } h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestUserAgentHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ Header: http.Header{ "User-Agent": []string{"some user agent string"}, }, } h := UserAgentHandler("ua")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"ua":"some user agent string"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestRefererHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ Header: http.Header{ "Referer": []string{"http://foo.com/bar"}, }, } h := RefererHandler("referer")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"referer":"http://foo.com/bar"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestRequestIDHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ Header: http.Header{ "X-Request-Id": []string{"1234"}, }, } h := RequestIDHandler("request-id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestID := requestid.FromContext(r.Context()) l := FromRequest(r) l.Log().Msg("") if want, got := fmt.Sprintf(`{"request-id":"%s"}`+"\n", requestID), decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h = requestid.HTTPMiddleware()(h) h.ServeHTTP(httptest.NewRecorder(), r) } func TestCombinedHandlers(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ Method: "POST", URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, } h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })))) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func BenchmarkHandlers(b *testing.B) { r := &http.Request{ Method: "POST", URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, } h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) h2 := MethodHandler("method")(RequestHandler("request")(h1)) handlers := map[string]http.Handler{ "Single": NewHandler(func() *zerolog.Logger { log := zerolog.New(ioutil.Discard) return &log })(h1), "Combined": NewHandler((func() *zerolog.Logger { log := zerolog.New(ioutil.Discard) return &log }))(h2), "SingleDisabled": NewHandler((func() *zerolog.Logger { log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled) return &log }))(h1), "CombinedDisabled": NewHandler((func() *zerolog.Logger { log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled) return &log }))(h2), } for name := range handlers { h := handlers[name] b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { h.ServeHTTP(nil, r) } }) } } func BenchmarkDataRace(b *testing.B) { log := zerolog.New(nil).With(). Str("foo", "bar"). Logger() lh := NewHandler(func() *zerolog.Logger { return &log }) h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("bar", "baz") }) l.Log().Msg("") })) b.RunParallel(func(pb *testing.PB) { for pb.Next() { h.ServeHTTP(nil, &http.Request{}) } }) } func TestLogHeadersHandler(t *testing.T) { out := &bytes.Buffer{} r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-For", "proxy1,proxy2,proxy3") h := HeadersHandler([]string{"X-Forwarded-For"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) h.ServeHTTP(nil, r) if want, got := `{"X-Forwarded-For":["proxy1,proxy2,proxy3"]}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } func TestAccessHandler(t *testing.T) { out := &bytes.Buffer{} r := httptest.NewRequest(http.MethodGet, "/", nil) h := AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { l := FromRequest(r) l.Log().Int("status", status).Int("size", size).Msg("info") })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("some inner logging") w.Write([]byte("Add something to the request of non-zero size")) })) log := zerolog.New(out) h = NewHandler(func() *zerolog.Logger { return &log })(h) w := httptest.NewRecorder() h.ServeHTTP(w, r) want := "{\"message\":\"some inner logging\"}\n{\"status\":200,\"size\":45,\"message\":\"info\"}\n" got := decodeIfBinary(out) if diff := cmp.Diff(want, got); diff != "" { t.Errorf("TestAccessHandler: %s", diff) } }