mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 09:19:39 +02:00
proxy: add JWT request signing support (#19)
- Refactored middleware and request hander logging. - Request refactored to use context.Context. - Add helper (based on Alice) to allow middleware chaining. - Add helper scripts to generate elliptic curve self-signed certificate that can be used to sign JWT. - Changed LetsEncrypt scripts to use acme instead of certbot. - Add script to have LetsEncrypt sign an RSA based certificate. - Add documentation to explain how to verify headers. - Refactored internal/cryptutil signer's code to expect a valid EC priv key. - Changed JWT expiries to use default leeway period. - Update docs and add screenshots. - Replaced logging handler logic to use context.Context. - Removed specific XML error handling. - Refactored handler function signatures to prefer standard go idioms.
This commit is contained in:
parent
98b8c7481f
commit
426e003b03
30 changed files with 1711 additions and 588 deletions
102
internal/cryptutil/marshal.go
Normal file
102
internal/cryptutil/marshal.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Package cryptutil provides encoding and decoding routines for various cryptographic structures.
|
||||
package cryptutil
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DecodePublicKey decodes a PEM-encoded ECDSA public key.
|
||||
func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) {
|
||||
block, _ := pem.Decode(encodedKey)
|
||||
if block == nil || block.Type != "PUBLIC KEY" {
|
||||
return nil, fmt.Errorf("marshal: could not decode PEM block type %s", block.Type)
|
||||
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ecdsaPub, ok := pub.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("marshal: data was not an ECDSA public key")
|
||||
}
|
||||
|
||||
return ecdsaPub, nil
|
||||
}
|
||||
|
||||
// EncodePublicKey encodes an ECDSA public key to PEM format.
|
||||
func EncodePublicKey(key *ecdsa.PublicKey) ([]byte, error) {
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block := &pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: derBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(block), nil
|
||||
}
|
||||
|
||||
// DecodePrivateKey decodes a PEM-encoded ECDSA private key.
|
||||
func DecodePrivateKey(encodedKey []byte) (*ecdsa.PrivateKey, error) {
|
||||
var skippedTypes []string
|
||||
var block *pem.Block
|
||||
|
||||
for {
|
||||
block, encodedKey = pem.Decode(encodedKey)
|
||||
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes)
|
||||
}
|
||||
|
||||
if block.Type == "EC PRIVATE KEY" {
|
||||
break
|
||||
} else {
|
||||
skippedTypes = append(skippedTypes, block.Type)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
privKey, err := x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return privKey, nil
|
||||
}
|
||||
|
||||
// EncodePrivateKey encodes an ECDSA private key to PEM format.
|
||||
func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
derKey, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: derKey,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(keyBlock), nil
|
||||
}
|
||||
|
||||
// EncodeSignatureJWT encodes an ECDSA signature according to
|
||||
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||
func EncodeSignatureJWT(sig []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// DecodeSignatureJWT decodes an ECDSA signature according to
|
||||
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||
func DecodeSignatureJWT(b64sig string) ([]byte, error) {
|
||||
return base64.RawURLEncoding.DecodeString(b64sig)
|
||||
}
|
122
internal/cryptutil/marshal_test.go
Normal file
122
internal/cryptutil/marshal_test.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// A keypair for NIST P-256 / secp256r1
|
||||
// Generated using:
|
||||
// openssl ecparam -genkey -name prime256v1 -outform PEM
|
||||
var pemECPrivateKeyP256 = `-----BEGIN EC PARAMETERS-----
|
||||
BggqhkjOPQMBBw==
|
||||
-----END EC PARAMETERS-----
|
||||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIOI+EZsjyN3jvWJI/KDihFmqTuDpUe/if6f/pgGTBta/oAoGCCqGSM49
|
||||
AwEHoUQDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvXnFUSTwqQFo5gbfIlP+gvEYba
|
||||
+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`
|
||||
|
||||
var pemECPublicKeyP256 = `-----BEGIN PUBLIC KEY-----
|
||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvX
|
||||
nFUSTwqQFo5gbfIlP+gvEYba+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||
-----END PUBLIC KEY-----
|
||||
`
|
||||
|
||||
// A keypair for NIST P-384 / secp384r1
|
||||
// Generated using:
|
||||
// openssl ecparam -genkey -name secp384r1 -outform PEM
|
||||
var pemECPrivateKeyP384 = `-----BEGIN EC PARAMETERS-----
|
||||
BgUrgQQAIg==
|
||||
-----END EC PARAMETERS-----
|
||||
-----BEGIN EC PRIVATE KEY-----
|
||||
MIGkAgEBBDAhA0YPVL1kimIy+FAqzUAtmR3It2Yjv2I++YpcC4oX7wGuEWcWKBYE
|
||||
oOjj7wG/memgBwYFK4EEACKhZANiAAQub8xaaCTTW5rCHJCqUddIXpvq/TxdwViH
|
||||
+tPEQQlJAJciXStM/aNLYA7Q1K1zMjYyzKSWz5kAh/+x4rXQ9Hlm3VAwCQDVVSjP
|
||||
bfiNOXKOWfmyrGyQ7fQfs+ro1lmjLjs=
|
||||
-----END EC PRIVATE KEY-----
|
||||
`
|
||||
|
||||
var pemECPublicKeyP384 = `-----BEGIN PUBLIC KEY-----
|
||||
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAELm/MWmgk01uawhyQqlHXSF6b6v08XcFY
|
||||
h/rTxEEJSQCXIl0rTP2jS2AO0NStczI2Msykls+ZAIf/seK10PR5Zt1QMAkA1VUo
|
||||
z234jTlyjln5sqxskO30H7Pq6NZZoy47
|
||||
-----END PUBLIC KEY-----
|
||||
`
|
||||
|
||||
var garbagePEM = `-----BEGIN GARBAGE-----
|
||||
TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQ=
|
||||
-----END GARBAGE-----
|
||||
`
|
||||
|
||||
func TestPublicKeyMarshaling(t *testing.T) {
|
||||
ecKey, err := DecodePublicKey([]byte(pemECPublicKeyP256))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pemBytes, _ := EncodePublicKey(ecKey)
|
||||
if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) {
|
||||
t.Fatal("public key encoding did not match")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPrivateKeyBadDecode(t *testing.T) {
|
||||
_, err := DecodePrivateKey([]byte(garbagePEM))
|
||||
if err == nil {
|
||||
t.Fatal("decoded garbage data without complaint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyMarshaling(t *testing.T) {
|
||||
ecKey, err := DecodePrivateKey([]byte(pemECPrivateKeyP256))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pemBytes, _ := EncodePrivateKey(ecKey)
|
||||
if !strings.HasSuffix(pemECPrivateKeyP256, string(pemBytes)) {
|
||||
t.Fatal("private key encoding did not match")
|
||||
}
|
||||
}
|
||||
|
||||
// Test vector from https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||
var jwtTest = []struct {
|
||||
sigBytes []byte
|
||||
b64sig string
|
||||
}{
|
||||
{
|
||||
sigBytes: []byte{14, 209, 33, 83, 121, 99, 108, 72, 60, 47, 127, 21,
|
||||
88, 7, 212, 2, 163, 178, 40, 3, 58, 249, 124, 126, 23, 129, 154, 195, 22, 158,
|
||||
166, 101, 197, 10, 7, 211, 140, 60, 112, 229, 216, 241, 45, 175,
|
||||
8, 74, 84, 128, 166, 101, 144, 197, 242, 147, 80, 154, 143, 63, 127, 138, 131,
|
||||
163, 84, 213},
|
||||
b64sig: "DtEhU3ljbEg8L38VWAfUAqOyKAM6-Xx-F4GawxaepmXFCgfTjDxw5djxLa8ISlSApmWQxfKTUJqPP3-Kg6NU1Q",
|
||||
},
|
||||
}
|
||||
|
||||
func TestJWTEncoding(t *testing.T) {
|
||||
for _, tt := range jwtTest {
|
||||
result := EncodeSignatureJWT(tt.sigBytes)
|
||||
|
||||
if strings.Compare(result, tt.b64sig) != 0 {
|
||||
t.Fatalf("expected %s, got %s\n", tt.b64sig, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTDecoding(t *testing.T) {
|
||||
for _, tt := range jwtTest {
|
||||
resultSig, err := DecodeSignatureJWT(tt.b64sig)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(resultSig, tt.sigBytes) {
|
||||
t.Fatalf("decoded signature was incorrect")
|
||||
}
|
||||
}
|
||||
}
|
84
internal/cryptutil/sign.go
Normal file
84
internal/cryptutil/sign.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519
|
||||
// https://tools.ietf.org/html/rfc7519
|
||||
type JWTSigner interface {
|
||||
SignJWT(string, string) (string, error)
|
||||
}
|
||||
|
||||
// ES256Signer is struct containing the required fields to create a ES256 signed JSON Web Tokens
|
||||
type ES256Signer struct {
|
||||
// User (sub) is unique, stable identifier for the user.
|
||||
// Use in place of the x-pomerium-authenticated-user-id header.
|
||||
User string `json:"sub,omitempty"`
|
||||
// Email (sub) is a **private** claim name identifier for the user email address.
|
||||
// Use in place of the x-pomerium-authenticated-user-email header.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Audience (aud) must be the destination of the upstream proxy locations.
|
||||
// e.g. `helloworld.corp.example.com`
|
||||
Audience jwt.Audience `json:"aud,omitempty"`
|
||||
// Issuer (iss) is the URL of the proxy.
|
||||
// e.g. `proxy.corp.example.com`
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
// Expiry (exp) is the expiration time in seconds since the UNIX epoch.
|
||||
// Allow 1 minute for skew. The maximum lifetime of a token is 10 minutes + 2 * skew.
|
||||
Expiry jwt.NumericDate `json:"exp,omitempty"`
|
||||
// IssuedAt (iat) is the time is measured in seconds since the UNIX epoch.
|
||||
// Allow 1 minute for skew.
|
||||
IssuedAt jwt.NumericDate `json:"iat,omitempty"`
|
||||
// IssuedAt (nbf) is the time is measured in seconds since the UNIX epoch.
|
||||
// Allow 1 minute for skew.
|
||||
NotBefore jwt.NumericDate `json:"nbf,omitempty"`
|
||||
|
||||
signer jose.Signer
|
||||
}
|
||||
|
||||
// NewES256Signer creates an Eliptic Curve, NIST P-256 (aka secp256r1 aka prime256v1) JWT signer.
|
||||
//
|
||||
// RSA is not supported due to performance considerations of needing to sign each request.
|
||||
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
|
||||
// to length extension attacks.
|
||||
// See also:
|
||||
// - https://cloud.google.com/iot/docs/how-tos/credentials/keys
|
||||
func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) {
|
||||
key, err := DecodePrivateKey(privKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/cryptutil parsing key failed %v", err)
|
||||
}
|
||||
signer, err := jose.NewSigner(
|
||||
jose.SigningKey{
|
||||
Algorithm: jose.ES256, // ECDSA using P-256 and SHA-256
|
||||
Key: key,
|
||||
},
|
||||
(&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/cryptutil new signer failed %v", err)
|
||||
}
|
||||
return &ES256Signer{
|
||||
Issuer: "pomerium-proxy",
|
||||
Audience: jwt.Audience{audience},
|
||||
signer: signer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SignJWT creates a signed JWT containing claims for the logged in user id (`sub`) and email (`email`).
|
||||
func (s *ES256Signer) SignJWT(user, email string) (string, error) {
|
||||
s.User = user
|
||||
s.Email = email
|
||||
now := time.Now()
|
||||
s.IssuedAt = jwt.NewNumericDate(now)
|
||||
s.Expiry = jwt.NewNumericDate(now.Add(jwt.DefaultLeeway))
|
||||
s.NotBefore = jwt.NewNumericDate(now.Add(-1 * jwt.DefaultLeeway))
|
||||
rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return rawJWT, nil
|
||||
}
|
44
internal/cryptutil/sign_test.go
Normal file
44
internal/cryptutil/sign_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestES256Signer(t *testing.T) {
|
||||
signer, err := NewES256Signer([]byte(pemECPrivateKeyP256), "destination-url")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if signer == nil {
|
||||
t.Fatal("signer should not be nil")
|
||||
}
|
||||
rawJwt, err := signer.SignJWT("joe-user", "joe-user@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawJwt == "" {
|
||||
t.Fatal("jwt should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewES256Signer(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
privKey []byte
|
||||
audience string
|
||||
wantErr bool
|
||||
}{
|
||||
{"working example", []byte(pemECPrivateKeyP256), "some-domain.com", false},
|
||||
{"bad private key", []byte(garbagePEM), "some-domain.com", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewES256Signer(tt.privKey, tt.audience)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewES256Signer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -130,6 +130,8 @@ func readCertificateFile(certFile, certKeyFile string) (*tls.Certificate, error)
|
|||
}
|
||||
|
||||
// newDefaultTLSConfig creates a new TLS config based on the certificate files given.
|
||||
// see also:
|
||||
// https://wiki.mozilla.org/Security/Server_Side_TLS#Recommended_configurations
|
||||
func newDefaultTLSConfig(cert *tls.Certificate) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
CipherSuites: []uint16{
|
||||
|
|
212
internal/log/handler_log.go
Normal file
212
internal/log/handler_log.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
// Package log provides a set of http.Handler helpers for zerolog.
|
||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/zenazn/goji/web/mutil"
|
||||
)
|
||||
|
||||
// FromRequest gets the logger in the request's context.
|
||||
// This is a shortcut for log.Ctx(r.Context())
|
||||
func FromRequest(r *http.Request) *zerolog.Logger {
|
||||
return Ctx(r.Context())
|
||||
}
|
||||
|
||||
// NewHandler injects log into requests context.
|
||||
func NewHandler(log zerolog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a copy of the logger (including internal context slice)
|
||||
// to prevent data race when using UpdateContext.
|
||||
l := log.With().Logger()
|
||||
r = r.WithContext(l.WithContext(r.Context()))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// URLHandler adds the requested URL as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func URLHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, r.URL.String())
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MethodHandler adds the request method as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func MethodHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, r.Method)
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequestHandler adds the request method and URL as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func RequestHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, r.Method+" "+r.URL.String())
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RemoteAddrHandler adds the request's remote address as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, host)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// UserAgentHandler adds the request's user-agent as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if ua := r.Header.Get("User-Agent"); ua != "" {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, ua)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RefererHandler adds the request's referer as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func RefererHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if ref := r.Header.Get("Referer"); ref != "" {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, ref)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type idKey struct{}
|
||||
|
||||
// IDFromRequest returns the unique id associated to the request if any.
|
||||
func IDFromRequest(r *http.Request) (id string, ok bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
return IDFromCtx(r.Context())
|
||||
}
|
||||
|
||||
// IDFromCtx returns the unique id associated to the context if any.
|
||||
func IDFromCtx(ctx context.Context) (id string, ok bool) {
|
||||
id, ok = ctx.Value(idKey{}).(string)
|
||||
return
|
||||
}
|
||||
|
||||
// RequestIDHandler returns a handler setting a unique id to the request which can
|
||||
// be gathered using IDFromRequest(req). This generated id is added as a field to the
|
||||
// logger using the passed fieldKey as field name. The id is also added as a response
|
||||
// header if the headerName is not empty.
|
||||
func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
id, ok := IDFromRequest(r)
|
||||
if !ok {
|
||||
id = uuid()
|
||||
ctx = context.WithValue(ctx, idKey{}, id)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
if fieldKey != "" {
|
||||
log := zerolog.Ctx(ctx)
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, id)
|
||||
})
|
||||
}
|
||||
if headerName != "" {
|
||||
w.Header().Set(headerName, id)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// AccessHandler returns a handler that call f after each request.
|
||||
func AccessHandler(f func(r *http.Request, status, size int, duration time.Duration)) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
lw := mutil.WrapWriter(w)
|
||||
next.ServeHTTP(lw, r)
|
||||
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardedAddrHandler returns the client IP address from a request. If present, the
|
||||
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
||||
// rightmost entry (the client IP that connected to the LB) is returned.
|
||||
func ForwardedAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
addr := r.RemoteAddr
|
||||
if ra := r.Header.Get("X-Forwarded-For"); ra != "" {
|
||||
forwardedList := strings.Split(ra, ",")
|
||||
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
||||
if forwardedAddr != "" {
|
||||
addr = forwardedAddr
|
||||
}
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, addr)
|
||||
})
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// uuid generates a random 128-bit non-RFC UUID.
|
||||
func uuid() string {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:])
|
||||
}
|
260
internal/log/handler_log_test.go
Normal file
260
internal/log/handler_log_test.go
Normal file
|
@ -0,0 +1,260 @@
|
|||
// Package log provides a set of http.Handler helpers for zerolog.
|
||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestGenerateUUID(t *testing.T) {
|
||||
prev := uuid()
|
||||
for i := 0; i < 100; i++ {
|
||||
id := uuid()
|
||||
if id == "" {
|
||||
t.Fatal("random pool failure")
|
||||
}
|
||||
if prev == id {
|
||||
t.Fatalf("Should get a new ID!")
|
||||
}
|
||||
matched, err := regexp.MatchString("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}", id)
|
||||
if !matched || err != nil {
|
||||
t.Fatalf("expected match %s %v %s", id, matched, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeIfBinary(out *bytes.Buffer) string {
|
||||
// p := out.Bytes()
|
||||
// if len(p) == 0 || p[0] < 0x7F {
|
||||
// return out.String()
|
||||
// }
|
||||
return out.String() //cbor.DecodeObjectToStr(p) + "\n"
|
||||
}
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
log := zerolog.New(nil).With().
|
||||
Str("foo", "bar").
|
||||
Logger()
|
||||
lh := NewHandler(log)
|
||||
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
if !reflect.DeepEqual(*l, log) {
|
||||
t.Fail()
|
||||
}
|
||||
}))
|
||||
h.ServeHTTP(nil, &http.Request{})
|
||||
}
|
||||
|
||||
func TestURLHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
}
|
||||
h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteAddrHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
RemoteAddr: "1.2.3.4:1234",
|
||||
}
|
||||
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteAddrHandlerIPv6(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
RemoteAddr: "[2001:db8:a0b:12f0::1]:1234",
|
||||
}
|
||||
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAgentHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"User-Agent": []string{"some user agent string"},
|
||||
},
|
||||
}
|
||||
h := UserAgentHandler("ua")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"ua":"some user agent string"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefererHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"Referer": []string{"http://foo.com/bar"},
|
||||
},
|
||||
}
|
||||
h := RefererHandler("referer")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"referer":"http://foo.com/bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestIDHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"Referer": []string{"http://foo.com/bar"},
|
||||
},
|
||||
}
|
||||
h := RequestIDHandler("id", "Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id, ok := IDFromRequest(r)
|
||||
if !ok {
|
||||
t.Fatal("Missing id in request")
|
||||
}
|
||||
// if want, got := id.String(), w.Header().Get("Request-Id"); got != want {
|
||||
// t.Errorf("Invalid Request-Id header, got: %s, want: %s", got, want)
|
||||
// }
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
if want, got := fmt.Sprintf(`{"id":"%s"}`+"\n", id), decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(httptest.NewRecorder(), r)
|
||||
}
|
||||
|
||||
func TestCombinedHandlers(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandlers(b *testing.B) {
|
||||
r := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||
}
|
||||
h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h2 := MethodHandler("method")(RequestHandler("request")(h1))
|
||||
handlers := map[string]http.Handler{
|
||||
"Single": NewHandler(zerolog.New(ioutil.Discard))(h1),
|
||||
"Combined": NewHandler(zerolog.New(ioutil.Discard))(h2),
|
||||
"SingleDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h1),
|
||||
"CombinedDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h2),
|
||||
}
|
||||
for name := range handlers {
|
||||
h := handlers[name]
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
h.ServeHTTP(nil, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDataRace(b *testing.B) {
|
||||
log := zerolog.New(nil).With().
|
||||
Str("foo", "bar").
|
||||
Logger()
|
||||
lh := NewHandler(log)
|
||||
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("bar", "baz")
|
||||
})
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
h.ServeHTTP(nil, &http.Request{})
|
||||
}
|
||||
})
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -21,19 +21,6 @@ func With() zerolog.Context {
|
|||
return Logger.With()
|
||||
}
|
||||
|
||||
// WithRequest creates a child logger with the remote user added to its context.
|
||||
func WithRequest(req *http.Request, function string) zerolog.Logger {
|
||||
remoteUser := getRemoteAddr(req)
|
||||
return Logger.With().
|
||||
Str("function", function).
|
||||
Str("req-remote-user", remoteUser).
|
||||
Str("req-http-method", req.Method).
|
||||
Str("req-host", req.Host).
|
||||
Str("req-url", req.URL.String()).
|
||||
// Str("req-user-agent", req.Header.Get("User-Agent")).
|
||||
Logger()
|
||||
}
|
||||
|
||||
// Level creates a child logger with the minimum accepted level set to level.
|
||||
func Level(level zerolog.Level) zerolog.Logger {
|
||||
return Logger.Level(level)
|
||||
|
@ -109,3 +96,9 @@ func Print(v ...interface{}) {
|
|||
func Printf(format string, v ...interface{}) {
|
||||
Logger.Printf(format, v...)
|
||||
}
|
||||
|
||||
// Ctx returns the Logger associated with the ctx. If no logger
|
||||
// is associated, a disabled logger is returned.
|
||||
func Ctx(ctx context.Context) *zerolog.Logger {
|
||||
return zerolog.Ctx(ctx)
|
||||
}
|
||||
|
|
|
@ -1,145 +0,0 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Used to stash the authenticated user in the response for access when logging requests.
|
||||
const loggingUserHeader = "SSO-Authenticated-User"
|
||||
const gapMetaDataHeader = "GAP-Auth"
|
||||
|
||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
||||
// code and body size
|
||||
type responseLogger struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
proxyHost string
|
||||
authInfo string
|
||||
}
|
||||
|
||||
func (l *responseLogger) Header() http.Header {
|
||||
return l.w.Header()
|
||||
}
|
||||
|
||||
func (l *responseLogger) extractUser() {
|
||||
authInfo := l.w.Header().Get(loggingUserHeader)
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
l.w.Header().Del(loggingUserHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *responseLogger) ExtractGAPMetadata() {
|
||||
authInfo := l.w.Header().Get(gapMetaDataHeader)
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
|
||||
l.w.Header().Del(gapMetaDataHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *responseLogger) Write(b []byte) (int, error) {
|
||||
if l.status == 0 {
|
||||
// The status will be StatusOK if WriteHeader has not been called yet
|
||||
l.status = http.StatusOK
|
||||
}
|
||||
l.extractUser()
|
||||
l.ExtractGAPMetadata()
|
||||
|
||||
size, err := l.w.Write(b)
|
||||
l.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (l *responseLogger) WriteHeader(s int) {
|
||||
l.extractUser()
|
||||
l.ExtractGAPMetadata()
|
||||
|
||||
l.w.WriteHeader(s)
|
||||
l.status = s
|
||||
}
|
||||
|
||||
func (l *responseLogger) Status() int {
|
||||
return l.status
|
||||
}
|
||||
|
||||
func (l *responseLogger) Size() int {
|
||||
return l.size
|
||||
}
|
||||
|
||||
func (l *responseLogger) Flush() {
|
||||
f := l.w.(http.Flusher)
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends
|
||||
type loggingHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer.
|
||||
func NewLoggingHandler(h http.Handler) http.Handler {
|
||||
return loggingHandler{
|
||||
handler: h,
|
||||
}
|
||||
}
|
||||
|
||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
t := time.Now()
|
||||
url := *req.URL
|
||||
logger := &responseLogger{w: w, proxyHost: getProxyHost(req)}
|
||||
h.handler.ServeHTTP(logger, req)
|
||||
requestDuration := time.Since(t)
|
||||
|
||||
logRequest(logger.proxyHost, logger.authInfo, req, url, requestDuration, logger.Status())
|
||||
}
|
||||
|
||||
// logRequest logs information about a request
|
||||
func logRequest(proxyHost, username string, req *http.Request, url url.URL, requestDuration time.Duration, status int) {
|
||||
uri := req.Host + url.RequestURI()
|
||||
Info().
|
||||
Int("http-status", status).
|
||||
Str("request-method", req.Method).
|
||||
Str("request-uri", uri).
|
||||
Str("proxy-host", proxyHost).
|
||||
// Str("user-agent", req.Header.Get("User-Agent")).
|
||||
Str("remote-address", getRemoteAddr(req)).
|
||||
Dur("duration", requestDuration).
|
||||
Str("user", username).
|
||||
Msg("request")
|
||||
|
||||
}
|
||||
|
||||
// getRemoteAddr returns the client IP address from a request. If present, the
|
||||
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
||||
// rightmost entry (the client IP that connected to the LB) is returned.
|
||||
func getRemoteAddr(req *http.Request) string {
|
||||
addr := req.RemoteAddr
|
||||
forwardedHeader := req.Header.Get("X-Forwarded-For")
|
||||
if forwardedHeader != "" {
|
||||
forwardedList := strings.Split(forwardedHeader, ",")
|
||||
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
||||
if forwardedAddr != "" {
|
||||
addr = forwardedAddr
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// getProxyHost attempts to get the proxy host from the redirect_uri parameter
|
||||
func getProxyHost(req *http.Request) string {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
redirect := req.Form.Get("redirect_uri")
|
||||
redirectURL, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return redirectURL.Host
|
||||
}
|
|
@ -1,72 +0,0 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetRemoteAddr(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
forwardedHeader string
|
||||
expectedAddr string
|
||||
}{
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is given",
|
||||
remoteAddr: "1.1.1.1",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is only whitespace",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " , , ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For header is preferred to RemoteAddr",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "9.9.9.9",
|
||||
expectedAddr: "9.9.9.9",
|
||||
},
|
||||
{
|
||||
name: "rightmost entry in X-Forwarded-For header is used",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5",
|
||||
expectedAddr: "5.5.5.5",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "2.2.2.2, 3.3.3.3, ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwaded-For header entries are stripped",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ",
|
||||
expectedAddr: "5.5.5.5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
if tc.forwardedHeader != "" {
|
||||
req.Header.Set("X-Forwarded-For", tc.forwardedHeader)
|
||||
}
|
||||
|
||||
addr := getRemoteAddr(req)
|
||||
if addr != tc.expectedAddr {
|
||||
t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
109
internal/middleware/chain.go
Normal file
109
internal/middleware/chain.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Constructor is a type alias for func(http.Handler) http.Handler
|
||||
type Constructor func(http.Handler) http.Handler
|
||||
|
||||
// Chain acts as a list of http.Handler constructors.
|
||||
// Chain is effectively immutable:
|
||||
// once created, it will always hold
|
||||
// the same set of constructors in the same order.
|
||||
type Chain struct {
|
||||
constructors []Constructor
|
||||
}
|
||||
|
||||
// NewChain creates a new chain,
|
||||
// memorizing the given list of middleware constructors.
|
||||
// New serves no other function,
|
||||
// constructors are only called upon a call to Then().
|
||||
func NewChain(constructors ...Constructor) Chain {
|
||||
return Chain{append(([]Constructor)(nil), constructors...)}
|
||||
}
|
||||
|
||||
// Then chains the middleware and returns the final http.Handler.
|
||||
// NewChain(m1, m2, m3).Then(h)
|
||||
// is equivalent to:
|
||||
// m1(m2(m3(h)))
|
||||
// When the request comes in, it will be passed to m1, then m2, then m3
|
||||
// and finally, the given handler
|
||||
// (assuming every middleware calls the following one).
|
||||
//
|
||||
// A chain can be safely reused by calling Then() several times.
|
||||
// stdStack := middleware.NewChain(ratelimitHandler, csrfHandler)
|
||||
// indexPipe = stdStack.Then(indexHandler)
|
||||
// authPipe = stdStack.Then(authHandler)
|
||||
// Note that constructors are called on every call to Then()
|
||||
// and thus several instances of the same middleware will be created
|
||||
// when a chain is reused in this way.
|
||||
// For proper middleware, this should cause no problems.
|
||||
//
|
||||
// Then() treats nil as http.DefaultServeMux.
|
||||
func (c Chain) Then(h http.Handler) http.Handler {
|
||||
if h == nil {
|
||||
h = http.DefaultServeMux
|
||||
}
|
||||
|
||||
for i := range c.constructors {
|
||||
h = c.constructors[len(c.constructors)-1-i](h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ThenFunc works identically to Then, but takes
|
||||
// a HandlerFunc instead of a Handler.
|
||||
//
|
||||
// The following two statements are equivalent:
|
||||
// c.Then(http.HandlerFunc(fn))
|
||||
// c.ThenFunc(fn)
|
||||
//
|
||||
// ThenFunc provides all the guarantees of Then.
|
||||
func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler {
|
||||
if fn == nil {
|
||||
return c.Then(nil)
|
||||
}
|
||||
return c.Then(fn)
|
||||
}
|
||||
|
||||
// Append extends a chain, adding the specified constructors
|
||||
// as the last ones in the request flow.
|
||||
//
|
||||
// Append returns a new chain, leaving the original one untouched.
|
||||
//
|
||||
// stdChain := middleware.NewChain(m1, m2)
|
||||
// extChain := stdChain.Append(m3, m4)
|
||||
// // requests in stdChain go m1 -> m2
|
||||
// // requests in extChain go m1 -> m2 -> m3 -> m4
|
||||
func (c Chain) Append(constructors ...Constructor) Chain {
|
||||
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
|
||||
newCons = append(newCons, c.constructors...)
|
||||
newCons = append(newCons, constructors...)
|
||||
|
||||
return Chain{newCons}
|
||||
}
|
||||
|
||||
// Extend extends a chain by adding the specified chain
|
||||
// as the last one in the request flow.
|
||||
//
|
||||
// Extend returns a new chain, leaving the original one untouched.
|
||||
//
|
||||
// stdChain := middleware.NewChain(m1, m2)
|
||||
// ext1Chain := middleware.NewChain(m3, m4)
|
||||
// ext2Chain := stdChain.Extend(ext1Chain)
|
||||
// // requests in stdChain go m1 -> m2
|
||||
// // requests in ext1Chain go m3 -> m4
|
||||
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
|
||||
//
|
||||
// Another example:
|
||||
// aHtmlAfterNosurf := middleware.NewChain(m2)
|
||||
// aHtml := middleware.NewChain(m1, func(h http.Handler) http.Handler {
|
||||
// csrf := nosurf.NewChain(h)
|
||||
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
|
||||
// return csrf
|
||||
// }).Extend(aHtmlAfterNosurf)
|
||||
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
|
||||
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
|
||||
func (c Chain) Extend(chain Chain) Chain {
|
||||
return c.Append(chain.constructors...)
|
||||
}
|
177
internal/middleware/chain_test.go
Normal file
177
internal/middleware/chain_test.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// A constructor for middleware
|
||||
// that writes its own "tag" into the RW and does nothing else.
|
||||
// Useful in checking if a chain is behaving in the right order.
|
||||
func tagMiddleware(tag string) Constructor {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(tag))
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
|
||||
// but the best we can do.
|
||||
func funcsEqual(f1, f2 interface{}) bool {
|
||||
val1 := reflect.ValueOf(f1)
|
||||
val2 := reflect.ValueOf(f2)
|
||||
return val1.Pointer() == val2.Pointer()
|
||||
}
|
||||
|
||||
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("app\n"))
|
||||
})
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
c1 := func(h http.Handler) http.Handler {
|
||||
return nil
|
||||
}
|
||||
|
||||
c2 := func(h http.Handler) http.Handler {
|
||||
return http.StripPrefix("potato", nil)
|
||||
}
|
||||
|
||||
slice := []Constructor{c1, c2}
|
||||
|
||||
chain := NewChain(slice...)
|
||||
for k := range slice {
|
||||
if !funcsEqual(chain.constructors[k], slice[k]) {
|
||||
t.Error("New does not add constructors correctly")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenWorksWithNoMiddleware(t *testing.T) {
|
||||
if !funcsEqual(NewChain().Then(testApp), testApp) {
|
||||
t.Error("Then does not work with no middleware")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||
if NewChain().Then(nil) != http.DefaultServeMux {
|
||||
t.Error("Then does not treat nil as DefaultServeMux")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||
if NewChain().ThenFunc(nil) != http.DefaultServeMux {
|
||||
t.Error("ThenFunc does not treat nil as DefaultServeMux")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
chained := NewChain().ThenFunc(fn)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
chained.ServeHTTP(rec, (*http.Request)(nil))
|
||||
|
||||
if reflect.TypeOf(chained) != reflect.TypeOf((http.HandlerFunc)(nil)) {
|
||||
t.Error("ThenFunc does not construct HandlerFunc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenOrdersHandlersCorrectly(t *testing.T) {
|
||||
t1 := tagMiddleware("t1\n")
|
||||
t2 := tagMiddleware("t2\n")
|
||||
t3 := tagMiddleware("t3\n")
|
||||
|
||||
chained := NewChain(t1, t2, t3).Then(testApp)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
chained.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() != "t1\nt2\nt3\napp\n" {
|
||||
t.Error("Then does not order handlers correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendAddsHandlersCorrectly(t *testing.T) {
|
||||
chain := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||
|
||||
if len(chain.constructors) != 2 {
|
||||
t.Error("chain should have 2 constructors")
|
||||
}
|
||||
if len(newChain.constructors) != 4 {
|
||||
t.Error("newChain should have 4 constructors")
|
||||
}
|
||||
|
||||
chained := newChain.Then(testApp)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
chained.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
|
||||
t.Error("Append does not add handlers correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendRespectsImmutability(t *testing.T) {
|
||||
chain := NewChain(tagMiddleware(""))
|
||||
newChain := chain.Append(tagMiddleware(""))
|
||||
|
||||
if &chain.constructors[0] == &newChain.constructors[0] {
|
||||
t.Error("Apppend does not respect immutability")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtendAddsHandlersCorrectly(t *testing.T) {
|
||||
chain1 := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||
chain2 := NewChain(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||
newChain := chain1.Extend(chain2)
|
||||
|
||||
if len(chain1.constructors) != 2 {
|
||||
t.Error("chain1 should contain 2 constructors")
|
||||
}
|
||||
if len(chain2.constructors) != 2 {
|
||||
t.Error("chain2 should contain 2 constructors")
|
||||
}
|
||||
if len(newChain.constructors) != 4 {
|
||||
t.Error("newChain should contain 4 constructors")
|
||||
}
|
||||
|
||||
chained := newChain.Then(testApp)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
chained.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
|
||||
t.Error("Extend does not add handlers in correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtendRespectsImmutability(t *testing.T) {
|
||||
chain := NewChain(tagMiddleware(""))
|
||||
newChain := chain.Extend(NewChain(tagMiddleware("")))
|
||||
|
||||
if &chain.constructors[0] == &newChain.constructors[0] {
|
||||
t.Error("Extend does not respect immutability")
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
// Package middleware provides a standard set of middleware implementations for pomerium.
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
|
@ -14,8 +15,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
// SetHeaders ensures that every response includes some basic security headers
|
||||
func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler {
|
||||
// SetHeadersOld ensures that every response includes some basic security headers
|
||||
func SetHeadersOld(h http.Handler, securityHeaders map[string]string) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for key, val := range securityHeaders {
|
||||
rw.Header().Set(key, val)
|
||||
|
@ -24,6 +25,18 @@ func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler
|
|||
})
|
||||
}
|
||||
|
||||
// SetHeaders ensures that every response includes some basic security headers
|
||||
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for key, val := range securityHeaders {
|
||||
rw.Header().Set(key, val)
|
||||
}
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithMethods writes an error response if the method of the request is not included.
|
||||
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
||||
methodMap := make(map[string]struct{}, len(methods))
|
||||
|
@ -116,14 +129,17 @@ func ValidateSignature(f http.HandlerFunc, sharedSecret string) http.HandlerFunc
|
|||
}
|
||||
|
||||
// ValidateHost ensures that each request's host is valid
|
||||
func ValidateHost(h http.Handler, mux map[string]*http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := mux[req.Host]; !ok {
|
||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
})
|
||||
func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := mux[req.Host]; !ok {
|
||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireHTTPS reroutes a HTTP request to HTTPS
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue