telemetry: add tracing

- telemetry/tace: add traces throughout code
- telemetry/metrics: nest metrics and trace under telemetry
- telemetry/tace: add service name span to HTTPMetricsHandler.
- telemetry/metrics: removed chain dependency middleware_tests.
- telemetry/metrics: wrap and encapsulate variatic view registration.
- telemetry/tace: add jaeger support for tracing.
- cmd/pomerium: move `parseOptions` to internal/config.
- cmd/pomerium: offload server handling to httputil and sub pkgs.
- httputil: standardize creation/shutdown of http listeners.
- httputil: prefer curve X25519 to P256 when negotiating TLS.
- fileutil: use standardized Getw

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-07-24 09:20:16 -07:00
parent 6b61a48fce
commit 5edfa7b03f
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
49 changed files with 1524 additions and 758 deletions

View file

@ -1,16 +1,5 @@
package config // import "github.com/pomerium/pomerium/internal/config"
import "os"
// findPwd returns best guess at current working directory
func findPwd() string {
p, err := os.Getwd()
if err != nil {
return "."
}
return p
}
// IsValidService checks to see if a service is a valid service mode
func IsValidService(s string) bool {
switch s {

View file

@ -7,11 +7,14 @@ import (
"net/url"
"path/filepath"
"reflect"
"strconv"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/mitchellh/hashstructure"
@ -129,6 +132,19 @@ type Options struct {
// Address/Port to bind to for prometheus metrics
MetricsAddr string `mapstructure:"metrics_address"`
// Tracing shared settings
TracingProvider string `mapstructure:"tracing_provider"`
TracingDebug bool `mapstructure:"tracing_debug"`
// Jaeger
// CollectorEndpoint is the full url to the Jaeger HTTP Thrift collector.
// For example, http://localhost:14268/api/traces
TracingJaegerCollectorEndpoint string `mapstructure:"tracing_jaeger_collector_endpoint"`
// AgentEndpoint instructs exporter to send spans to jaeger-agent at this address.
// For example, localhost:6831.
TracingJaegerAgentEndpoint string `mapstructure:"tracing_jaeger_agent_endpoint"`
}
var defaultOptions = Options{
@ -148,8 +164,8 @@ var defaultOptions = Options{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
},
Addr: ":https",
CertFile: filepath.Join(findPwd(), "cert.pem"),
KeyFile: filepath.Join(findPwd(), "privkey.pem"),
CertFile: filepath.Join(fileutil.Getwd(), "cert.pem"),
KeyFile: filepath.Join(fileutil.Getwd(), "privkey.pem"),
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // support streaming by default
@ -339,3 +355,56 @@ func (o *Options) Checksum() string {
}
return fmt.Sprintf("%x", hash)
}
func ParseOptions(configFile string) (*Options, error) {
o, err := OptionsFromViper(configFile)
if err != nil {
return nil, err
}
if o.Debug {
log.SetDebugMode()
}
if o.LogLevel != "" {
log.SetLevel(o.LogLevel)
}
metrics.AddPolicyCountCallback(o.Services, func() int64 {
return int64(len(o.Policies))
})
checksumDec, err := strconv.ParseUint(o.Checksum(), 16, 64)
if err != nil {
log.Warn().Err(err).Msg("Could not parse config checksum into decimal")
}
metrics.SetConfigChecksum(o.Services, checksumDec)
return o, nil
}
func HandleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options {
newOpt, err := ParseOptions(configFile)
if err != nil {
log.Error().Err(err).Msg("cmd/pomerium: could not reload configuration")
return opt
}
optChecksum := opt.Checksum()
newOptChecksum := newOpt.Checksum()
log.Debug().
Str("old-checksum", optChecksum).
Str("new-checksum", newOptChecksum).
Msg("cmd/pomerium: configuration file changed")
if newOptChecksum == optChecksum {
log.Debug().Msg("cmd/pomerium: loaded configuration has not changed")
return opt
}
log.Info().Str("checksum", newOptChecksum).Msg("cmd/pomerium: checksum changed")
for _, service := range services {
if err := service.UpdateOptions(*newOpt); err != nil {
log.Error().Err(err).Msg("cmd/pomerium: could not update options")
}
}
return newOpt
}

View file

@ -408,3 +408,99 @@ func TestOptionsFromViper(t *testing.T) {
})
}
}
func Test_parseOptions(t *testing.T) {
viper.Reset()
tests := []struct {
name string
envKey string
envValue string
servicesEnvKey string
servicesEnvValue string
wantSharedKey string
wantErr bool
}{
{"no shared secret", "", "", "SERVICES", "authenticate", "skip", true},
{"no shared secret in all mode", "", "", "", "", "", false},
{"good", "SHARED_SECRET", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", "", "", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv(tt.servicesEnvKey, tt.servicesEnvValue)
os.Setenv(tt.envKey, tt.envValue)
defer os.Unsetenv(tt.envKey)
defer os.Unsetenv(tt.servicesEnvKey)
got, err := ParseOptions("")
if (err != nil) != tt.wantErr {
t.Errorf("ParseOptions() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil && got.Services != "all" && got.SharedKey != tt.wantSharedKey {
t.Errorf("ParseOptions()\n")
t.Errorf("got: %+v\n", got.SharedKey)
t.Errorf("want: %+v\n", tt.wantSharedKey)
}
})
}
}
type mockService struct {
fail bool
Updated bool
}
func (m *mockService) UpdateOptions(o Options) error {
m.Updated = true
if m.fail {
return fmt.Errorf("failed")
}
return nil
}
func Test_HandleConfigUpdate(t *testing.T) {
os.Clearenv()
os.Setenv("SHARED_SECRET", "foo")
defer os.Unsetenv("SHARED_SECRET")
blankOpts, err := NewOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
goodOpts, err := OptionsFromViper("")
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
envarKey string
envarValue string
service *mockService
oldOpts Options
wantUpdate bool
}{
{"good", "", "", &mockService{fail: false}, *blankOpts, true},
{"good set debug", "POMERIUM_DEBUG", "true", &mockService{fail: false}, *blankOpts, true},
{"bad", "", "", &mockService{fail: true}, *blankOpts, true},
{"no change", "", "", &mockService{fail: false}, *goodOpts, false},
{"bad policy file unmarshal error", "POLICY", base64.StdEncoding.EncodeToString([]byte("{json:}")), &mockService{fail: false}, *blankOpts, false},
{"bad header key", "SERVICES", "error", &mockService{fail: false}, *blankOpts, false},
{"bad header header value", "HEADERS", "x;y;z", &mockService{fail: false}, *blankOpts, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv(tt.envarKey, tt.envarValue)
defer os.Unsetenv(tt.envarKey)
HandleConfigUpdate("", &tt.oldOpts, []OptionsUpdater{tt.service})
if tt.service.Updated != tt.wantUpdate {
t.Errorf("Failed to update config on service")
}
})
}
}

View file

@ -30,3 +30,17 @@ func IsReadableFile(path string) (bool, error) {
fd.Close()
return true, nil // Item exists and is readable.
}
// Getwd returns a rooted path name corresponding to the
// current directory. If the current directory can be
// reached via multiple paths (due to symbolic links),
// Getwd may return any one of them.
//
// On failure, will return "."
func Getwd() string {
p, err := os.Getwd()
if err != nil {
return "."
}
return p
}

View file

