package httputil import ( "crypto/tls" "fmt" "io/ioutil" "log" "net/http" "net/http/httptest" "os" "os/signal" "sync" "syscall" "testing" "time" "github.com/google/go-cmp/cmp" ) func TestNewServer(t *testing.T) { // to support envs that won't let us use 443 without root defaultServerOptions.Addr = ":0" t.Parallel() tests := []struct { name string opt *ServerOptions httpHandler http.Handler // want *http.Server wantErr bool }{ { "good basic http handler", &ServerOptions{ Addr: ":0", Insecure: true, }, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, http") }), false, }, { "bad neither insecure nor certs set", &ServerOptions{ Addr: ":0", }, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, http") }), true, }, { "good no address", &ServerOptions{ Insecure: true, }, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, http") }), false, }, { "empty handler", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, http") }), true, }, { "bad port - invalid port range ", &ServerOptions{ Addr: ":65536", Insecure: true, }, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, http") }), true, }, { "good tls set", &ServerOptions{ TLSConfig: &tls.Config{}, }, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, http") }), false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var wg sync.WaitGroup srv, err := NewServer(tt.opt, tt.httpHandler, &wg) if (err != nil) != tt.wantErr { t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil { // we cheat a little bit here and use the httptest server to test the client ts := httptest.NewServer(srv.Handler) defer ts.Close() client := ts.Client() res, err := client.Get(ts.URL) if err != nil { log.Fatal(err) } greeting, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) } fmt.Printf("%s", greeting) } if srv != nil { // simulate a sigterm and cleanup the server c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT) defer signal.Stop(c) go Shutdown(srv) syscall.Kill(syscall.Getpid(), syscall.SIGINT) waitSig(t, c, syscall.SIGINT) } }) } } func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) { select { case s := <-c: if s != sig { t.Fatalf("signal was %v, want %v", s, sig) } case <-time.After(1 * time.Second): t.Fatalf("timeout waiting for %v", sig) } } func TestRedirectHandler(t *testing.T) { tests := []struct { url string wantStatus int wantBody string }{ {"http://example", http.StatusMovedPermanently, "Moved Permanently.\n\n"}, {"http://example:8080", http.StatusMovedPermanently, "Moved Permanently.\n\n"}, {"http://example:8080/some/path?x=y", http.StatusMovedPermanently, "Moved Permanently.\n\n"}, } for _, tt := range tests { t.Run(tt.url, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.url, nil) rr := httptest.NewRecorder() RedirectHandler().ServeHTTP(rr, req) if diff := cmp.Diff(tt.wantStatus, rr.Code); diff != "" { t.Errorf("TestRedirectHandler() code diff :%s", diff) } if diff := cmp.Diff(tt.wantBody, rr.Body.String()); diff != "" { t.Errorf("TestRedirectHandler() body diff :%s", diff) } }) } }