all: remove unused handler code (#2439)

* - Remove unused middleware

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>

* remove unused func weightedStrings

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>

* remove unused func getJWTSetCookieHeaders

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>

* Fix test name
This commit is contained in:
bobby 2021-08-16 13:04:39 -07:00 committed by GitHub
parent 87c9ace12c
commit 87c3c675d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 2 additions and 413 deletions

View file

@ -2,9 +2,7 @@ package authorize
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
@ -64,22 +62,3 @@ func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler
}
return cookieStore, nil
}
func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (map[string]string, error) {
recorder := httptest.NewRecorder()
err := cookieStore.SaveSession(recorder, nil /* unused by cookie store */, string(rawjwt))
if err != nil {
return nil, fmt.Errorf("authorize: error saving cookie: %w", err)
}
res := recorder.Result()
res.Body.Close()
hdrs := make(map[string]string)
for k, vs := range res.Header {
for _, v := range vs {
hdrs[k] = v
}
}
return hdrs, nil
}

View file

@ -2,7 +2,6 @@ package authorize
import (
"net/url"
"regexp"
"testing"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
@ -45,31 +44,6 @@ func TestLoadSession(t *testing.T) {
return &state, nil
}
t.Run("cookie", func(t *testing.T) {
cookieStore, err := getCookieStore(opts, encoder)
if !assert.NoError(t, err) {
return
}
hdrs, err := getJWTSetCookieHeaders(cookieStore, rawjwt)
if !assert.NoError(t, err) {
return
}
cookie := regexp.MustCompile(`^([^;]+)(;.*)?$`).ReplaceAllString(hdrs["Set-Cookie"], "$1")
hattrs := &envoy_service_auth_v3.AttributeContext_HttpRequest{
Id: "req-1",
Method: "GET",
Headers: map[string]string{
"Cookie": cookie,
},
Path: "/hello/world",
Host: "example.com",
Scheme: "https",
}
sess, err := load(t, hattrs)
assert.NoError(t, err)
assert.NotNil(t, sess)
})
t.Run("header", func(t *testing.T) {
hattrs := &envoy_service_auth_v3.AttributeContext_HttpRequest{
Id: "req-1",

View file

@ -448,34 +448,6 @@ func parseTo(raw interface{}) ([]WeightedURL, error) {
return ParseWeightedUrls(slc...)
}
func weightedStrings(src StringSlice) (endpoints StringSlice, weights []uint32, err error) {
weights = make([]uint32, len(src))
endpoints = make([]string, len(src))
noWeight := false
hasWeight := false
for i, str := range src {
endpoints[i], weights[i], err = weightedString(str)
if err != nil {
return nil, nil, err
}
if weights[i] == 0 {
noWeight = true
} else {
hasWeight = true
}
}
if noWeight == hasWeight {
return nil, nil, errEndpointWeightsSpec
}
if noWeight {
return endpoints, nil, nil
}
return endpoints, weights, nil
}
// parses URL followed by weighted
func weightedString(str string) (string, uint32, error) {
i := strings.IndexRune(str, ',')

View file

@ -3,7 +3,6 @@ package config
import (
"encoding/base64"
"encoding/json"
"fmt"
"testing"
"github.com/mitchellh/mapstructure"
@ -147,52 +146,6 @@ func TestSerializable(t *testing.T) {
require.NoError(t, err, "json marshal")
}
func TestWeightedStringSlice(t *testing.T) {
tcases := []struct {
In StringSlice
Out StringSlice
Weights []uint32
Error bool
}{
{
StringSlice{"https://srv-1.int.corp.com,1", "https://srv-2.int.corp.com,2", "http://10.0.1.1:8080,3", "http://localhost:8000,4"},
StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"},
[]uint32{1, 2, 3, 4},
false,
},
{ // all should be provided
StringSlice{"https://srv-1.int.corp.com,1", "https://srv-2.int.corp.com", "http://10.0.1.1:8080,3", "http://localhost:8000,4"},
nil,
nil,
true,
},
{ // or none
StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"},
StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"},
nil,
false,
},
{ // IPv6 https://tools.ietf.org/html/rfc2732
StringSlice{"http://[::FFFF:129.144.52.38]:8080,1", "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8080/,2"},
StringSlice{"http://[::FFFF:129.144.52.38]:8080", "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8080/"},
[]uint32{1, 2},
false,
},
}
for _, tc := range tcases {
name := fmt.Sprintf("%s", tc.In)
out, weights, err := weightedStrings(tc.In)
if tc.Error {
assert.Error(t, err, name)
} else {
assert.NoError(t, err, name)
}
assert.Equal(t, tc.Out, out, name)
assert.Equal(t, tc.Weights, weights, name)
}
}
func TestDecodePPLPolicyHookFunc(t *testing.T) {
var withPolicy struct {
Policy *PPLPolicy `mapstructure:"policy"`

View file

@ -1,8 +1,6 @@
package log
import (
"crypto/rand"
"fmt"
"net"
"net/http"
"time"
@ -26,48 +24,6 @@ func NewHandler(getLogger func() *zerolog.Logger) func(http.Handler) http.Handle
}
}
// URLHandler adds the requested URL as a field to the context's logger
// using fieldKey as field key.
func URLHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, r.URL.String())
})
next.ServeHTTP(w, r)
})
}
}
// MethodHandler adds the request method as a field to the context's logger
// using fieldKey as field key.
func MethodHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, r.Method)
})
next.ServeHTTP(w, r)
})
}
}
// RequestHandler adds the request method and URL as a field to the context's logger
// using fieldKey as field key.
func RequestHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, r.Method+" "+r.URL.String())
})
next.ServeHTTP(w, r)
})
}
}
// RemoteAddrHandler adds the request's remote address as a field to the context's logger
// using fieldKey as field key.
func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
@ -165,12 +121,3 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
})
}
}
// uuid generates a random 128-bit non-RFC UUID.
func uuid() string {
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return ""
}
return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:])
}