@ -1,6 +1,9 @@
package fileutil // import "github.com/pomerium/pomerium/internal/fileutil"
package fileutil
import "testing"
import (
"strings"
"testing"
)
func TestIsReadableFile(t *testing.T) {
@ -27,3 +30,19 @@ func TestIsReadableFile(t *testing.T) {
})
}
}
func TestGetwd(t *testing.T) {
tests := []struct {
name string
want string
}{
{"most basic example", "internal/fileutil"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := Getwd(); strings.Contains(tt.want, got) {
t.Errorf("Getwd() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -23,21 +23,11 @@ func (h Error) Error() string {
return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message)
}
// CodeForError maps an error type and returns a corresponding http.Status
func CodeForError(err error) int {
switch err {
case ErrTokenRevoked:
return http.StatusUnauthorized
}
return http.StatusInternalServerError
}
// ErrorResponse renders an error page for errors given a message and a status code.
// If no message is passed, defaults to the text of the status code.
func ErrorResponse(rw http.ResponseWriter, r *http.Request, e *Error) {
requestID := ""
id, ok := log.IDFromRequest(r)
if ok {
var requestID string
if id, ok := log.IDFromRequest(r); ok {
requestID = id
}
if r.Header.Get("Accept") == "application/json" {

View file

@ -0,0 +1,49 @@
package httputil
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestErrorResponse(t *testing.T) {
tests := []struct {
name string
rw http.ResponseWriter
r *http.Request
e *Error
}{
{"good", httptest.NewRecorder(), &http.Request{Method: http.MethodGet}, &Error{Code: http.StatusBadRequest, Message: "missing id token"}},
{"good json", httptest.NewRecorder(), &http.Request{Method: http.MethodGet, Header: http.Header{"Accept": []string{"application/json"}}}, &Error{Code: http.StatusBadRequest, Message: "missing id token"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ErrorResponse(tt.rw, tt.r, tt.e)
})
}
}
func TestError_Error(t *testing.T) {
tests := []struct {
name string
Message string
Code int
CanDebug bool
want string
}{
{"good", "short and stout", http.StatusTeapot, false, "418 I'm a teapot: short and stout"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := Error{
Message: tt.Message,
Code: tt.Code,
CanDebug: tt.CanDebug,
}
if got := h.Error(); got != tt.want {
t.Errorf("Error.Error() = %v, want %v", got, tt.want)
}
})
}
}

76
internal/httputil/http.go Normal file
View file

@ -0,0 +1,76 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"context"
"fmt"
stdlog "log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
)
// NewHTTPServer starts a http server given a set of options and a handler.
//
// It is the caller's responsibility to Close() or Shutdown() the returned
// server.
func NewHTTPServer(opt *ServerOptions, h http.Handler) *http.Server {
if opt == nil {
opt = defaultHTTPServerOptions
} else {
opt.applyHTTPDefaults()
}
sublogger := log.With().Str("addr", opt.Addr).Logger()
srv := http.Server{
Addr: opt.Addr,
ReadHeaderTimeout: opt.ReadHeaderTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
IdleTimeout: opt.IdleTimeout,
Handler: h,
ErrorLog: stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0),
}
go func() {
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
log.Error().Str("addr", opt.Addr).Err(err).Msg("internal/httputil: unexpected shutdown")
}
}()
return &srv
}
func RedirectHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "close")
url := fmt.Sprintf("https://%s%s", urlutil.StripPort(r.Host), r.URL.String())
http.Redirect(w, r, url, http.StatusMovedPermanently)
})
}
// Shutdown attempts to shut down the server when a os interrupt or sigterm
// signal are received without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener(s).
//
// When Shutdown is called, Serve, ListenAndServe, and
// ListenAndServeTLS immediately return ErrServerClosed.
func Shutdown(srv *http.Server) {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
signal.Notify(sigint, syscall.SIGTERM)
rec := <-sigint
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Info().Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
if err := srv.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("internal/httputil: shutdown failed")
}
}

View file

@ -0,0 +1,49 @@
package httputil
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewHTTPServer(t *testing.T) {
tests := []struct {
name string
opts *ServerOptions
// wantErr bool
}{
{"localhost:9232", &ServerOptions{Addr: "localhost:9232"}},
{"localhost:65536", &ServerOptions{Addr: "localhost:-1"}}, // will fail, but won't err
{"empty", &ServerOptions{}},
{"empty", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := NewHTTPServer(tt.opts, RedirectHandler())
defer srv.Close()
// we cheat a little bit here and use the httptest server to test the client
ts := httptest.NewServer(srv.Handler)
defer ts.Close()
client := ts.Client()
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
res, err := client.Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
})
}
}

View file

@ -0,0 +1,87 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"path/filepath"
"time"
"github.com/pomerium/pomerium/internal/fileutil"
)
// ServerOptions contains the configurations settings for a http server.
type ServerOptions struct {
// Addr specifies the host and port on which the server should serve
// HTTPS requests. If empty, ":https" is used.
Addr string
// TLS certificates to use.
Cert string
Key string
CertFile string
KeyFile string
// Timeouts
ReadHeaderTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
}
var defaultTLSServerOptions = &ServerOptions{
Addr: ":https",
CertFile: filepath.Join(fileutil.Getwd(), "cert.pem"),
KeyFile: filepath.Join(fileutil.Getwd(), "privkey.pem"),
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // support streaming by default
IdleTimeout: 5 * time.Minute,
}
func (o *ServerOptions) applyTLSDefaults() {
if o.Addr == "" {
o.Addr = defaultTLSServerOptions.Addr
}
if o.Cert == "" && o.CertFile == "" {
o.CertFile = defaultTLSServerOptions.CertFile
}
if o.Key == "" && o.KeyFile == "" {
o.KeyFile = defaultTLSServerOptions.KeyFile
}
if o.ReadHeaderTimeout == 0 {
o.ReadHeaderTimeout = defaultTLSServerOptions.ReadHeaderTimeout
}
if o.ReadTimeout == 0 {
o.ReadTimeout = defaultTLSServerOptions.ReadTimeout
}
if o.WriteTimeout == 0 {
o.WriteTimeout = defaultTLSServerOptions.WriteTimeout
}
if o.IdleTimeout == 0 {
o.IdleTimeout = defaultTLSServerOptions.IdleTimeout
}
}
var defaultHTTPServerOptions = &ServerOptions{
Addr: ":http",
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 5 * time.Minute,
}
func (o *ServerOptions) applyHTTPDefaults() {
if o.Addr == "" {
o.Addr = defaultHTTPServerOptions.Addr
}
if o.ReadHeaderTimeout == 0 {
o.ReadHeaderTimeout = defaultHTTPServerOptions.ReadHeaderTimeout
}
if o.ReadTimeout == 0 {
o.ReadTimeout = defaultHTTPServerOptions.ReadTimeout
}
if o.WriteTimeout == 0 {
o.WriteTimeout = defaultHTTPServerOptions.WriteTimeout
}
if o.IdleTimeout == 0 {
o.IdleTimeout = defaultHTTPServerOptions.IdleTimeout
}
}

View file

@ -0,0 +1,10 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
-----END CERTIFICATE-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
-----END EC PRIVATE KEY-----

View file

