mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-24 20:18:13 +02:00
telemetry: add tracing
- telemetry/tace: add traces throughout code - telemetry/metrics: nest metrics and trace under telemetry - telemetry/tace: add service name span to HTTPMetricsHandler. - telemetry/metrics: removed chain dependency middleware_tests. - telemetry/metrics: wrap and encapsulate variatic view registration. - telemetry/tace: add jaeger support for tracing. - cmd/pomerium: move `parseOptions` to internal/config. - cmd/pomerium: offload server handling to httputil and sub pkgs. - httputil: standardize creation/shutdown of http listeners. - httputil: prefer curve X25519 to P256 when negotiating TLS. - fileutil: use standardized Getw Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
6b61a48fce
commit
5edfa7b03f
49 changed files with 1524 additions and 758 deletions
|
@ -23,21 +23,11 @@ func (h Error) Error() string {
|
|||
return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message)
|
||||
}
|
||||
|
||||
// CodeForError maps an error type and returns a corresponding http.Status
|
||||
func CodeForError(err error) int {
|
||||
switch err {
|
||||
case ErrTokenRevoked:
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ErrorResponse renders an error page for errors given a message and a status code.
|
||||
// If no message is passed, defaults to the text of the status code.
|
||||
func ErrorResponse(rw http.ResponseWriter, r *http.Request, e *Error) {
|
||||
requestID := ""
|
||||
id, ok := log.IDFromRequest(r)
|
||||
if ok {
|
||||
var requestID string
|
||||
if id, ok := log.IDFromRequest(r); ok {
|
||||
requestID = id
|
||||
}
|
||||
if r.Header.Get("Accept") == "application/json" {
|
||||
|
|
49
internal/httputil/errors_test.go
Normal file
49
internal/httputil/errors_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rw http.ResponseWriter
|
||||
r *http.Request
|
||||
e *Error
|
||||
}{
|
||||
{"good", httptest.NewRecorder(), &http.Request{Method: http.MethodGet}, &Error{Code: http.StatusBadRequest, Message: "missing id token"}},
|
||||
{"good json", httptest.NewRecorder(), &http.Request{Method: http.MethodGet, Header: http.Header{"Accept": []string{"application/json"}}}, &Error{Code: http.StatusBadRequest, Message: "missing id token"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ErrorResponse(tt.rw, tt.r, tt.e)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
Message string
|
||||
Code int
|
||||
CanDebug bool
|
||||
want string
|
||||
}{
|
||||
{"good", "short and stout", http.StatusTeapot, false, "418 I'm a teapot: short and stout"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := Error{
|
||||
Message: tt.Message,
|
||||
Code: tt.Code,
|
||||
CanDebug: tt.CanDebug,
|
||||
}
|
||||
if got := h.Error(); got != tt.want {
|
||||
t.Errorf("Error.Error() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
76
internal/httputil/http.go
Normal file
76
internal/httputil/http.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
// NewHTTPServer starts a http server given a set of options and a handler.
|
||||
//
|
||||
// It is the caller's responsibility to Close() or Shutdown() the returned
|
||||
// server.
|
||||
func NewHTTPServer(opt *ServerOptions, h http.Handler) *http.Server {
|
||||
if opt == nil {
|
||||
opt = defaultHTTPServerOptions
|
||||
} else {
|
||||
opt.applyHTTPDefaults()
|
||||
}
|
||||
sublogger := log.With().Str("addr", opt.Addr).Logger()
|
||||
srv := http.Server{
|
||||
Addr: opt.Addr,
|
||||
ReadHeaderTimeout: opt.ReadHeaderTimeout,
|
||||
ReadTimeout: opt.ReadTimeout,
|
||||
WriteTimeout: opt.WriteTimeout,
|
||||
IdleTimeout: opt.IdleTimeout,
|
||||
Handler: h,
|
||||
ErrorLog: stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0),
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
|
||||
log.Error().Str("addr", opt.Addr).Err(err).Msg("internal/httputil: unexpected shutdown")
|
||||
}
|
||||
}()
|
||||
return &srv
|
||||
}
|
||||
|
||||
func RedirectHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Connection", "close")
|
||||
url := fmt.Sprintf("https://%s%s", urlutil.StripPort(r.Host), r.URL.String())
|
||||
http.Redirect(w, r, url, http.StatusMovedPermanently)
|
||||
})
|
||||
}
|
||||
|
||||
// Shutdown attempts to shut down the server when a os interrupt or sigterm
|
||||
// signal are received without interrupting any
|
||||
// active connections. Shutdown works by first closing all open
|
||||
// listeners, then closing all idle connections, and then waiting
|
||||
// indefinitely for connections to return to idle and then shut down.
|
||||
// If the provided context expires before the shutdown is complete,
|
||||
// Shutdown returns the context's error, otherwise it returns any
|
||||
// error returned from closing the Server's underlying Listener(s).
|
||||
//
|
||||
// When Shutdown is called, Serve, ListenAndServe, and
|
||||
// ListenAndServeTLS immediately return ErrServerClosed.
|
||||
func Shutdown(srv *http.Server) {
|
||||
sigint := make(chan os.Signal, 1)
|
||||
signal.Notify(sigint, os.Interrupt)
|
||||
signal.Notify(sigint, syscall.SIGTERM)
|
||||
rec := <-sigint
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
log.Info().Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("internal/httputil: shutdown failed")
|
||||
}
|
||||
}
|
49
internal/httputil/http_test.go
Normal file
49
internal/httputil/http_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewHTTPServer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *ServerOptions
|
||||
// wantErr bool
|
||||
}{
|
||||
{"localhost:9232", &ServerOptions{Addr: "localhost:9232"}},
|
||||
{"localhost:65536", &ServerOptions{Addr: "localhost:-1"}}, // will fail, but won't err
|
||||
{"empty", &ServerOptions{}},
|
||||
{"empty", nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv := NewHTTPServer(tt.opts, RedirectHandler())
|
||||
|
||||
defer srv.Close()
|
||||
|
||||
// we cheat a little bit here and use the httptest server to test the client
|
||||
ts := httptest.NewServer(srv.Handler)
|
||||
defer ts.Close()
|
||||
client := ts.Client()
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
res, err := client.Get(ts.URL)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
greeting, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s", greeting)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
87
internal/httputil/options.go
Normal file
87
internal/httputil/options.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
)
|
||||
|
||||
// ServerOptions contains the configurations settings for a http server.
|
||||
type ServerOptions struct {
|
||||
// Addr specifies the host and port on which the server should serve
|
||||
// HTTPS requests. If empty, ":https" is used.
|
||||
Addr string
|
||||
|
||||
// TLS certificates to use.
|
||||
Cert string
|
||||
Key string
|
||||
CertFile string
|
||||
KeyFile string
|
||||
|
||||
// Timeouts
|
||||
ReadHeaderTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
var defaultTLSServerOptions = &ServerOptions{
|
||||
Addr: ":https",
|
||||
CertFile: filepath.Join(fileutil.Getwd(), "cert.pem"),
|
||||
KeyFile: filepath.Join(fileutil.Getwd(), "privkey.pem"),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 0, // support streaming by default
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
func (o *ServerOptions) applyTLSDefaults() {
|
||||
if o.Addr == "" {
|
||||
o.Addr = defaultTLSServerOptions.Addr
|
||||
}
|
||||
if o.Cert == "" && o.CertFile == "" {
|
||||
o.CertFile = defaultTLSServerOptions.CertFile
|
||||
}
|
||||
if o.Key == "" && o.KeyFile == "" {
|
||||
o.KeyFile = defaultTLSServerOptions.KeyFile
|
||||
}
|
||||
if o.ReadHeaderTimeout == 0 {
|
||||
o.ReadHeaderTimeout = defaultTLSServerOptions.ReadHeaderTimeout
|
||||
}
|
||||
if o.ReadTimeout == 0 {
|
||||
o.ReadTimeout = defaultTLSServerOptions.ReadTimeout
|
||||
}
|
||||
if o.WriteTimeout == 0 {
|
||||
o.WriteTimeout = defaultTLSServerOptions.WriteTimeout
|
||||
}
|
||||
if o.IdleTimeout == 0 {
|
||||
o.IdleTimeout = defaultTLSServerOptions.IdleTimeout
|
||||
}
|
||||
}
|
||||
|
||||
var defaultHTTPServerOptions = &ServerOptions{
|
||||
Addr: ":http",
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
func (o *ServerOptions) applyHTTPDefaults() {
|
||||
if o.Addr == "" {
|
||||
o.Addr = defaultHTTPServerOptions.Addr
|
||||
}
|
||||
if o.ReadHeaderTimeout == 0 {
|
||||
o.ReadHeaderTimeout = defaultHTTPServerOptions.ReadHeaderTimeout
|
||||
}
|
||||
if o.ReadTimeout == 0 {
|
||||
o.ReadTimeout = defaultHTTPServerOptions.ReadTimeout
|
||||
}
|
||||
if o.WriteTimeout == 0 {
|
||||
o.WriteTimeout = defaultHTTPServerOptions.WriteTimeout
|
||||
}
|
||||
if o.IdleTimeout == 0 {
|
||||
o.IdleTimeout = defaultHTTPServerOptions.IdleTimeout
|
||||
}
|
||||
}
|
10
internal/httputil/test_data/cert.pem
Normal file
10
internal/httputil/test_data/cert.pem
Normal file
|
@ -0,0 +1,10 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
|
||||
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
|
||||
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
|
||||
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
|
||||
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
|
||||
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
|
||||
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
|
||||
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
|
||||
-----END CERTIFICATE-----
|
5
internal/httputil/test_data/privkey.pem
Normal file
5
internal/httputil/test_data/privkey.pem
Normal file
|
@ -0,0 +1,5 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
|
||||
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
|
||||
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
|
||||
-----END EC PRIVATE KEY-----
|
|
@ -7,83 +7,20 @@ import (
|
|||
stdlog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// Options contains the configurations settings for a TLS http server.
|
||||
type Options struct {
|
||||
// Addr specifies the host and port on which the server should serve
|
||||
// HTTPS requests. If empty, ":https" is used.
|
||||
Addr string
|
||||
|
||||
// TLS certificates to use.
|
||||
Cert string
|
||||
Key string
|
||||
CertFile string
|
||||
KeyFile string
|
||||
|
||||
// Timeouts
|
||||
ReadHeaderTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
var defaultOptions = &Options{
|
||||
Addr: ":https",
|
||||
CertFile: filepath.Join(findKeyDir(), "cert.pem"),
|
||||
KeyFile: filepath.Join(findKeyDir(), "privkey.pem"),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 0, // support streaming by default
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
func findKeyDir() string {
|
||||
p, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (o *Options) applyDefaults() {
|
||||
if o.Addr == "" {
|
||||
o.Addr = defaultOptions.Addr
|
||||
}
|
||||
if o.Cert == "" && o.CertFile == "" {
|
||||
o.CertFile = defaultOptions.CertFile
|
||||
}
|
||||
if o.Key == "" && o.KeyFile == "" {
|
||||
o.KeyFile = defaultOptions.KeyFile
|
||||
}
|
||||
if o.ReadHeaderTimeout == 0 {
|
||||
o.ReadHeaderTimeout = defaultOptions.ReadHeaderTimeout
|
||||
}
|
||||
if o.ReadTimeout == 0 {
|
||||
o.ReadTimeout = defaultOptions.ReadTimeout
|
||||
}
|
||||
if o.WriteTimeout == 0 {
|
||||
o.WriteTimeout = defaultOptions.WriteTimeout
|
||||
}
|
||||
if o.IdleTimeout == 0 {
|
||||
o.IdleTimeout = defaultOptions.IdleTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndServeTLS serves the provided handlers by HTTPS
|
||||
// using the provided options.
|
||||
func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.Handler) error {
|
||||
// NewTLSServer creates a new TLS server given a set of options, handlers, and
|
||||
// optionally a set of gRPC endpoints as well.
|
||||
// It is the callers responsibility to close the resturned server.
|
||||
func NewTLSServer(opt *ServerOptions, httpHandler http.Handler, grpcHandler http.Handler) (*http.Server, error) {
|
||||
if opt == nil {
|
||||
opt = defaultOptions
|
||||
opt = defaultTLSServerOptions
|
||||
} else {
|
||||
opt.applyDefaults()
|
||||
opt.applyTLSDefaults()
|
||||
}
|
||||
var cert *tls.Certificate
|
||||
var err error
|
||||
|
@ -93,12 +30,12 @@ func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.
|
|||
cert, err = readCertificateFile(opt.CertFile, opt.KeyFile)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("https: failed loading x509 certificate: %v", err)
|
||||
return nil, fmt.Errorf("internal/httputil: failed loading x509 certificate: %v", err)
|
||||
}
|
||||
config := newDefaultTLSConfig(cert)
|
||||
ln, err := net.Listen("tcp", opt.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ln = tls.NewListener(ln, config)
|
||||
|
@ -112,7 +49,7 @@ func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.
|
|||
sublogger := log.With().Str("addr", opt.Addr).Logger()
|
||||
|
||||
// Set up the main server.
|
||||
server := &http.Server{
|
||||
srv := &http.Server{
|
||||
ReadHeaderTimeout: opt.ReadHeaderTimeout,
|
||||
ReadTimeout: opt.ReadTimeout,
|
||||
WriteTimeout: opt.WriteTimeout,
|
||||
|
@ -121,8 +58,13 @@ func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.
|
|||
Handler: h,
|
||||
ErrorLog: stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0),
|
||||
}
|
||||
go func() {
|
||||
if err := srv.Serve(ln); err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("internal/httputil: tls server crashed")
|
||||
}
|
||||
}()
|
||||
|
||||
return server.Serve(ln)
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
func decodeCertificate(cert, key string) (*tls.Certificate, error) {
|
||||
|
@ -189,8 +131,8 @@ func newDefaultTLSConfig(cert *tls.Certificate) *tls.Config {
|
|||
return tlsConfig
|
||||
}
|
||||
|
||||
// grpcHandlerFunc splits request serving between gRPC and HTTPS depending on the request type.
|
||||
// Requires HTTP/2.
|
||||
// grpcHandlerFunc splits request serving between gRPC and HTTPS depending on
|
||||
// the request type. Requires HTTP/2 to be enabled.
|
||||
func grpcHandlerFunc(rpcServer http.Handler, other http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ct := r.Header.Get("Content-Type")
|
210
internal/httputil/tls_test.go
Normal file
210
internal/httputil/tls_test.go
Normal file
|
@ -0,0 +1,210 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const privKey = `-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
|
||||
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
|
||||
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
|
||||
-----END EC PRIVATE KEY-----`
|
||||
const pubKey = `-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
|
||||
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
|
||||
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
|
||||
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
|
||||
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
|
||||
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
|
||||
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
|
||||
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
func TestNewTLSServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
opt *ServerOptions
|
||||
httpHandler http.Handler
|
||||
grpcHandler http.Handler
|
||||
// want *http.Server
|
||||
wantErr bool
|
||||
}{
|
||||
{"good basic http handler",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:0",
|
||||
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
|
||||
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
|
||||
},
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
nil,
|
||||
false},
|
||||
{"good basic http and grpc handler",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:0",
|
||||
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
|
||||
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
|
||||
},
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, grpc")
|
||||
}),
|
||||
false},
|
||||
{"good with cert files",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:0",
|
||||
CertFile: "test_data/cert.pem",
|
||||
KeyFile: "test_data/privkey.pem",
|
||||
},
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, grpc")
|
||||
}),
|
||||
false},
|
||||
{"unreadable cert file",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:0",
|
||||
CertFile: "test_data",
|
||||
KeyFile: "test_data/privkey.pem",
|
||||
},
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, grpc")
|
||||
}),
|
||||
true},
|
||||
{"unreadable key file",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:0",
|
||||
CertFile: "./test_data/cert.pem",
|
||||
KeyFile: "./test_data",
|
||||
},
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, grpc")
|
||||
}),
|
||||
true},
|
||||
{"unreadable key file",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:0",
|
||||
CertFile: "./test_data/cert.pem",
|
||||
KeyFile: "./test_data/file-does-not-exist",
|
||||
},
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, grpc")
|
||||
}),
|
||||
true},
|
||||
{"bad private key base64",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:0",
|
||||
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
|
||||
Key: "bad guy",
|
||||
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
nil,
|
||||
true},
|
||||
{"bad public key base64",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:9999",
|
||||
Key: base64.StdEncoding.EncodeToString([]byte(pubKey)),
|
||||
Cert: "bad guy",
|
||||
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
nil,
|
||||
true},
|
||||
{"bad port - invalid port range ",
|
||||
&ServerOptions{
|
||||
Addr: "127.0.0.1:65536",
|
||||
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
|
||||
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
|
||||
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
nil,
|
||||
true},
|
||||
{"nil apply default but will fail",
|
||||
nil,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
nil,
|
||||
true},
|
||||
{"empty, apply defaults to missing",
|
||||
&ServerOptions{},
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello, http")
|
||||
}),
|
||||
nil,
|
||||
true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv, err := NewTLSServer(tt.opt, tt.httpHandler, tt.grpcHandler)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewTLSServer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
// we cheat a little bit here and use the httptest server to test the client
|
||||
ts := httptest.NewTLSServer(srv.Handler)
|
||||
defer ts.Close()
|
||||
client := ts.Client()
|
||||
res, err := client.Get(ts.URL)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
greeting, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s", greeting)
|
||||
}
|
||||
if srv != nil {
|
||||
// simulate a sigterm and cleanup the server
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
defer signal.Stop(c)
|
||||
go Shutdown(srv)
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
||||
waitSig(t, c, syscall.SIGINT)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
|
||||
select {
|
||||
case s := <-c:
|
||||
if s != sig {
|
||||
t.Fatalf("signal was %v, want %v", s, sig)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatalf("timeout waiting for %v", sig)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue