package httputil
import (
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"os/signal"
"sync"
"syscall"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
func TestNewServer(t *testing.T) {
// to support envs that won't let us use 443 without root
defaultServerOptions.Addr = ":0"
t.Parallel()
tests := []struct {
name string
opt *ServerOptions
httpHandler http.Handler
// want *http.Server
wantErr bool
}{
{
"good basic http handler",
&ServerOptions{
Addr: ":0",
Insecure: true,
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
false,
},
{
"bad neither insecure nor certs set",
&ServerOptions{
Addr: ":0",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
true,
},
{
"good no address",
&ServerOptions{
Insecure: true,
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
false,
},
{
"empty handler",
nil,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
true,
},
{
"bad port - invalid port range ",
&ServerOptions{
Addr: ":65536",
Insecure: true,
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
true,
},
{
"good tls set",
&ServerOptions{
TLSConfig: &tls.Config{},
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var wg sync.WaitGroup
srv, err := NewServer(tt.opt, tt.httpHandler, &wg)
if (err != nil) != tt.wantErr {
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
// we cheat a little bit here and use the httptest server to test the client
ts := httptest.NewServer(srv.Handler)
defer ts.Close()
client := ts.Client()
res, err := client.Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
}
if srv != nil {
// simulate a sigterm and cleanup the server
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
go Shutdown(srv)
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
}
})
}
}
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
select {
case s := <-c:
if s != sig {
t.Fatalf("signal was %v, want %v", s, sig)
}
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for %v", sig)
}
}
func TestRedirectHandler(t *testing.T) {
tests := []struct {
url string
wantStatus int
wantBody string
}{
{"http://example", http.StatusMovedPermanently, "Moved Permanently.\n\n"},
{"http://example:8080", http.StatusMovedPermanently, "Moved Permanently.\n\n"},
{"http://example:8080/some/path?x=y", http.StatusMovedPermanently, "Moved Permanently.\n\n"},
}
for _, tt := range tests {
t.Run(tt.url, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tt.url, nil)
rr := httptest.NewRecorder()
RedirectHandler().ServeHTTP(rr, req)
if diff := cmp.Diff(tt.wantStatus, rr.Code); diff != "" {
t.Errorf("TestRedirectHandler() code diff :%s", diff)
}
if diff := cmp.Diff(tt.wantBody, rr.Body.String()); diff != "" {
t.Errorf("TestRedirectHandler() body diff :%s", diff)
}
})
}
}