@ -7,83 +7,20 @@ import (
stdlog "log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log"
)
// Options contains the configurations settings for a TLS http server.
type Options struct {
// Addr specifies the host and port on which the server should serve
// HTTPS requests. If empty, ":https" is used.
Addr string
// TLS certificates to use.
Cert string
Key string
CertFile string
KeyFile string
// Timeouts
ReadHeaderTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
}
var defaultOptions = &Options{
Addr: ":https",
CertFile: filepath.Join(findKeyDir(), "cert.pem"),
KeyFile: filepath.Join(findKeyDir(), "privkey.pem"),
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // support streaming by default
IdleTimeout: 5 * time.Minute,
}
func findKeyDir() string {
p, err := os.Getwd()
if err != nil {
return "."
}
return p
}
func (o *Options) applyDefaults() {
if o.Addr == "" {
o.Addr = defaultOptions.Addr
}
if o.Cert == "" && o.CertFile == "" {
o.CertFile = defaultOptions.CertFile
}
if o.Key == "" && o.KeyFile == "" {
o.KeyFile = defaultOptions.KeyFile
}
if o.ReadHeaderTimeout == 0 {
o.ReadHeaderTimeout = defaultOptions.ReadHeaderTimeout
}
if o.ReadTimeout == 0 {
o.ReadTimeout = defaultOptions.ReadTimeout
}
if o.WriteTimeout == 0 {
o.WriteTimeout = defaultOptions.WriteTimeout
}
if o.IdleTimeout == 0 {
o.IdleTimeout = defaultOptions.IdleTimeout
}
}
// ListenAndServeTLS serves the provided handlers by HTTPS
// using the provided options.
func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.Handler) error {
// NewTLSServer creates a new TLS server given a set of options, handlers, and
// optionally a set of gRPC endpoints as well.
// It is the callers responsibility to close the resturned server.
func NewTLSServer(opt *ServerOptions, httpHandler http.Handler, grpcHandler http.Handler) (*http.Server, error) {
if opt == nil {
opt = defaultOptions
opt = defaultTLSServerOptions
} else {
opt.applyDefaults()
opt.applyTLSDefaults()
}
var cert *tls.Certificate
var err error
@ -93,12 +30,12 @@ func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.
cert, err = readCertificateFile(opt.CertFile, opt.KeyFile)
}
if err != nil {
return fmt.Errorf("https: failed loading x509 certificate: %v", err)
return nil, fmt.Errorf("internal/httputil: failed loading x509 certificate: %v", err)
}
config := newDefaultTLSConfig(cert)
ln, err := net.Listen("tcp", opt.Addr)
if err != nil {
return err
return nil, err
}
ln = tls.NewListener(ln, config)
@ -112,7 +49,7 @@ func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.
sublogger := log.With().Str("addr", opt.Addr).Logger()
// Set up the main server.
server := &http.Server{
srv := &http.Server{
ReadHeaderTimeout: opt.ReadHeaderTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
@ -121,8 +58,13 @@ func ListenAndServeTLS(opt *Options, httpHandler http.Handler, grpcHandler http.
Handler: h,
ErrorLog: stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0),
}
go func() {
if err := srv.Serve(ln); err != http.ErrServerClosed {
log.Error().Err(err).Msg("internal/httputil: tls server crashed")
}
}()
return server.Serve(ln)
return srv, nil
}
func decodeCertificate(cert, key string) (*tls.Certificate, error) {
@ -189,8 +131,8 @@ func newDefaultTLSConfig(cert *tls.Certificate) *tls.Config {
return tlsConfig
}
// grpcHandlerFunc splits request serving between gRPC and HTTPS depending on the request type.
// Requires HTTP/2.
// grpcHandlerFunc splits request serving between gRPC and HTTPS depending on
// the request type. Requires HTTP/2 to be enabled.
func grpcHandlerFunc(rpcServer http.Handler, other http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ct := r.Header.Get("Content-Type")

View file

@ -0,0 +1,210 @@
package httputil
import (
"encoding/base64"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"os/signal"
"syscall"
"testing"
"time"
)
const privKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
-----END EC PRIVATE KEY-----`
const pubKey = `-----BEGIN CERTIFICATE-----
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
-----END CERTIFICATE-----`
func TestNewTLSServer(t *testing.T) {
t.Parallel()
tests := []struct {
name string
opt *ServerOptions
httpHandler http.Handler
grpcHandler http.Handler
// want *http.Server
wantErr bool
}{
{"good basic http handler",
&ServerOptions{
Addr: "127.0.0.1:0",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
false},
{"good basic http and grpc handler",
&ServerOptions{
Addr: "127.0.0.1:0",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
false},
{"good with cert files",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "test_data/cert.pem",
KeyFile: "test_data/privkey.pem",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
false},
{"unreadable cert file",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "test_data",
KeyFile: "test_data/privkey.pem",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
true},
{"unreadable key file",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "./test_data/cert.pem",
KeyFile: "./test_data",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
true},
{"unreadable key file",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "./test_data/cert.pem",
KeyFile: "./test_data/file-does-not-exist",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
true},
{"bad private key base64",
&ServerOptions{
Addr: "127.0.0.1:0",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: "bad guy",
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"bad public key base64",
&ServerOptions{
Addr: "127.0.0.1:9999",
Key: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Cert: "bad guy",
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"bad port - invalid port range ",
&ServerOptions{
Addr: "127.0.0.1:65536",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"nil apply default but will fail",
nil,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"empty, apply defaults to missing",
&ServerOptions{},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv, err := NewTLSServer(tt.opt, tt.httpHandler, tt.grpcHandler)
if (err != nil) != tt.wantErr {
t.Errorf("NewTLSServer() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
// we cheat a little bit here and use the httptest server to test the client
ts := httptest.NewTLSServer(srv.Handler)
defer ts.Close()
client := ts.Client()
res, err := client.Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
}
if srv != nil {
// simulate a sigterm and cleanup the server
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
go Shutdown(srv)
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
}
})
}
}
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
select {
case s := <-c:
if s != sig {
t.Fatalf("signal was %v, want %v", s, sig)
}
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for %v", sig)
}
}

View file

@ -9,11 +9,12 @@ import (
"net/url"
"time"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
)
const (
@ -117,6 +118,8 @@ func (p *Provider) GetSignInURL(state string) string {
// Validate does NOT check if revoked.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (p *Provider) Validate(ctx context.Context, idToken string) (bool, error) {
ctx, span := trace.StartSpan(ctx, "identity.provider.Validate")
defer span.End()
_, err := p.verifier.Verify(ctx, idToken)
if err != nil {
log.Error().Err(err).Msg("identity: failed to verify session state")

View file

@ -1,30 +0,0 @@
package metrics // import "github.com/pomerium/pomerium/internal/metrics"
import (
"net/http"
ocProm "contrib.go.opencensus.io/exporter/prometheus"
prom "github.com/prometheus/client_golang/prometheus"
"go.opencensus.io/stats/view"
)
//NewPromHTTPListener creates a prometheus exporter on ListenAddr
func NewPromHTTPListener(addr string) error {
return http.ListenAndServe(addr, newPromHTTPHandler())
}
// newPromHTTPHandler creates a new prometheus exporter handler for /metrics
func newPromHTTPHandler() http.Handler {
// TODO this is a cheap way to get thorough go process
// stats. It will not work with additional exporters.
// It should turn into an FR to the OC framework
reg := prom.DefaultRegisterer.(*prom.Registry)
pe, _ := ocProm.NewExporter(ocProm.Options{
Namespace: "pomerium",
Registry: reg,
})
view.RegisterExporter(pe)
mux := http.NewServeMux()
mux.Handle("/metrics", pe)
return mux
}

View file

@ -1,151 +0,0 @@
package metrics // import "github.com/pomerium/pomerium/internal/metrics"
import (
"net/http"
"go.opencensus.io/plugin/ochttp"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tripper"
"go.opencensus.io/stats/view"
"go.opencensus.io/tag"
)
var (
httpSizeDistribution = view.Distribution(
1, 256, 512, 1024, 2048, 8192, 16384, 32768, 65536, 131072, 262144, 524288,
1048576, 2097152, 4194304, 8388608,
)
httpLatencyDistrubtion = view.Distribution(
1, 2, 5, 7, 10, 25, 500, 750,
100, 250, 500, 750,
1000, 2500, 5000, 7500,
10000, 25000, 50000, 75000,
100000,
)
// httpClientRequestCount = stats.Int64("http_client_requests_total", "Total HTTP Client Requests", "1")
// httpClientResponseSize = stats.Int64("http_client_response_size_bytes", "HTTP Client Response Size in bytes", "bytes")
// httpClientRequestDuration = stats.Int64("http_client_request_duration_ms", "HTTP Client Request duration in ms", "ms")
// HTTPServerRequestCountView is an OpenCensus View that tracks HTTP server requests by pomerium service, host, method and status
HTTPServerRequestCountView = &view.View{
Name: "http_server_requests_total",
Measure: ochttp.ServerLatency,
Description: "Total HTTP Requests",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod, ochttp.StatusCode},
Aggregation: view.Count(),
}
// HTTPServerRequestDurationView is an OpenCensus view that tracks HTTP server request duration by pomerium service, host, method and status
HTTPServerRequestDurationView = &view.View{
Name: "http_server_request_duration_ms",
Measure: ochttp.ServerLatency,
Description: "HTTP Request duration in ms",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod, ochttp.StatusCode},
Aggregation: httpLatencyDistrubtion,
}
// HTTPServerRequestSizeView is an OpenCensus view that tracks HTTP server request size by pomerium service, host and method
HTTPServerRequestSizeView = &view.View{
Name: "http_server_request_size_bytes",
Measure: ochttp.ServerRequestBytes,
Description: "HTTP Server Request Size in bytes",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod},
Aggregation: httpSizeDistribution,
}
// HTTPServerResponseSizeView is an OpenCensus view that tracks HTTP server response size by pomerium service, host, method and status
HTTPServerResponseSizeView = &view.View{
Name: "http_server_response_size_bytes",
Measure: ochttp.ServerResponseBytes,
Description: "HTTP Server Response Size in bytes",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod, ochttp.StatusCode},
Aggregation: httpSizeDistribution,
}
// HTTPClientRequestCountView is an OpenCensus View that tracks HTTP client requests by pomerium service, destination, host, method and status
HTTPClientRequestCountView = &view.View{
Name: "http_client_requests_total",
Measure: ochttp.ClientRoundtripLatency,
Description: "Total HTTP Client Requests",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod, ochttp.StatusCode, keyDestination},
Aggregation: view.Count(),
}
// HTTPClientRequestDurationView is an OpenCensus view that tracks HTTP client request duration by pomerium service, destination, host, method and status
HTTPClientRequestDurationView = &view.View{
Name: "http_client_request_duration_ms",
Measure: ochttp.ClientRoundtripLatency,
Description: "HTTP Client Request duration in ms",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod, ochttp.StatusCode, keyDestination},
Aggregation: httpLatencyDistrubtion,
}
// HTTPClientResponseSizeView is an OpenCensus view that tracks HTTP client response size by pomerium service, destination, host, method and status
HTTPClientResponseSizeView = &view.View{
Name: "http_client_response_size_bytes",
Measure: ochttp.ClientReceivedBytes,
Description: "HTTP Client Response Size in bytes",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod, ochttp.StatusCode, keyDestination},
Aggregation: httpSizeDistribution,
}
// HTTPClientRequestSizeView is an OpenCensus view that tracks HTTP client request size by pomerium service, destination, host and method
HTTPClientRequestSizeView = &view.View{
Name: "http_client_response_size_bytes",
Measure: ochttp.ClientSentBytes,
Description: "HTTP Client Response Size in bytes",
TagKeys: []tag.Key{keyService, keyHost, keyHTTPMethod, keyDestination},
Aggregation: httpSizeDistribution,
}
)
// HTTPMetricsHandler creates a metrics middleware for incoming HTTP requests
func HTTPMetricsHandler(service string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, tagErr := tag.New(
r.Context(),
tag.Insert(keyService, service),
tag.Insert(keyHost, r.Host),
tag.Insert(keyHTTPMethod, r.Method),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsHandler").Msg("internal/metrics: Failed to create metrics context tag")
next.ServeHTTP(w, r)
return
}
ocHandler := ochttp.Handler{Handler: next}
ocHandler.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// HTTPMetricsRoundTripper creates a metrics tracking tripper for outbound HTTP Requests
func HTTPMetricsRoundTripper(service string, destination string) func(next http.RoundTripper) http.RoundTripper {
return func(next http.RoundTripper) http.RoundTripper {
return tripper.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
ctx, tagErr := tag.New(
r.Context(),
tag.Insert(keyService, service),
tag.Insert(keyHost, r.Host),
tag.Insert(keyHTTPMethod, r.Method),
tag.Insert(keyDestination, destination),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("internal/metrics: Failed to create context tag")
return next.RoundTrip(r)
}
ocTransport := ochttp.Transport{Base: next}
return ocTransport.RoundTrip(r.WithContext(ctx))
})
}
}

View file

@ -1,14 +0,0 @@
package metrics
import (
"go.opencensus.io/tag"
)
var (
keyHTTPMethod tag.Key = tag.MustNewKey("http_method")
keyService tag.Key = tag.MustNewKey("service")
keyGRPCService tag.Key = tag.MustNewKey("grpc_service")
keyGRPCMethod tag.Key = tag.MustNewKey("grpc_method")
keyHost tag.Key = tag.MustNewKey("host")
keyDestination tag.Key = tag.MustNewKey("destination")
)

View file

@ -1,32 +0,0 @@
package metrics
import (
"github.com/pomerium/pomerium/internal/log"
"go.opencensus.io/stats/view"
)
var (
// HTTPClientViews contains opencensus views for HTTP Client metrics
HTTPClientViews = []*view.View{HTTPClientRequestCountView, HTTPClientRequestDurationView, HTTPClientResponseSizeView}
// HTTPServerViews contains opencensus views for HTTP Server metrics
HTTPServerViews = []*view.View{HTTPServerRequestCountView, HTTPServerRequestDurationView, HTTPServerRequestSizeView, HTTPServerResponseSizeView}
// GRPCClientViews contains opencensus views for GRPC Client metrics
GRPCClientViews = []*view.View{GRPCClientRequestCountView, GRPCClientRequestDurationView, GRPCClientResponseSizeView, GRPCClientRequestSizeView}
// GRPCServerViews contains opencensus views for GRPC Server metrics
GRPCServerViews = []*view.View{GRPCServerRequestCountView, GRPCServerRequestDurationView, GRPCServerResponseSizeView, GRPCServerRequestSizeView}
// InfoViews contains opencensus views for Info metrics
InfoViews = []*view.View{ConfigLastReloadView, ConfigLastReloadSuccessView}
)
// RegisterView registers one of the defined metrics views. It must be called for metrics to see metrics
// in the configured exporters
func RegisterView(v []*view.View) {
if err := view.Register(v...); err != nil {
log.Warn().Str("context", "RegisterView").Err(err).Msg("internal/metrics: Could not register view")
}
}
// UnRegisterView unregisters one of the defined metrics views.
func UnRegisterView(v []*view.View) {
view.Unregister(v...)
}

View file

@ -1,25 +0,0 @@
package metrics
import (
"testing"
"go.opencensus.io/stats/view"
)
func Test_RegisterView(t *testing.T) {
RegisterView(HTTPClientViews)
for _, v := range HTTPClientViews {
if view.Find(v.Name) != v {
t.Errorf("Failed to find registered view %s", v.Name)
}
}
}
func Test_UnregisterView(t *testing.T) {
UnRegisterView(HTTPClientViews)
for _, v := range HTTPClientViews {
if view.Find(v.Name) == v {
t.Errorf("Found unregistered view %s", v.Name)
}
}
}

View file

@ -3,6 +3,8 @@ package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"context"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
@ -30,6 +32,9 @@ func (s SharedSecretCred) RequireTransportSecurity() bool { return false }
// handler and returns an error. Otherwise, the interceptor invokes the unary
// handler.
func (s SharedSecretCred) ValidateRequest(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx, span := trace.StartSpan(ctx, "middleware.grpc.ValidateRequest")
defer span.End()
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "missing metadata")

View file

@ -12,6 +12,8 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"golang.org/x/net/publicsuffix"
)
@ -19,10 +21,12 @@ import (
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SetHeaders")
defer span.End()
for key, val := range securityHeaders {
w.Header().Set(key, val)
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
@ -32,6 +36,9 @@ func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateClientSecret")
defer span.End()
if err := r.ParseForm(); err != nil {
httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
@ -47,7 +54,7 @@ func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Hand
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusInternalServerError})
return
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
@ -57,6 +64,8 @@ func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Hand
func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateRedirectURI")
defer span.End()
err := r.ParseForm()
if err != nil {
httpErr := &httputil.Error{
@ -80,7 +89,7 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl
httputil.ErrorResponse(w, r, httpErr)
return
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
@ -103,6 +112,9 @@ func SameDomain(u, j *url.URL) bool {
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
defer span.End()
err := r.ParseForm()
if err != nil {
httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
@ -120,7 +132,7 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
return
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
@ -129,11 +141,14 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateHost")
defer span.End()
if !validHost(r.Host) {
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusNotFound})
return
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
@ -145,13 +160,16 @@ func ValidateHost(validHost func(host string) bool) func(next http.Handler) http
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
f := func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck")
defer span.End()
if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(msg))
return
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}

View file

@ -6,11 +6,14 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
)
func SignRequest(signer cryptutil.JWTSigner, id, email, groups, header string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest")
defer span.End()
jwt, err := signer.SignJWT(
r.Header.Get(id),
r.Header.Get(email),
@ -20,7 +23,7 @@ func SignRequest(signer cryptutil.JWTSigner, id, email, groups, header string) f
} else {
r.Header.Set(header, jwt)
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
@ -29,6 +32,9 @@ func SignRequest(signer cryptutil.JWTSigner, id, email, groups, header string) f
func StripPomeriumCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest")
defer span.End()
headers := make([]string, len(r.Cookies()))
for _, cookie := range r.Cookies() {
if cookie.Name != cookieName {
@ -36,7 +42,7 @@ func StripPomeriumCookie(cookieName string) func(next http.Handler) http.Handler
}
}
r.Header.Set("Cookie", strings.Join(headers, ";"))
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View file

@ -0,0 +1,41 @@
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"go.opencensus.io/plugin/ocgrpc"
"go.opencensus.io/stats/view"
"go.opencensus.io/tag"
)
// The following tags are applied to stats recorded by this package.
var (
TagKeyHTTPMethod tag.Key = tag.MustNewKey("http_method")
TagKeyService tag.Key = tag.MustNewKey("service")
TagKeyGRPCService tag.Key = tag.MustNewKey("grpc_service")
TagKeyGRPCMethod tag.Key = tag.MustNewKey("grpc_method")
TagKeyHost tag.Key = tag.MustNewKey("host")
TagKeyDestination tag.Key = tag.MustNewKey("destination")
)
// Default distributions used by views in this package.
var (
DefaulHTTPSizeDistribution = view.Distribution(
1, 256, 512, 1024, 2048, 8192, 16384, 32768, 65536, 131072, 262144,
524288, 1048576, 2097152, 4194304, 8388608)
DefaultHTTPLatencyDistrubtion = view.Distribution(
1, 2, 5, 7, 10, 25, 500, 750, 100, 250, 500, 750, 1000, 2500, 5000,
7500, 10000, 25000, 50000, 75000, 100000)
grpcSizeDistribution = view.Distribution(
1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024,
2048, 4096, 8192, 16384,
)
DefaultMillisecondsDistribution = ocgrpc.DefaultMillisecondsDistribution
)
// DefaultViews are a set of default views to view HTTP and GRPC metrics.
var (
DefaultViews = [][]*view.View{
GRPCServerViews,
HTTPServerViews,
GRPCClientViews,
GRPCServerViews}
)

View file

@ -1,4 +1,4 @@
package metrics // import "github.com/pomerium/pomerium/internal/metrics"
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"context"
@ -12,93 +12,98 @@ import (
grpcstats "google.golang.org/grpc/stats"
)
// GRPC Views
var (
grpcSizeDistribution = view.Distribution(
1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024,
2048, 4096, 8192, 16384,
)
grcpLatencyDistribution = view.Distribution(
1, 2, 5, 7, 10, 25, 50, 75,
100, 250, 500, 750, 1000,
)
// GRPCClientViews contains opencensus views for GRPC Client metrics.
GRPCClientViews = []*view.View{
GRPCClientRequestCountView,
GRPCClientRequestDurationView,
GRPCClientResponseSizeView,
GRPCClientRequestSizeView}
// GRPCServerViews contains opencensus views for GRPC Server metrics.
GRPCServerViews = []*view.View{
GRPCServerRequestCountView,
GRPCServerRequestDurationView,
GRPCServerResponseSizeView,
GRPCServerRequestSizeView}
// GRPCServerRequestCountView is an OpenCensus view which counts GRPC Server
// requests by pomerium service, grpc service, grpc method, and status
GRPCServerRequestCountView = &view.View{
Name: "grpc_server_requests_total",
Name: "grpc/server/requests_total",
Measure: ocgrpc.ServerLatency,
Description: "Total grpc Requests",
TagKeys: []tag.Key{keyService, keyGRPCMethod, ocgrpc.KeyServerStatus, keyGRPCService},
TagKeys: []tag.Key{TagKeyService, TagKeyGRPCMethod, ocgrpc.KeyServerStatus, TagKeyGRPCService},
Aggregation: view.Count(),
}
// GRPCServerRequestDurationView is an OpenCensus view which tracks GRPC Server
// request duration by pomerium service, grpc service, grpc method, and status
GRPCServerRequestDurationView = &view.View{
Name: "grpc_server_request_duration_ms",
Name: "grpc/server/request_duration_ms",
Measure: ocgrpc.ServerLatency,
Description: "grpc Request duration in ms",
TagKeys: []tag.Key{keyService, keyGRPCMethod, ocgrpc.KeyServerStatus, keyGRPCService},
Aggregation: grcpLatencyDistribution,
TagKeys: []tag.Key{TagKeyService, TagKeyGRPCMethod, ocgrpc.KeyServerStatus, TagKeyGRPCService},
Aggregation: DefaultMillisecondsDistribution,
}
// GRPCServerResponseSizeView is an OpenCensus view which tracks GRPC Server
// response size by pomerium service, grpc service, grpc method, and status
GRPCServerResponseSizeView = &view.View{
Name: "grpc_server_response_size_bytes",
Name: "grpc/server/response_size_bytes",
Measure: ocgrpc.ServerSentBytesPerRPC,
Description: "grpc Server Response Size in bytes",
TagKeys: []tag.Key{keyService, keyGRPCMethod, ocgrpc.KeyServerStatus, keyGRPCService},
TagKeys: []tag.Key{TagKeyService, TagKeyGRPCMethod, ocgrpc.KeyServerStatus, TagKeyGRPCService},
Aggregation: grpcSizeDistribution,
}
// GRPCServerRequestSizeView is an OpenCensus view which tracks GRPC Server
// request size by pomerium service, grpc service, grpc method, and status
GRPCServerRequestSizeView = &view.View{
Name: "grpc_server_request_size_bytes",
Name: "grpc/server/request_size_bytes",
Measure: ocgrpc.ServerReceivedBytesPerRPC,
Description: "grpc Server Request Size in bytes",
TagKeys: []tag.Key{keyService, keyGRPCMethod, ocgrpc.KeyServerStatus, keyGRPCService},
TagKeys: []tag.Key{TagKeyService, TagKeyGRPCMethod, ocgrpc.KeyServerStatus, TagKeyGRPCService},
Aggregation: grpcSizeDistribution,
}
// GRPCClientRequestCountView is an OpenCensus view which tracks GRPC Client
// requests by pomerium service, target host, grpc service, grpc method, and status
GRPCClientRequestCountView = &view.View{
Name: "grpc_client_requests_total",
Name: "grpc/client/requests_total",
Measure: ocgrpc.ClientRoundtripLatency,
Description: "Total grpc Client Requests",
TagKeys: []tag.Key{keyService, keyHost, keyGRPCMethod, keyGRPCService, ocgrpc.KeyClientStatus},
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyGRPCMethod, TagKeyGRPCService, ocgrpc.KeyClientStatus},
Aggregation: view.Count(),
}
// GRPCClientRequestDurationView is an OpenCensus view which tracks GRPC Client
// request duration by pomerium service, target host, grpc service, grpc method, and status
GRPCClientRequestDurationView = &view.View{
Name: "grpc_client_request_duration_ms",
Name: "grpc/client/request_duration_ms",
Measure: ocgrpc.ClientRoundtripLatency,
Description: "grpc Client Request duration in ms",
TagKeys: []tag.Key{keyService, keyHost, keyGRPCMethod, keyGRPCService, ocgrpc.KeyClientStatus},
Aggregation: grcpLatencyDistribution,
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyGRPCMethod, TagKeyGRPCService, ocgrpc.KeyClientStatus},
Aggregation: DefaultMillisecondsDistribution,
}
// GRPCClientResponseSizeView is an OpenCensus view which tracks GRPC Client
// response size by pomerium service, target host, grpc service, grpc method, and status
GRPCClientResponseSizeView = &view.View{
Name: "grpc_client_response_size_bytes",
Name: "grpc/client/response_size_bytes",
Measure: ocgrpc.ClientReceivedBytesPerRPC,
Description: "grpc Client Response Size in bytes",
TagKeys: []tag.Key{keyService, keyHost, keyGRPCMethod, keyGRPCService, ocgrpc.KeyClientStatus},
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyGRPCMethod, TagKeyGRPCService, ocgrpc.KeyClientStatus},
Aggregation: grpcSizeDistribution,
}
// GRPCClientRequestSizeView is an OpenCensus view which tracks GRPC Client
// request size by pomerium service, target host, grpc service, grpc method, and status
GRPCClientRequestSizeView = &view.View{
Name: "grpc_client_request_size_bytes",
Name: "grpc/client/request_size_bytes",
Measure: ocgrpc.ClientSentBytesPerRPC,
Description: "grpc Client Request Size in bytes",
TagKeys: []tag.Key{keyService, keyHost, keyGRPCMethod, keyGRPCService, ocgrpc.KeyClientStatus},
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyGRPCMethod, TagKeyGRPCService, ocgrpc.KeyClientStatus},
Aggregation: grpcSizeDistribution,
}
)
@ -126,13 +131,13 @@ func GRPCClientInterceptor(service string) grpc.UnaryClientInterceptor {
taggedCtx, tagErr := tag.New(
ctx,
tag.Insert(keyService, service),
tag.Insert(keyHost, cc.Target()),
tag.Insert(keyGRPCMethod, rpcMethod),
tag.Insert(keyGRPCService, rpcService),
tag.Insert(TagKeyService, service),
tag.Insert(TagKeyHost, cc.Target()),
tag.Insert(TagKeyGRPCMethod, rpcMethod),
tag.Insert(TagKeyGRPCService, rpcService),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("internal/metrics: Failed to create context")
log.Warn().Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("internal/telemetry: Failed to create context")
return invoker(ctx, method, req, reply, cc, opts...)
}
@ -165,12 +170,12 @@ func (h *GRPCServerStatsHandler) TagRPC(ctx context.Context, tagInfo *grpcstats.
taggedCtx, tagErr := tag.New(
handledCtx,
tag.Insert(keyService, h.service),
tag.Insert(keyGRPCMethod, rpcMethod),
tag.Insert(keyGRPCService, rpcService),
tag.Insert(TagKeyService, h.service),
tag.Insert(TagKeyGRPCMethod, rpcMethod),
tag.Insert(TagKeyGRPCService, rpcService),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("internal/metrics: Failed to create context")
log.Warn().Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("internal/telemetry: Failed to create context")
return handledCtx
}
@ -180,6 +185,5 @@ func (h *GRPCServerStatsHandler) TagRPC(ctx context.Context, tagInfo *grpcstats.
// NewGRPCServerStatsHandler creates a new GRPCServerStatsHandler for a pomerium service
func NewGRPCServerStatsHandler(service string) grpcstats.Handler {
return &GRPCServerStatsHandler{service: service, Handler: &ocgrpc.ServerHandler{}}
}

View file

@ -1,10 +1,11 @@
package metrics
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"context"
"testing"
"go.opencensus.io/plugin/ocgrpc"
"go.opencensus.io/stats/view"
"google.golang.org/grpc"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
@ -97,8 +98,8 @@ func Test_GRPCClientInterceptor(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
UnRegisterView(GRPCClientViews)
RegisterView(GRPCClientViews)
view.Unregister(GRPCClientViews...)
view.Register(GRPCClientViews...)
invoker := testInvoker{
invokeResult: tt.errorCode,
@ -167,8 +168,8 @@ func Test_GRPCServerStatsHandler(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
UnRegisterView(GRPCServerViews)
RegisterView(GRPCServerViews)
view.Unregister(GRPCServerViews...)
view.Register(GRPCServerViews...)
statsHandler := NewGRPCServerStatsHandler("test_service")
mockServerRPCHandle(statsHandler, tt.method, tt.errorCode)

View file

@ -1,41 +1,12 @@
package metrics
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"go.opencensus.io/metric/metricdata"
"go.opencensus.io/stats/view"
)
func testDataRetrieval(v *view.View, t *testing.T, want string) {
if v == nil {
t.Fatalf("%s: nil view passed", t.Name())
}
name := v.Name
data, err := view.RetrieveData(name)
if err != nil {
t.Fatalf("%s: failed to retrieve data line %s", name, err)
}
if want != "" && len(data) != 1 {
t.Fatalf("%s: received incorrect number of data rows: %d", name, len(data))
}
if want == "" && len(data) > 0 {
t.Fatalf("%s: received incorrect number of data rows: %d", name, len(data))
} else if want == "" {
return
}
dataString := data[0].String()
if want != "" && !strings.HasPrefix(dataString, want) {
t.Errorf("%s: Found unexpected data row: \nwant: %s\ngot: %s\n", name, want, dataString)
}
}
func testMetricRetrieval(metrics []*metricdata.Metric, t *testing.T, labels []metricdata.LabelValue, value interface{}, name string) {
switch value.(type) {
case int64:

View file

@ -0,0 +1,157 @@
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"fmt"
"net/http"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tripper"
"go.opencensus.io/plugin/ochttp"
"go.opencensus.io/stats/view"
"go.opencensus.io/tag"
)
// HTTP Views
var (
// HTTPClientViews contains opencensus views for HTTP Client metrics.
HTTPClientViews = []*view.View{
HTTPClientRequestCountView,
HTTPClientRequestDurationView,
HTTPClientResponseSizeView}
// HTTPServerViews contains opencensus views for HTTP Server metrics.
HTTPServerViews = []*view.View{
HTTPServerRequestCountView,
HTTPServerRequestDurationView,
HTTPServerRequestSizeView,
HTTPServerResponseSizeView}
// HTTPServerRequestCountView is an OpenCensus View that tracks HTTP server
// requests by pomerium service, host, method and status
HTTPServerRequestCountView = &view.View{
Name: "http/server/requests_total",
Measure: ochttp.ServerLatency,
Description: "Total HTTP Requests",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod, ochttp.StatusCode},
Aggregation: view.Count(),
}
// HTTPServerRequestDurationView is an OpenCensus view that tracks HTTP
// server request duration by pomerium service, host, method and status
HTTPServerRequestDurationView = &view.View{
Name: "http/server/request_duration_ms",
Measure: ochttp.ServerLatency,
Description: "HTTP Request duration in ms",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod, ochttp.StatusCode},
Aggregation: DefaultHTTPLatencyDistrubtion,
}
// HTTPServerRequestSizeView is an OpenCensus view that tracks HTTP server
// request size by pomerium service, host and method
HTTPServerRequestSizeView = &view.View{
Name: "http/server/request_size_bytes",
Measure: ochttp.ServerRequestBytes,
Description: "HTTP Server Request Size in bytes",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod},
Aggregation: DefaulHTTPSizeDistribution,
}
// HTTPServerResponseSizeView is an OpenCensus view that tracks HTTP server
// response size by pomerium service, host, method and status
HTTPServerResponseSizeView = &view.View{
Name: "http/server/response_size_bytes",
Measure: ochttp.ServerResponseBytes,
Description: "HTTP Server Response Size in bytes",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod, ochttp.StatusCode},
Aggregation: DefaulHTTPSizeDistribution,
}
// HTTPClientRequestCountView is an OpenCensus View that tracks HTTP client
// requests by pomerium service, destination, host, method and status
HTTPClientRequestCountView = &view.View{
Name: "http/client/requests_total",
Measure: ochttp.ClientRoundtripLatency,
Description: "Total HTTP Client Requests",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod, ochttp.StatusCode, TagKeyDestination},
Aggregation: view.Count(),
}
// HTTPClientRequestDurationView is an OpenCensus view that tracks HTTP
// client request duration by pomerium service, destination, host, method and status
HTTPClientRequestDurationView = &view.View{
Name: "http/client/request_duration_ms",
Measure: ochttp.ClientRoundtripLatency,
Description: "HTTP Client Request duration in ms",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod, ochttp.StatusCode, TagKeyDestination},
Aggregation: DefaultHTTPLatencyDistrubtion,
}
// HTTPClientResponseSizeView is an OpenCensus view that tracks HTTP client
// esponse size by pomerium service, destination, host, method and status
HTTPClientResponseSizeView = &view.View{
Name: "http/client/response_size_bytes",
Measure: ochttp.ClientReceivedBytes,
Description: "HTTP Client Response Size in bytes",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod, ochttp.StatusCode, TagKeyDestination},
Aggregation: DefaulHTTPSizeDistribution,
}
// HTTPClientRequestSizeView is an OpenCensus view that tracks HTTP client
//request size by pomerium service, destination, host and method
HTTPClientRequestSizeView = &view.View{
Name: "http/client/response_size_bytes",
Measure: ochttp.ClientSentBytes,
Description: "HTTP Client Response Size in bytes",
TagKeys: []tag.Key{TagKeyService, TagKeyHost, TagKeyHTTPMethod, TagKeyDestination},
Aggregation: DefaulHTTPSizeDistribution,
}
)
// HTTPMetricsHandler creates a metrics middleware for incoming HTTP requests
func HTTPMetricsHandler(service string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, tagErr := tag.New(
r.Context(),
tag.Insert(TagKeyService, service),
tag.Insert(TagKeyHost, r.Host),
tag.Insert(TagKeyHTTPMethod, r.Method),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsHandler").
Msg("telemetry/metrics: failed to create metrics tag")
next.ServeHTTP(w, r)
return
}
ocHandler := ochttp.Handler{
Handler: next,
FormatSpanName: func(r *http.Request) string {
return fmt.Sprintf("%s%s", r.Host, r.URL.Path)
},
}
ocHandler.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// HTTPMetricsRoundTripper creates a metrics tracking tripper for outbound HTTP Requests
func HTTPMetricsRoundTripper(service string, destination string) func(next http.RoundTripper) http.RoundTripper {
return func(next http.RoundTripper) http.RoundTripper {
return tripper.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
ctx, tagErr := tag.New(
r.Context(),
tag.Insert(TagKeyService, service),
tag.Insert(TagKeyHost, r.Host),
tag.Insert(TagKeyHTTPMethod, r.Method),
tag.Insert(TagKeyDestination, destination),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("telemetry/metrics: failed to create metrics tag")
return next.RoundTrip(r)
}
ocTransport := ochttp.Transport{Base: next}
return ocTransport.RoundTrip(r.WithContext(ctx))
})
}
}

View file

@ -1,4 +1,4 @@
package metrics // import "github.com/pomerium/pomerium/internal/metrics"
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"bytes"
@ -7,13 +7,40 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/tripper"
"go.opencensus.io/stats/view"
)
func testDataRetrieval(v *view.View, t *testing.T, want string) {
if v == nil {
t.Fatalf("%s: nil view passed", t.Name())
}
name := v.Name
data, err := view.RetrieveData(name)
if err != nil {
t.Fatalf("%s: failed to retrieve data line %s", name, err)
}
if want != "" && len(data) != 1 {
t.Fatalf("%s: received incorrect number of data rows: %d", name, len(data))
}
if want == "" && len(data) > 0 {
t.Fatalf("%s: received incorrect number of data rows: %d", name, len(data))
} else if want == "" {
return
}
dataString := data[0].String()
if want != "" && !strings.HasPrefix(dataString, want) {
t.Errorf("%s: Found unexpected data row: \nwant: %s\ngot: %s\n", name, want, dataString)
}
}
func newTestMux() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/good", func(w http.ResponseWriter, r *http.Request) {
@ -25,10 +52,6 @@ func newTestMux() http.Handler {
func Test_HTTPMetricsHandler(t *testing.T) {
chain := middleware.NewChain()
chain = chain.Append(HTTPMetricsHandler("test_service"))
chainHandler := chain.Then(newTestMux())
tests := []struct {
name string
url string
@ -73,7 +96,9 @@ func Test_HTTPMetricsHandler(t *testing.T) {
req := httptest.NewRequest(tt.verb, tt.url, new(bytes.Buffer))
rec := httptest.NewRecorder()
chainHandler.ServeHTTP(rec, req)
h := HTTPMetricsHandler("test_service")(newTestMux())
h.ServeHTTP(rec, req)
testDataRetrieval(HTTPServerRequestSizeView, t, tt.wanthttpServerRequestSize)
testDataRetrieval(HTTPServerResponseSizeView, t, tt.wanthttpServerResponseSize)

View file

@ -1,4 +1,4 @@
package metrics // import "github.com/pomerium/pomerium/internal/metrics"
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"context"
@ -8,6 +8,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/version"
"go.opencensus.io/metric"
"go.opencensus.io/metric/metricdata"
"go.opencensus.io/metric/metricproducer"
@ -17,44 +18,53 @@ import (
)
var (
//buildInfo = stats.Int64("build_info", "Build Metadata", "1")
configLastReload = stats.Int64("config_last_reload_success_timestamp", "Timestamp of last successful config reload", "seconds")
configLastReloadSuccess = stats.Int64("config_last_reload_success", "Returns 1 if last reload was successful", "1")
registry = newMetricRegistry()
// InfoViews contains opencensus views for informational metrics about
// pomerium itself.
InfoViews = []*view.View{ConfigLastReloadView, ConfigLastReloadSuccessView}
configLastReload = stats.Int64(
"config_last_reload_success_timestamp",
"Timestamp of last successful config reload",
"seconds")
configLastReloadSuccess = stats.Int64(
"config_last_reload_success",
"Returns 1 if last reload was successful",
"1")
registry = newMetricRegistry()
// ConfigLastReloadView contains the timestamp the configuration was last
// reloaded, labeled by service
// reloaded, labeled by service.
ConfigLastReloadView = &view.View{
Name: configLastReload.Name(),
Description: configLastReload.Description(),
Measure: configLastReload,
TagKeys: []tag.Key{keyService},
TagKeys: []tag.Key{TagKeyService},
Aggregation: view.LastValue(),
}
// ConfigLastReloadSuccessView contains the result of the last configuration
// reload, labeled by service
// reload, labeled by service.
ConfigLastReloadSuccessView = &view.View{
Name: configLastReloadSuccess.Name(),
Description: configLastReloadSuccess.Description(),
Measure: configLastReloadSuccess,
TagKeys: []tag.Key{keyService},
TagKeys: []tag.Key{TagKeyService},
Aggregation: view.LastValue(),
}
)
// SetConfigInfo records the status, checksum and timestamp of a configuration reload. You must register InfoViews or the related
// config views before calling
// SetConfigInfo records the status, checksum and timestamp of a configuration
// reload. You must register InfoViews or the related config views before calling
func SetConfigInfo(service string, success bool, checksum string) {
if success {
serviceTag := tag.Insert(keyService, service)
serviceTag := tag.Insert(TagKeyService, service)
if err := stats.RecordWithTags(
context.Background(),
[]tag.Mutator{serviceTag},
configLastReload.M(time.Now().Unix()),
); err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to record config checksum timestamp")
log.Error().Err(err).Msg("internal/telemetry: failed to record config checksum timestamp")
}
if err := stats.RecordWithTags(
@ -62,7 +72,7 @@ func SetConfigInfo(service string, success bool, checksum string) {
[]tag.Mutator{serviceTag},
configLastReloadSuccess.M(1),
); err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to record config reload")
log.Error().Err(err).Msg("internal/telemetry: failed to record config reload")
}
} else {
stats.Record(context.Background(), configLastReloadSuccess.M(0))
@ -96,7 +106,7 @@ func (r *metricRegistry) init() {
metric.WithLabelKeys("service", "version", "revision", "goversion"),
)
if err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to register build info metric")
log.Error().Err(err).Msg("internal/telemetry: failed to register build info metric")
}
r.configChecksum, err = r.registry.AddFloat64Gauge("config_checksum_decimal",
@ -104,7 +114,7 @@ func (r *metricRegistry) init() {
metric.WithLabelKeys("service"),
)
if err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to register config checksum metric")
log.Error().Err(err).Msg("internal/telemetry: failed to register config checksum metric")
}
r.policyCount, err = r.registry.AddInt64DerivedGauge("policy_count_total",
@ -112,7 +122,7 @@ func (r *metricRegistry) init() {
metric.WithLabelKeys("service"),
)
if err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to register policy count metric")
log.Error().Err(err).Msg("internal/telemetry: failed to register policy count metric")
}
})
}
@ -130,7 +140,7 @@ func (r *metricRegistry) setBuildInfo(service string) {
metricdata.NewLabelValue((runtime.Version())),
)
if err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to get build info metric")
log.Error().Err(err).Msg("internal/telemetry: failed to get build info metric")
}
// This sets our build_info metric to a constant 1 per
@ -155,7 +165,7 @@ func (r *metricRegistry) setConfigChecksum(service string, checksum uint64) {
}
m, err := r.configChecksum.GetEntry(metricdata.NewLabelValue(service))
if err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to get config checksum metric")
log.Error().Err(err).Msg("internal/telemetry: failed to get config checksum metric")
}
m.Set(float64(checksum))
}
@ -172,7 +182,7 @@ func (r *metricRegistry) addPolicyCountCallback(service string, f func() int64)
}
err := r.policyCount.UpsertEntry(f, metricdata.NewLabelValue(service))
if err != nil {
log.Error().Err(err).Msg("internal/metrics: failed to get policy count metric")
log.Error().Err(err).Msg("internal/telemetry: failed to get policy count metric")
}
}

View file

@ -1,4 +1,4 @@
package metrics // import "github.com/pomerium/pomerium/internal/metrics"
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"runtime"
@ -8,6 +8,7 @@ import (
"go.opencensus.io/metric/metricdata"
"go.opencensus.io/metric/metricproducer"
"go.opencensus.io/stats/view"
)
func Test_SetConfigInfo(t *testing.T) {
@ -24,9 +25,8 @@ func Test_SetConfigInfo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
UnRegisterView(InfoViews)
RegisterView(InfoViews)
view.Unregister(InfoViews...)
view.Register(InfoViews...)
SetConfigInfo("test_service", tt.success, tt.checksum)
testDataRetrieval(ConfigLastReloadView, t, tt.wantLastReload)

View file

@ -0,0 +1,39 @@
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"fmt"
"net/http"
ocprom "contrib.go.opencensus.io/exporter/prometheus"
prom "github.com/prometheus/client_golang/prometheus"
"go.opencensus.io/stats/view"
)
// PrometheusHandler creates an exporter that exports stats to Prometheus
// and returns a handler suitable for exporting metrics.
func PrometheusHandler() (http.Handler, error) {
if err := registerDefaultViews(); err != nil {
return nil, fmt.Errorf("internal/telemetry: failed registering views")
}
reg := prom.DefaultRegisterer.(*prom.Registry)
exporter, err := ocprom.NewExporter(
ocprom.Options{
Namespace: "pomerium",
Registry: reg,
})
if err != nil {
return nil, fmt.Errorf("internal/telemetry: prometheus exporter: %v", err)
}
view.RegisterExporter(exporter)
mux := http.NewServeMux()
mux.Handle("/metrics", exporter)
return mux, nil
}
func registerDefaultViews() error {
var views []*view.View
for _, v := range DefaultViews {
views = append(views, v...)
}
return view.Register(views...)
}

View file

@ -1,4 +1,4 @@
package metrics // import "github.com/pomerium/pomerium/internal/metrics"
package metrics // import "github.com/pomerium/pomerium/internal/telemetry/metrics"
import (
"bytes"
@ -8,9 +8,11 @@ import (
"testing"
)
func Test_newPromHTTPHandler(t *testing.T) {
h := newPromHTTPHandler()
func Test_PrometheusHandler(t *testing.T) {
h, err := PrometheusHandler()
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("GET", "http://test.local/metrics", new(bytes.Buffer))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)

View file

@ -0,0 +1,74 @@
package trace // import "github.com/pomerium/pomerium/internal/telemetry/trace"
import (
"context"
"fmt"
"github.com/pomerium/pomerium/internal/log"
"contrib.go.opencensus.io/exporter/jaeger"
"go.opencensus.io/trace"
)
const (
JaegerTracingProviderName = "jaeger"
)
// TracingOptions contains the configurations settings for a http server.
type TracingOptions struct {
// Shared
Provider string
Service string
Debug bool
// Jaeger
// CollectorEndpoint is the full url to the Jaeger HTTP Thrift collector.
// For example, http://localhost:14268/api/traces
JaegerCollectorEndpoint string `mapstructure:"tracing_jaeger_collector_endpoint"`
// AgentEndpoint instructs exporter to send spans to jaeger-agent at this address.
// For example, localhost:6831.
JaegerAgentEndpoint string `mapstructure:"tracing_jaeger_agent_endpoint"`
}
func RegisterTracing(opts *TracingOptions) error {
var err error
switch opts.Provider {
case JaegerTracingProviderName:
err = registerJaeger(opts)
default:
return fmt.Errorf("telemetry/trace: provider %s unknown", opts.Provider)
}
if err != nil {
return err
}
if opts.Debug {
log.Debug().Msg("telemetry/trace: debug on, sample everything")
trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
}
log.Debug().Interface("Opts", opts).Msg("telemetry/trace: exporter created")
return nil
}
func registerJaeger(opts *TracingOptions) error {
jex, err := jaeger.NewExporter(
jaeger.Options{
AgentEndpoint: opts.JaegerAgentEndpoint,
CollectorEndpoint: opts.JaegerCollectorEndpoint,
ServiceName: opts.Service,
})
if err != nil {
return err
}
trace.RegisterExporter(jex)
return nil
}
// StartSpan starts a new child span of the current span in the context. If
// there is no span in the context, creates a new trace and span.
//
// Returned context contains the newly created span. You can use it to
// propagate the returned span in process.
func StartSpan(ctx context.Context, name string, o ...trace.StartOption) (context.Context, *trace.Span) {
return trace.StartSpan(ctx, name, o...)
}

View file

@ -0,0 +1,23 @@
package trace // import "github.com/pomerium/pomerium/internal/telemetry/trace"
import "testing"
func TestRegisterTracing(t *testing.T) {
tests := []struct {
name string
opts *TracingOptions
wantErr bool
}{
{"jaeger", &TracingOptions{JaegerAgentEndpoint: "localhost:6831", Service: "all", Provider: "jaeger"}, false},
{"jaeger with debug", &TracingOptions{JaegerAgentEndpoint: "localhost:6831", Service: "all", Provider: "jaeger", Debug: true}, false},
{"jaeger no endpoint", &TracingOptions{JaegerAgentEndpoint: "", Service: "all", Provider: "jaeger"}, true},
{"unknown provider", &TracingOptions{JaegerAgentEndpoint: "localhost:0", Service: "all", Provider: "Lucius Cornelius Sulla"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := RegisterTracing(tt.opts); (err != nil) != tt.wantErr {
t.Errorf("RegisterTracing() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}