mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 19:32:48 +02:00
all: remove unused handler code (#2439)
* - Remove unused middleware Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com> * remove unused func weightedStrings Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com> * remove unused func getJWTSetCookieHeaders Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com> * Fix test name
This commit is contained in:
parent
87c9ace12c
commit
87c3c675d2
8 changed files with 2 additions and 413 deletions
|
@ -2,9 +2,7 @@ package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
@ -64,22 +62,3 @@ func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler
|
||||||
}
|
}
|
||||||
return cookieStore, nil
|
return cookieStore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (map[string]string, error) {
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
err := cookieStore.SaveSession(recorder, nil /* unused by cookie store */, string(rawjwt))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("authorize: error saving cookie: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
res := recorder.Result()
|
|
||||||
res.Body.Close()
|
|
||||||
|
|
||||||
hdrs := make(map[string]string)
|
|
||||||
for k, vs := range res.Header {
|
|
||||||
for _, v := range vs {
|
|
||||||
hdrs[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return hdrs, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
||||||
|
@ -45,31 +44,6 @@ func TestLoadSession(t *testing.T) {
|
||||||
return &state, nil
|
return &state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("cookie", func(t *testing.T) {
|
|
||||||
cookieStore, err := getCookieStore(opts, encoder)
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hdrs, err := getJWTSetCookieHeaders(cookieStore, rawjwt)
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cookie := regexp.MustCompile(`^([^;]+)(;.*)?$`).ReplaceAllString(hdrs["Set-Cookie"], "$1")
|
|
||||||
|
|
||||||
hattrs := &envoy_service_auth_v3.AttributeContext_HttpRequest{
|
|
||||||
Id: "req-1",
|
|
||||||
Method: "GET",
|
|
||||||
Headers: map[string]string{
|
|
||||||
"Cookie": cookie,
|
|
||||||
},
|
|
||||||
Path: "/hello/world",
|
|
||||||
Host: "example.com",
|
|
||||||
Scheme: "https",
|
|
||||||
}
|
|
||||||
sess, err := load(t, hattrs)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, sess)
|
|
||||||
})
|
|
||||||
t.Run("header", func(t *testing.T) {
|
t.Run("header", func(t *testing.T) {
|
||||||
hattrs := &envoy_service_auth_v3.AttributeContext_HttpRequest{
|
hattrs := &envoy_service_auth_v3.AttributeContext_HttpRequest{
|
||||||
Id: "req-1",
|
Id: "req-1",
|
||||||
|
|
|
@ -448,34 +448,6 @@ func parseTo(raw interface{}) ([]WeightedURL, error) {
|
||||||
return ParseWeightedUrls(slc...)
|
return ParseWeightedUrls(slc...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func weightedStrings(src StringSlice) (endpoints StringSlice, weights []uint32, err error) {
|
|
||||||
weights = make([]uint32, len(src))
|
|
||||||
endpoints = make([]string, len(src))
|
|
||||||
|
|
||||||
noWeight := false
|
|
||||||
hasWeight := false
|
|
||||||
for i, str := range src {
|
|
||||||
endpoints[i], weights[i], err = weightedString(str)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
if weights[i] == 0 {
|
|
||||||
noWeight = true
|
|
||||||
} else {
|
|
||||||
hasWeight = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if noWeight == hasWeight {
|
|
||||||
return nil, nil, errEndpointWeightsSpec
|
|
||||||
}
|
|
||||||
|
|
||||||
if noWeight {
|
|
||||||
return endpoints, nil, nil
|
|
||||||
}
|
|
||||||
return endpoints, weights, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parses URL followed by weighted
|
// parses URL followed by weighted
|
||||||
func weightedString(str string) (string, uint32, error) {
|
func weightedString(str string) (string, uint32, error) {
|
||||||
i := strings.IndexRune(str, ',')
|
i := strings.IndexRune(str, ',')
|
||||||
|
|
|
@ -3,7 +3,6 @@ package config
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
|
@ -147,52 +146,6 @@ func TestSerializable(t *testing.T) {
|
||||||
require.NoError(t, err, "json marshal")
|
require.NoError(t, err, "json marshal")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWeightedStringSlice(t *testing.T) {
|
|
||||||
tcases := []struct {
|
|
||||||
In StringSlice
|
|
||||||
Out StringSlice
|
|
||||||
Weights []uint32
|
|
||||||
Error bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
StringSlice{"https://srv-1.int.corp.com,1", "https://srv-2.int.corp.com,2", "http://10.0.1.1:8080,3", "http://localhost:8000,4"},
|
|
||||||
StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"},
|
|
||||||
[]uint32{1, 2, 3, 4},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{ // all should be provided
|
|
||||||
StringSlice{"https://srv-1.int.corp.com,1", "https://srv-2.int.corp.com", "http://10.0.1.1:8080,3", "http://localhost:8000,4"},
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
{ // or none
|
|
||||||
StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"},
|
|
||||||
StringSlice{"https://srv-1.int.corp.com", "https://srv-2.int.corp.com", "http://10.0.1.1:8080", "http://localhost:8000"},
|
|
||||||
nil,
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{ // IPv6 https://tools.ietf.org/html/rfc2732
|
|
||||||
StringSlice{"http://[::FFFF:129.144.52.38]:8080,1", "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8080/,2"},
|
|
||||||
StringSlice{"http://[::FFFF:129.144.52.38]:8080", "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8080/"},
|
|
||||||
[]uint32{1, 2},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tcases {
|
|
||||||
name := fmt.Sprintf("%s", tc.In)
|
|
||||||
out, weights, err := weightedStrings(tc.In)
|
|
||||||
if tc.Error {
|
|
||||||
assert.Error(t, err, name)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err, name)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.Out, out, name)
|
|
||||||
assert.Equal(t, tc.Weights, weights, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecodePPLPolicyHookFunc(t *testing.T) {
|
func TestDecodePPLPolicyHookFunc(t *testing.T) {
|
||||||
var withPolicy struct {
|
var withPolicy struct {
|
||||||
Policy *PPLPolicy `mapstructure:"policy"`
|
Policy *PPLPolicy `mapstructure:"policy"`
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -26,48 +24,6 @@ func NewHandler(getLogger func() *zerolog.Logger) func(http.Handler) http.Handle
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// URLHandler adds the requested URL as a field to the context's logger
|
|
||||||
// using fieldKey as field key.
|
|
||||||
func URLHandler(fieldKey string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
log := zerolog.Ctx(r.Context())
|
|
||||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
|
||||||
return c.Str(fieldKey, r.URL.String())
|
|
||||||
})
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MethodHandler adds the request method as a field to the context's logger
|
|
||||||
// using fieldKey as field key.
|
|
||||||
func MethodHandler(fieldKey string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
log := zerolog.Ctx(r.Context())
|
|
||||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
|
||||||
return c.Str(fieldKey, r.Method)
|
|
||||||
})
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestHandler adds the request method and URL as a field to the context's logger
|
|
||||||
// using fieldKey as field key.
|
|
||||||
func RequestHandler(fieldKey string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
log := zerolog.Ctx(r.Context())
|
|
||||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
|
||||||
return c.Str(fieldKey, r.Method+" "+r.URL.String())
|
|
||||||
})
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteAddrHandler adds the request's remote address as a field to the context's logger
|
// RemoteAddrHandler adds the request's remote address as a field to the context's logger
|
||||||
// using fieldKey as field key.
|
// using fieldKey as field key.
|
||||||
func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
@ -165,12 +121,3 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// uuid generates a random 128-bit non-RFC UUID.
|
|
||||||
func uuid() string {
|
|
||||||
buf := make([]byte, 16)
|
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:])
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,12 +3,9 @@ package log
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -18,23 +15,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGenerateUUID(t *testing.T) {
|
|
||||||
prev := uuid()
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
id := uuid()
|
|
||||||
if id == "" {
|
|
||||||
t.Fatal("random pool failure")
|
|
||||||
}
|
|
||||||
if prev == id {
|
|
||||||
t.Fatalf("Should get a new ID!")
|
|
||||||
}
|
|
||||||
matched := regexp.MustCompile("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}").MatchString(id)
|
|
||||||
if !matched {
|
|
||||||
t.Fatalf("expected match %s %v", id, matched)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeIfBinary(out fmt.Stringer) string {
|
func decodeIfBinary(out fmt.Stringer) string {
|
||||||
return out.String()
|
return out.String()
|
||||||
}
|
}
|
||||||
|
@ -53,58 +33,6 @@ func TestNewHandler(t *testing.T) {
|
||||||
h.ServeHTTP(nil, &http.Request{})
|
h.ServeHTTP(nil, &http.Request{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestURLHandler(t *testing.T) {
|
|
||||||
out := &bytes.Buffer{}
|
|
||||||
r := &http.Request{
|
|
||||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
|
||||||
}
|
|
||||||
h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
l := FromRequest(r)
|
|
||||||
l.Log().Msg("")
|
|
||||||
}))
|
|
||||||
log := zerolog.New(out)
|
|
||||||
h = NewHandler(func() *zerolog.Logger { return &log })(h)
|
|
||||||
h.ServeHTTP(nil, r)
|
|
||||||
if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
|
||||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMethodHandler(t *testing.T) {
|
|
||||||
out := &bytes.Buffer{}
|
|
||||||
r := &http.Request{
|
|
||||||
Method: "POST",
|
|
||||||
}
|
|
||||||
h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
l := FromRequest(r)
|
|
||||||
l.Log().Msg("")
|
|
||||||
}))
|
|
||||||
log := zerolog.New(out)
|
|
||||||
h = NewHandler(func() *zerolog.Logger { return &log })(h)
|
|
||||||
h.ServeHTTP(nil, r)
|
|
||||||
if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got {
|
|
||||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestHandler(t *testing.T) {
|
|
||||||
out := &bytes.Buffer{}
|
|
||||||
r := &http.Request{
|
|
||||||
Method: "POST",
|
|
||||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
|
||||||
}
|
|
||||||
h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
l := FromRequest(r)
|
|
||||||
l.Log().Msg("")
|
|
||||||
}))
|
|
||||||
log := zerolog.New(out)
|
|
||||||
h = NewHandler(func() *zerolog.Logger { return &log })(h)
|
|
||||||
h.ServeHTTP(nil, r)
|
|
||||||
if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
|
||||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRemoteAddrHandler(t *testing.T) {
|
func TestRemoteAddrHandler(t *testing.T) {
|
||||||
out := &bytes.Buffer{}
|
out := &bytes.Buffer{}
|
||||||
r := &http.Request{
|
r := &http.Request{
|
||||||
|
@ -198,62 +126,6 @@ func TestRequestIDHandler(t *testing.T) {
|
||||||
h.ServeHTTP(httptest.NewRecorder(), r)
|
h.ServeHTTP(httptest.NewRecorder(), r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCombinedHandlers(t *testing.T) {
|
|
||||||
out := &bytes.Buffer{}
|
|
||||||
r := &http.Request{
|
|
||||||
Method: "POST",
|
|
||||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
|
||||||
}
|
|
||||||
h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
l := FromRequest(r)
|
|
||||||
l.Log().Msg("")
|
|
||||||
}))))
|
|
||||||
log := zerolog.New(out)
|
|
||||||
h = NewHandler(func() *zerolog.Logger { return &log })(h)
|
|
||||||
h.ServeHTTP(nil, r)
|
|
||||||
if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
|
||||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkHandlers(b *testing.B) {
|
|
||||||
r := &http.Request{
|
|
||||||
Method: "POST",
|
|
||||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
|
||||||
}
|
|
||||||
h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
l := FromRequest(r)
|
|
||||||
l.Log().Msg("")
|
|
||||||
}))
|
|
||||||
h2 := MethodHandler("method")(RequestHandler("request")(h1))
|
|
||||||
handlers := map[string]http.Handler{
|
|
||||||
"Single": NewHandler(func() *zerolog.Logger {
|
|
||||||
log := zerolog.New(ioutil.Discard)
|
|
||||||
return &log
|
|
||||||
})(h1),
|
|
||||||
"Combined": NewHandler((func() *zerolog.Logger {
|
|
||||||
log := zerolog.New(ioutil.Discard)
|
|
||||||
return &log
|
|
||||||
}))(h2),
|
|
||||||
"SingleDisabled": NewHandler((func() *zerolog.Logger {
|
|
||||||
log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled)
|
|
||||||
return &log
|
|
||||||
}))(h1),
|
|
||||||
"CombinedDisabled": NewHandler((func() *zerolog.Logger {
|
|
||||||
log := zerolog.New(ioutil.Discard).Level(zerolog.Disabled)
|
|
||||||
return &log
|
|
||||||
}))(h2),
|
|
||||||
}
|
|
||||||
for name := range handlers {
|
|
||||||
h := handlers[name]
|
|
||||||
b.Run(name, func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
h.ServeHTTP(nil, r)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDataRace(b *testing.B) {
|
func BenchmarkDataRace(b *testing.B) {
|
||||||
log := zerolog.New(nil).With().
|
log := zerolog.New(nil).With().
|
||||||
Str("foo", "bar").
|
Str("foo", "bar").
|
||||||
|
|
|
@ -4,8 +4,6 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
|
@ -48,36 +46,6 @@ func ValidateRequestURL(r *http.Request, key []byte) error {
|
||||||
return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate()
|
return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
// StripCookie strips the cookie from the downstram request.
|
|
||||||
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
headers := make([]string, 0, len(r.Cookies()))
|
|
||||||
for _, cookie := range r.Cookies() {
|
|
||||||
if !strings.HasPrefix(cookie.Name, cookieName) {
|
|
||||||
headers = append(headers, cookie.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.Header.Set("Cookie", strings.Join(headers, ";"))
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimeoutHandlerFunc wraps http.TimeoutHandler
|
|
||||||
func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc")
|
|
||||||
defer span.End()
|
|
||||||
http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequireBasicAuth creates a new handler that requires basic auth from the client before
|
// RequireBasicAuth creates a new handler that requires basic auth from the client before
|
||||||
// calling the underlying handler.
|
// calling the underlying handler.
|
||||||
func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler {
|
func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler {
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
@ -41,81 +40,6 @@ func TestSetHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStripCookie(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
pomeriumCookie string
|
|
||||||
otherCookies []string
|
|
||||||
}{
|
|
||||||
{"good", "pomerium", []string{"x", "y", "z"}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
for _, cookie := range r.Cookies() {
|
|
||||||
if cookie.Name == tt.pomeriumCookie {
|
|
||||||
t.Errorf("cookie not stripped %s", r.Cookies())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
for _, cn := range tt.otherCookies {
|
|
||||||
http.SetCookie(rr, &http.Cookie{
|
|
||||||
Name: cn,
|
|
||||||
Value: "some other cookie",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
http.SetCookie(rr, &http.Cookie{
|
|
||||||
Name: tt.pomeriumCookie,
|
|
||||||
Value: "pomerium cookie!",
|
|
||||||
})
|
|
||||||
|
|
||||||
http.SetCookie(rr, &http.Cookie{
|
|
||||||
Name: tt.pomeriumCookie + "_csrf",
|
|
||||||
Value: "pomerium csrf cookie!",
|
|
||||||
})
|
|
||||||
req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}}
|
|
||||||
|
|
||||||
handler := StripCookie(tt.pomeriumCookie)(testHandler)
|
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTimeoutHandlerFunc(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
timeout time.Duration
|
|
||||||
timeoutError string
|
|
||||||
wantStatus int
|
|
||||||
wantBody string
|
|
||||||
}{
|
|
||||||
{"good", 1 * time.Second, "good timed out!?", http.StatusOK, http.StatusText(http.StatusOK)},
|
|
||||||
{"timeout!", 1 * time.Nanosecond, "ruh roh", http.StatusServiceUnavailable, "ruh roh"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
got := TimeoutHandlerFunc(tt.timeout, tt.timeoutError)(fn)
|
|
||||||
got.ServeHTTP(w, r)
|
|
||||||
if status := w.Code; status != tt.wantStatus {
|
|
||||||
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
|
||||||
}
|
|
||||||
if body := w.Body.String(); tt.wantBody != body {
|
|
||||||
t.Errorf("SignRequest() body = %v, want %v", body, tt.wantBody)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSignature(t *testing.T) {
|
func TestValidateSignature(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -143,11 +67,11 @@ func TestValidateSignature(t *testing.T) {
|
||||||
got := ValidateSignature(tt.secretA)(fn)
|
got := ValidateSignature(tt.secretA)(fn)
|
||||||
got.ServeHTTP(w, r)
|
got.ServeHTTP(w, r)
|
||||||
if status := w.Code; status != tt.wantStatus {
|
if status := w.Code; status != tt.wantStatus {
|
||||||
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
t.Errorf("ValidateSignature() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||||
}
|
}
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||||
t.Errorf("SignRequest() %s", diff)
|
t.Errorf("ValidateSignature() %s", diff)
|
||||||
t.Errorf("%s", signedURL)
|
t.Errorf("%s", signedURL)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue