mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
- gofumpt everything - fix TLS MinVersion to be at least 1.2 - add octal syntax - remove newlines - fix potential decompression bomb in ecjson - remove implicit memory aliasing in for loops. Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
165 lines
3.8 KiB
Go
165 lines
3.8 KiB
Go
package httputil
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
)
|
|
|
|
func TestNewServer(t *testing.T) {
|
|
// to support envs that won't let us use 443 without root
|
|
defaultServerOptions.Addr = ":0"
|
|
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
opt *ServerOptions
|
|
httpHandler http.Handler
|
|
// want *http.Server
|
|
wantErr bool
|
|
}{
|
|
|
|
{
|
|
"good basic http handler",
|
|
&ServerOptions{
|
|
Addr: ":0",
|
|
Insecure: true,
|
|
},
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintln(w, "Hello, http")
|
|
}),
|
|
false,
|
|
},
|
|
{
|
|
"bad neither insecure nor certs set",
|
|
&ServerOptions{
|
|
Addr: ":0",
|
|
},
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintln(w, "Hello, http")
|
|
}),
|
|
true,
|
|
},
|
|
{
|
|
"good no address",
|
|
&ServerOptions{
|
|
Insecure: true,
|
|
},
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintln(w, "Hello, http")
|
|
}),
|
|
false,
|
|
},
|
|
{
|
|
"empty handler",
|
|
nil,
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintln(w, "Hello, http")
|
|
}),
|
|
true,
|
|
},
|
|
{
|
|
"bad port - invalid port range ",
|
|
&ServerOptions{
|
|
Addr: ":65536",
|
|
Insecure: true,
|
|
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintln(w, "Hello, http")
|
|
}),
|
|
true,
|
|
},
|
|
{
|
|
"good tls set",
|
|
&ServerOptions{
|
|
TLSConfig: &tls.Config{},
|
|
},
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintln(w, "Hello, http")
|
|
}),
|
|
false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var wg sync.WaitGroup
|
|
srv, err := NewServer(tt.opt, tt.httpHandler, &wg)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err == nil {
|
|
// we cheat a little bit here and use the httptest server to test the client
|
|
ts := httptest.NewServer(srv.Handler)
|
|
defer ts.Close()
|
|
client := ts.Client()
|
|
res, err := client.Get(ts.URL)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
greeting, err := ioutil.ReadAll(res.Body)
|
|
res.Body.Close()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
fmt.Printf("%s", greeting)
|
|
}
|
|
if srv != nil {
|
|
// simulate a sigterm and cleanup the server
|
|
c := make(chan os.Signal, 1)
|
|
signal.Notify(c, syscall.SIGINT)
|
|
defer signal.Stop(c)
|
|
go Shutdown(srv)
|
|
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
|
waitSig(t, c, syscall.SIGINT)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
|
|
select {
|
|
case s := <-c:
|
|
if s != sig {
|
|
t.Fatalf("signal was %v, want %v", s, sig)
|
|
}
|
|
case <-time.After(1 * time.Second):
|
|
t.Fatalf("timeout waiting for %v", sig)
|
|
}
|
|
}
|
|
|
|
func TestRedirectHandler(t *testing.T) {
|
|
tests := []struct {
|
|
url string
|
|
wantStatus int
|
|
wantBody string
|
|
}{
|
|
{"http://example", http.StatusMovedPermanently, "<a href=\"https://example\">Moved Permanently</a>.\n\n"},
|
|
{"http://example:8080", http.StatusMovedPermanently, "<a href=\"https://example\">Moved Permanently</a>.\n\n"},
|
|
{"http://example:8080/some/path?x=y", http.StatusMovedPermanently, "<a href=\"https://example/some/path?x=y\">Moved Permanently</a>.\n\n"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.url, func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, tt.url, nil)
|
|
rr := httptest.NewRecorder()
|
|
RedirectHandler().ServeHTTP(rr, req)
|
|
if diff := cmp.Diff(tt.wantStatus, rr.Code); diff != "" {
|
|
t.Errorf("TestRedirectHandler() code diff :%s", diff)
|
|
}
|
|
if diff := cmp.Diff(tt.wantBody, rr.Body.String()); diff != "" {
|
|
t.Errorf("TestRedirectHandler() body diff :%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|