diff --git a/authorize/session.go b/authorize/session.go index e6e466904..dac5aa5cb 100644 --- a/authorize/session.go +++ b/authorize/session.go @@ -2,9 +2,7 @@ package authorize import ( "errors" - "fmt" "net/http" - "net/http/httptest" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" @@ -64,22 +62,3 @@ func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler } return cookieStore, nil } - -func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (map[string]string, error) { - recorder := httptest.NewRecorder() - err := cookieStore.SaveSession(recorder, nil /* unused by cookie store */, string(rawjwt)) - if err != nil { - return nil, fmt.Errorf("authorize: error saving cookie: %w", err) - } - - res := recorder.Result() - res.Body.Close() - - hdrs := make(map[string]string) - for k, vs := range res.Header { - for _, v := range vs { - hdrs[k] = v - } - } - return hdrs, nil -} diff --git a/authorize/session_test.go b/authorize/session_test.go index fbe5a75f5..725007e51 100644 --- a/authorize/session_test.go +++ b/authorize/session_test.go @@ -2,7 +2,6 @@ package authorize import ( "net/url" - "regexp" "testing" envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" @@ -45,31 +44,6 @@ func TestLoadSession(t *testing.T) { return &state, nil } - t.Run("cookie", func(t *testing.T) { - cookieStore, err := getCookieStore(opts, encoder) - if !assert.NoError(t, err) { - return - } - hdrs, err := getJWTSetCookieHeaders(cookieStore, rawjwt) - if !assert.NoError(t, err) { - return - } - cookie := regexp.MustCompile(`^([^;]+)(;.*)?$`).ReplaceAllString(hdrs["Set-Cookie"], "$1") - - hattrs := &envoy_service_auth_v3.AttributeContext_HttpRequest{ - Id: "req-1", - Method: "GET", - Headers: map[string]string{ - "Cookie": cookie, - }, - Path: "/hello/world", - Host: "example.com", - Scheme: "https", - } - sess, err := load(t, hattrs) - assert.NoError(t, err) - assert.NotNil(t, sess) - }) t.Run("header", func(t *testing.T) { hattrs := &envoy_service_auth_v3.AttributeContext_HttpRequest{ Id: "req-1", diff --git a/config/custom.go b/config/custom.go index 204adc61f..bb68e2469 100644 --- a/config/custom.go +++ b/config/custom.go @@ -448,34 +448,6 @@ func parseTo(raw interface{}) ([]WeightedURL, error) { return ParseWeightedUrls(slc...) } -func weightedStrings(src StringSlice) (endpoints StringSlice, weights []uint32, err error) { - weights = make([]uint32, len(src)) - endpoints = make([]string, len(src)) - - noWeight := false - hasWeight := false - for i, str := range src { - endpoints[i], weights[i], err = weightedString(str) - if err != nil { - return nil, nil, err - } - if weights[i] == 0 { - noWeight = true - } else { - hasWeight = true - } - } - - if noWeight == hasWeight { - return nil, nil, errEndpointWeightsSpec - } - - if noWeight { - return endpoints, nil, nil - } - return endpoints, weights, nil -} - // parses URL followed by weighted func weightedString(str string) (string, uint32, error) { i := strings.IndexRune(str, ',') diff --git a/config/custom_test.go b/config/custom_test.go index a85af5cd0..20acec625 100644 --- a/config/custom_test.go +++ b/config/custom_test.go @@ -3,7 +3,6 @@ package config import ( "encoding/base64" "encoding/json" - "fmt" "testing" "github.com/mitchellh/mapstructure" @@ -147,52 +146,6 @@ func TestSerializable(t *testing.T) { require.NoError(t, err, "json marshal") } -func TestWeightedStringSlice(t *testing.T) { - tcases := []struct { - In StringSlice - Out StringSlice - Weights []uint32 - Error bool - }{ - { - StringSlice{"https://srv-1.int.corp.com,1", "https://srv-2.int.corp.com,2", "http://10.0.1.1:8080,3", "http://localhost:8000,4"}, - StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"}, - []uint32{1, 2, 3, 4}, - false, - }, - { // all should be provided - StringSlice{"https://srv-1.int.corp.com,1", "https://srv-2.int.corp.com", "http://10.0.1.1:8080,3", "http://localhost:8000,4"}, - nil, - nil, - true, - }, - { // or none - StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"}, - StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"}, - nil, - false, - }, - { // IPv6 https://tools.ietf.org/html/rfc2732 - StringSlice{"http://[::FFFF:129.144.52.38]:8080,1", "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8080/,2"}, - StringSlice{"http://[::FFFF:129.144.52.38]:8080", "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8080/"}, - []uint32{1, 2}, - false, - }, - } - - for _, tc := range tcases { - name := fmt.Sprintf("%s", tc.In) - out, weights, err := weightedStrings(tc.In) - if tc.Error { - assert.Error(t, err, name) - } else { - assert.NoError(t, err, name) - } - assert.Equal(t, tc.Out, out, name) - assert.Equal(t, tc.Weights, weights, name) - } -} - func TestDecodePPLPolicyHookFunc(t *testing.T) { var withPolicy struct { Policy *PPLPolicy `mapstructure:"policy"` diff --git a/internal/log/middleware.go b/internal/log/middleware.go index e2b3cabc9..c30168dc9 100644 --- a/internal/log/middleware.go +++ b/internal/log/middleware.go @@ -1,8 +1,6 @@ package log import ( - "crypto/rand" - "fmt" "net" "net/http" "time" @@ -26,48 +24,6 @@ func NewHandler(getLogger func() *zerolog.Logger) func(http.Handler) http.Handle } } -// URLHandler adds the requested URL as a field to the context's logger -// using fieldKey as field key. -func URLHandler(fieldKey string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log := zerolog.Ctx(r.Context()) - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str(fieldKey, r.URL.String()) - }) - next.ServeHTTP(w, r) - }) - } -} - -// MethodHandler adds the request method as a field to the context's logger -// using fieldKey as field key. -func MethodHandler(fieldKey string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log := zerolog.Ctx(r.Context()) - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str(fieldKey, r.Method) - }) - next.ServeHTTP(w, r) - }) - } -} - -// RequestHandler adds the request method and URL as a field to the context's logger -// using fieldKey as field key. -func RequestHandler(fieldKey string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log := zerolog.Ctx(r.Context()) - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str(fieldKey, r.Method+" "+r.URL.String()) - }) - next.ServeHTTP(w, r) - }) - } -} - // RemoteAddrHandler adds the request's remote address as a field to the context's logger // using fieldKey as field key. func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler { @@ -165,12 +121,3 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler { }) } } - -// uuid generates a random 128-bit non-RFC UUID. -func uuid() string { - buf := make([]byte, 16) - if _, err := rand.Read(buf); err != nil { - return "" - } - return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:]) -} diff --git a/internal/log/middleware_test.go b/internal/log/middleware_test.go index dff755b55..2a2dc1c6b 100644 --- a/internal/log/middleware_test.go +++ b/internal/log/middleware_test.go @@ -3,12 +3,9 @@ package log import ( "bytes" "fmt" - "io/ioutil" "net/http" "net/http/httptest" - "net/url" "reflect" - "regexp" "testing" "time" @@ -18,23 +15,6 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/requestid" ) -func TestGenerateUUID(t *testing.T) { - prev := uuid() - for i := 0; i < 100; i++ { - id := uuid() - if id == "" { - t.Fatal("random pool failure") - } - if prev == id { - t.Fatalf("Should get a new ID!") - } - matched := regexp.MustCompile("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}").MatchString(id) - if !matched { - t.Fatalf("expected match %s %v", id, matched) - } - } -} - func decodeIfBinary(out fmt.Stringer) string { return out.String() } @@ -53,58 +33,6 @@ func TestNewHandler(t *testing.T) { h.ServeHTTP(nil, &http.Request{}) } -func TestURLHandler(t *testing.T) { - out := &bytes.Buffer{} - r := &http.Request{ - URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, - } - h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l := FromRequest(r) - l.Log().Msg("") - })) - log := zerolog.New(out) - h = NewHandler(func() *zerolog.Logger { return &log })(h) - h.ServeHTTP(nil, r) - if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } -} - -func TestMethodHandler(t *testing.T) { - out := &bytes.Buffer{} - r := &http.Request{ - Method: "POST", - } - h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l := FromRequest(r) - l.Log().Msg("") - })) - log := zerolog.New(out) - h = NewHandler(func() *zerolog.Logger { return &log })(h) - h.ServeHTTP(nil, r) - if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } -} - -func TestRequestHandler(t *testing.T) { - out := &bytes.Buffer{} - r := &http.Request{ - Method: "POST", - URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, - } - h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l := FromRequest(r) - l.Log().Msg("") - })) - log := zerolog.New(out) - h = NewHandler(func() *zerolog.Logger { return &log })(h) - h.ServeHTTP(nil, r) - if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } -} - func TestRemoteAddrHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ @@ -198,62 +126,6 @@ func TestRequestIDHandler(t *testing.T) { h.ServeHTTP(httptest.NewRecorder(), r) } -func TestCombinedHandlers(t *testing.T) { - out := &bytes.Buffer{} - r := &http.Request{ - Method: "POST", - URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, - } - h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l := FromRequest(r) - l.Log().Msg("") - })))) - log := zerolog.New(out) - h = NewHandler(func() *zerolog.Logger { return &log })(h) - h.ServeHTTP(nil, r) - if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } -} - -func BenchmarkHandlers(b *testing.B) { - r := &http.Request{ - Method: "POST", - URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, - } - h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l := FromRequest(r) - l.Log().Msg("") - })) - h2 := MethodHandler("method")(RequestHandler("request")(h1)) - handlers := map[string]http.Handler{ - "Single": NewHandler(func() *zerolog.Logger { - log := zerolog.New(ioutil.Discard) - return &log - })(h1), - "Combined": NewHandler((func() *zerolog.Logger { - log := zerolog.New(ioutil.Discard) - return &log - }))(h2), - "SingleDisabled": NewHandler((func() *zerolog.Logger { - log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled) - return &log - }))(h1), - "CombinedDisabled": NewHandler((func() *zerolog.Logger { - log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled) - return &log - }))(h2), - } - for name := range handlers { - h := handlers[name] - b.Run(name, func(b *testing.B) { - for i := 0; i < b.N; i++ { - h.ServeHTTP(nil, r) - } - }) - } -} - func BenchmarkDataRace(b *testing.B) { log := zerolog.New(nil).With(). Str("foo", "bar"). diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 1de90f028..79454f6d5 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -4,8 +4,6 @@ import ( "crypto/sha256" "crypto/subtle" "net/http" - "strings" - "time" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/telemetry/trace" @@ -48,36 +46,6 @@ func ValidateRequestURL(r *http.Request, key []byte) error { return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate() } -// StripCookie strips the cookie from the downstram request. -func StripCookie(cookieName string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie") - defer span.End() - - headers := make([]string, 0, len(r.Cookies())) - for _, cookie := range r.Cookies() { - if !strings.HasPrefix(cookie.Name, cookieName) { - headers = append(headers, cookie.String()) - } - } - r.Header.Set("Cookie", strings.Join(headers, ";")) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - -// TimeoutHandlerFunc wraps http.TimeoutHandler -func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc") - defer span.End() - http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - // RequireBasicAuth creates a new handler that requires basic auth from the client before // calling the underlying handler. func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler { diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index fab373b90..26d80125b 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -6,7 +6,6 @@ import ( "net/http/httptest" "net/url" "testing" - "time" "github.com/google/go-cmp/cmp" @@ -41,81 +40,6 @@ func TestSetHeaders(t *testing.T) { } } -func TestStripCookie(t *testing.T) { - tests := []struct { - name string - pomeriumCookie string - otherCookies []string - }{ - {"good", "pomerium", []string{"x", "y", "z"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for _, cookie := range r.Cookies() { - if cookie.Name == tt.pomeriumCookie { - t.Errorf("cookie not stripped %s", r.Cookies()) - } - } - }) - rr := httptest.NewRecorder() - for _, cn := range tt.otherCookies { - http.SetCookie(rr, &http.Cookie{ - Name: cn, - Value: "some other cookie", - }) - } - - http.SetCookie(rr, &http.Cookie{ - Name: tt.pomeriumCookie, - Value: "pomerium cookie!", - }) - - http.SetCookie(rr, &http.Cookie{ - Name: tt.pomeriumCookie + "_csrf", - Value: "pomerium csrf cookie!", - }) - req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} - - handler := StripCookie(tt.pomeriumCookie)(testHandler) - handler.ServeHTTP(rr, req) - }) - } -} - -func TestTimeoutHandlerFunc(t *testing.T) { - t.Parallel() - fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, http.StatusText(http.StatusOK)) - w.WriteHeader(http.StatusOK) - }) - - tests := []struct { - name string - timeout time.Duration - timeoutError string - wantStatus int - wantBody string - }{ - {"good", 1 * time.Second, "good timed out!?", http.StatusOK, http.StatusText(http.StatusOK)}, - {"timeout!", 1 * time.Nanosecond, "ruh roh", http.StatusServiceUnavailable, "ruh roh"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() - got := TimeoutHandlerFunc(tt.timeout, tt.timeoutError)(fn) - got.ServeHTTP(w, r) - if status := w.Code; status != tt.wantStatus { - t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) - } - if body := w.Body.String(); tt.wantBody != body { - t.Errorf("SignRequest() body = %v, want %v", body, tt.wantBody) - } - }) - } -} - func TestValidateSignature(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -143,11 +67,11 @@ func TestValidateSignature(t *testing.T) { got := ValidateSignature(tt.secretA)(fn) got.ServeHTTP(w, r) if status := w.Code; status != tt.wantStatus { - t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + t.Errorf("ValidateSignature() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) } body := w.Body.String() if diff := cmp.Diff(body, tt.wantBody); diff != "" { - t.Errorf("SignRequest() %s", diff) + t.Errorf("ValidateSignature() %s", diff) t.Errorf("%s", signedURL) } })