View file

@ -3,12 +3,9 @@ package log
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"regexp"
"testing"
"time"
@ -18,23 +15,6 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/requestid"
)
func TestGenerateUUID(t *testing.T) {
prev := uuid()
for i := 0; i < 100; i++ {
id := uuid()
if id == "" {
t.Fatal("random pool failure")
}
if prev == id {
t.Fatalf("Should get a new ID!")
}
matched := regexp.MustCompile("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}").MatchString(id)
if !matched {
t.Fatalf("expected match %s %v", id, matched)
}
}
}
func decodeIfBinary(out fmt.Stringer) string {
return out.String()
}
@ -53,58 +33,6 @@ func TestNewHandler(t *testing.T) {
h.ServeHTTP(nil, &http.Request{})
}
func TestURLHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
}
h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
log := zerolog.New(out)
h = NewHandler(func() *zerolog.Logger { return &log })(h)
h.ServeHTTP(nil, r)
if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}
func TestMethodHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Method: "POST",
}
h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
log := zerolog.New(out)
h = NewHandler(func() *zerolog.Logger { return &log })(h)
h.ServeHTTP(nil, r)
if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}
func TestRequestHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
}
h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
log := zerolog.New(out)
h = NewHandler(func() *zerolog.Logger { return &log })(h)
h.ServeHTTP(nil, r)
if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}
func TestRemoteAddrHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
@ -198,62 +126,6 @@ func TestRequestIDHandler(t *testing.T) {
h.ServeHTTP(httptest.NewRecorder(), r)
}
func TestCombinedHandlers(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
}
h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))))
log := zerolog.New(out)
h = NewHandler(func() *zerolog.Logger { return &log })(h)
h.ServeHTTP(nil, r)
if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}
func BenchmarkHandlers(b *testing.B) {
r := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
}
h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h2 := MethodHandler("method")(RequestHandler("request")(h1))
handlers := map[string]http.Handler{
"Single": NewHandler(func() *zerolog.Logger {
log := zerolog.New(ioutil.Discard)
return &log
})(h1),
"Combined": NewHandler((func() *zerolog.Logger {
log := zerolog.New(ioutil.Discard)
return &log
}))(h2),
"SingleDisabled": NewHandler((func() *zerolog.Logger {
log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled)
return &log
}))(h1),
"CombinedDisabled": NewHandler((func() *zerolog.Logger {
log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled)
return &log
}))(h2),
}
for name := range handlers {
h := handlers[name]
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
h.ServeHTTP(nil, r)
}
})
}
}
func BenchmarkDataRace(b *testing.B) {
log := zerolog.New(nil).With().
Str("foo", "bar").

View file

@ -4,8 +4,6 @@ import (
"crypto/sha256"
"crypto/subtle"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace"
@ -48,36 +46,6 @@ func ValidateRequestURL(r *http.Request, key []byte) error {
return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate()
}
// StripCookie strips the cookie from the downstram request.
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie")
defer span.End()
headers := make([]string, 0, len(r.Cookies()))
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, cookieName) {
headers = append(headers, cookie.String())
}
}
r.Header.Set("Cookie", strings.Join(headers, ";"))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// TimeoutHandlerFunc wraps http.TimeoutHandler
func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc")
defer span.End()
http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequireBasicAuth creates a new handler that requires basic auth from the client before
// calling the underlying handler.
func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler {

View file

@ -6,7 +6,6 @@ import (
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
@ -41,81 +40,6 @@ func TestSetHeaders(t *testing.T) {
}
}
func TestStripCookie(t *testing.T) {
tests := []struct {
name string
pomeriumCookie string
otherCookies []string
}{
{"good", "pomerium", []string{"x", "y", "z"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, cookie := range r.Cookies() {
if cookie.Name == tt.pomeriumCookie {
t.Errorf("cookie not stripped %s", r.Cookies())
}
}
})
rr := httptest.NewRecorder()
for _, cn := range tt.otherCookies {
http.SetCookie(rr, &http.Cookie{
Name: cn,
Value: "some other cookie",
})
}
http.SetCookie(rr, &http.Cookie{
Name: tt.pomeriumCookie,
Value: "pomerium cookie!",
})
http.SetCookie(rr, &http.Cookie{
Name: tt.pomeriumCookie + "_csrf",
Value: "pomerium csrf cookie!",
})
req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}}
handler := StripCookie(tt.pomeriumCookie)(testHandler)
handler.ServeHTTP(rr, req)
})
}
}
func TestTimeoutHandlerFunc(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
timeout time.Duration
timeoutError string
wantStatus int
wantBody string
}{
{"good", 1 * time.Second, "good timed out!?", http.StatusOK, http.StatusText(http.StatusOK)},
{"timeout!", 1 * time.Nanosecond, "ruh roh", http.StatusServiceUnavailable, "ruh roh"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
got := TimeoutHandlerFunc(tt.timeout, tt.timeoutError)(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
if body := w.Body.String(); tt.wantBody != body {
t.Errorf("SignRequest() body = %v, want %v", body, tt.wantBody)
}
})
}
}
func TestValidateSignature(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -143,11 +67,11 @@ func TestValidateSignature(t *testing.T) {
got := ValidateSignature(tt.secretA)(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
t.Errorf("ValidateSignature() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("SignRequest() %s", diff)
t.Errorf("ValidateSignature() %s", diff)
t.Errorf("%s", signedURL)
}
})