mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 08:19:23 +02:00
Merge pull request #187 from travisgroth/feature/metrics
Proxy Client Metrics
This commit is contained in:
commit
7191ed6fb1
7 changed files with 454 additions and 58 deletions
|
@ -181,7 +181,10 @@ Expose a prometheus format HTTP endpoint on the specified port. Disabled by def
|
|||
|:------------- |:-------------|:-----|
|
||||
|http_server_requests_total| Counter | Total HTTP server requests handled by service|
|
||||
|http_server_response_size_bytes| Histogram | HTTP server response size by service|
|
||||
|http_server_request_duration_ms| Histogram | HTTP server request duration by service\
|
||||
|http_server_request_duration_ms| Histogram | HTTP server request duration by service|
|
||||
|http_client_requests_total| Counter | Total HTTP client requests made by service|
|
||||
|http_client_response_size_bytes| Histogram | HTTP client response size by service|
|
||||
|http_client_request_duration_ms| Histogram | HTTP client request duration by service|
|
||||
|
||||
### Policy
|
||||
|
||||
|
|
|
@ -6,11 +6,11 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||
"github.com/pomerium/pomerium/internal/tripper"
|
||||
"go.opencensus.io/stats"
|
||||
"go.opencensus.io/stats/view"
|
||||
"go.opencensus.io/tag"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -19,22 +19,27 @@ var (
|
|||
keyService, _ = tag.NewKey("service")
|
||||
keyHost, _ = tag.NewKey("host")
|
||||
|
||||
httpRequestCount = stats.Int64("http_server_requests_total", "Total HTTP Requests", "1")
|
||||
httpResponseSize = stats.Int64("http_server_response_size_bytes", "HTTP Server Response Size in bytes", "bytes")
|
||||
httpRequestDuration = stats.Int64("http_server_request_duration_ms", "HTTP Request duration in ms", "ms")
|
||||
httpServerRequestCount = stats.Int64("http_server_requests_total", "Total HTTP Requests", "1")
|
||||
httpServerResponseSize = stats.Int64("http_server_response_size_bytes", "HTTP Server Response Size in bytes", "bytes")
|
||||
httpServerRequestDuration = stats.Int64("http_server_request_duration_ms", "HTTP Request duration in ms", "ms")
|
||||
|
||||
httpClientRequestCount = stats.Int64("http_client_requests_total", "Total HTTP Client Requests", "1")
|
||||
httpClientResponseSize = stats.Int64("http_client_response_size_bytes", "HTTP Client Response Size in bytes", "bytes")
|
||||
httpClientRequestDuration = stats.Int64("http_client_request_duration_ms", "HTTP Client Request duration in ms", "ms")
|
||||
|
||||
views = []*view.View{
|
||||
//HTTP Server
|
||||
{
|
||||
Name: httpRequestCount.Name(),
|
||||
Measure: httpRequestCount,
|
||||
Description: httpRequestCount.Description(),
|
||||
Name: httpServerRequestCount.Name(),
|
||||
Measure: httpServerRequestCount,
|
||||
Description: httpServerRequestCount.Description(),
|
||||
TagKeys: []tag.Key{keyService, keyHost, keyMethod, keyStatus},
|
||||
Aggregation: view.Count(),
|
||||
},
|
||||
{
|
||||
Name: httpRequestDuration.Name(),
|
||||
Measure: httpRequestDuration,
|
||||
Description: httpRequestDuration.Description(),
|
||||
Name: httpServerRequestDuration.Name(),
|
||||
Measure: httpServerRequestDuration,
|
||||
Description: httpServerRequestDuration.Description(),
|
||||
TagKeys: []tag.Key{keyService, keyHost, keyMethod, keyStatus},
|
||||
Aggregation: view.Distribution(
|
||||
1, 2, 5, 7, 10, 25, 500, 750,
|
||||
|
@ -45,9 +50,41 @@ var (
|
|||
),
|
||||
},
|
||||
{
|
||||
Name: httpResponseSize.Name(),
|
||||
Measure: httpResponseSize,
|
||||
Description: httpResponseSize.Description(),
|
||||
Name: httpServerResponseSize.Name(),
|
||||
Measure: httpServerResponseSize,
|
||||
Description: httpServerResponseSize.Description(),
|
||||
TagKeys: []tag.Key{keyService, keyHost, keyMethod, keyStatus},
|
||||
Aggregation: view.Distribution(
|
||||
1, 256, 512, 1024, 2048, 8192, 16384, 32768, 65536, 131072, 262144, 524288,
|
||||
1048576, 2097152, 4194304, 8388608,
|
||||
),
|
||||
},
|
||||
|
||||
//HTTP Client
|
||||
{
|
||||
Name: httpClientRequestCount.Name(),
|
||||
Measure: httpClientRequestCount,
|
||||
Description: httpClientRequestCount.Description(),
|
||||
TagKeys: []tag.Key{keyService, keyHost, keyMethod, keyStatus},
|
||||
Aggregation: view.Count(),
|
||||
},
|
||||
{
|
||||
Name: httpClientRequestDuration.Name(),
|
||||
Measure: httpClientRequestDuration,
|
||||
Description: httpClientRequestDuration.Description(),
|
||||
TagKeys: []tag.Key{keyService, keyHost, keyMethod, keyStatus},
|
||||
Aggregation: view.Distribution(
|
||||
1, 2, 5, 7, 10, 25, 500, 750,
|
||||
100, 250, 500, 750,
|
||||
1000, 2500, 5000, 7500,
|
||||
10000, 25000, 50000, 75000,
|
||||
100000,
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: httpClientResponseSize.Name(),
|
||||
Measure: httpClientResponseSize,
|
||||
Description: httpClientResponseSize.Description(),
|
||||
TagKeys: []tag.Key{keyService, keyHost, keyMethod, keyStatus},
|
||||
Aggregation: view.Distribution(
|
||||
1, 256, 512, 1024, 2048, 8192, 16384, 32768, 65536, 131072, 262144, 524288,
|
||||
|
@ -71,18 +108,52 @@ func HTTPMetricsHandler(service string) func(next http.Handler) http.Handler {
|
|||
|
||||
next.ServeHTTP(m, r)
|
||||
|
||||
ctx, _ := tag.New(
|
||||
ctx, tagErr := tag.New(
|
||||
context.Background(),
|
||||
tag.Insert(keyService, service),
|
||||
tag.Insert(keyHost, r.Host),
|
||||
tag.Insert(keyMethod, r.Method),
|
||||
tag.Insert(keyStatus, strconv.Itoa(m.Status())),
|
||||
)
|
||||
stats.Record(ctx,
|
||||
httpRequestCount.M(1),
|
||||
httpRequestDuration.M(time.Since(startTime).Nanoseconds()/int64(time.Millisecond)),
|
||||
httpResponseSize.M(int64(m.BytesWritten())),
|
||||
)
|
||||
|
||||
if tagErr == nil {
|
||||
stats.Record(ctx,
|
||||
httpServerRequestCount.M(1),
|
||||
httpServerRequestDuration.M(time.Since(startTime).Nanoseconds()/int64(time.Millisecond)),
|
||||
httpServerResponseSize.M(int64(m.BytesWritten())),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPMetricsRoundTripper creates a metrics tracking tripper for outbound HTTP Requests
|
||||
func HTTPMetricsRoundTripper(service string) func(next http.RoundTripper) http.RoundTripper {
|
||||
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return tripper.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
resp, err := next.RoundTrip(r)
|
||||
|
||||
if resp != nil && err == nil {
|
||||
ctx, tagErr := tag.New(
|
||||
context.Background(),
|
||||
tag.Insert(keyService, service),
|
||||
tag.Insert(keyHost, r.Host),
|
||||
tag.Insert(keyMethod, r.Method),
|
||||
tag.Insert(keyStatus, strconv.Itoa(resp.StatusCode)),
|
||||
)
|
||||
|
||||
if tagErr == nil {
|
||||
stats.Record(ctx,
|
||||
httpClientRequestCount.M(1),
|
||||
httpClientRequestDuration.M(time.Since(startTime).Nanoseconds()/int64(time.Millisecond)),
|
||||
httpClientResponseSize.M(resp.ContentLength),
|
||||
)
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package metrics // import "github.com/pomerium/pomerium/internal/metrics"
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -9,6 +10,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/tripper"
|
||||
"go.opencensus.io/stats/view"
|
||||
)
|
||||
|
||||
|
@ -34,36 +36,36 @@ func Test_HTTPMetricsHandler(t *testing.T) {
|
|||
chainHandler := chain.Then(newTestMux())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
verb string
|
||||
wanthttpResponseSize string
|
||||
wanthttpRequestDuration string
|
||||
wanthttpRequestCount string
|
||||
name string
|
||||
url string
|
||||
verb string
|
||||
wanthttpServerResponseSize string
|
||||
wanthttpServerRequestDuration string
|
||||
wanthttpServerRequestCount string
|
||||
}{
|
||||
{
|
||||
name: "good get",
|
||||
url: "http://test.local/good",
|
||||
verb: "GET",
|
||||
wanthttpResponseSize: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1 5 5 5 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpRequestDuration: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1",
|
||||
wanthttpRequestCount: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1",
|
||||
name: "good get",
|
||||
url: "http://test.local/good",
|
||||
verb: "GET",
|
||||
wanthttpServerResponseSize: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1 5 5 5 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpServerRequestDuration: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1",
|
||||
wanthttpServerRequestCount: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1",
|
||||
},
|
||||
{
|
||||
name: "good post",
|
||||
url: "http://test.local/good",
|
||||
verb: "POST",
|
||||
wanthttpResponseSize: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1 5 5 5 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpRequestDuration: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1",
|
||||
wanthttpRequestCount: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1",
|
||||
name: "good post",
|
||||
url: "http://test.local/good",
|
||||
verb: "POST",
|
||||
wanthttpServerResponseSize: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1 5 5 5 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpServerRequestDuration: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1",
|
||||
wanthttpServerRequestCount: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1",
|
||||
},
|
||||
{
|
||||
name: "bad post",
|
||||
url: "http://test.local/bad",
|
||||
verb: "POST",
|
||||
wanthttpResponseSize: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1 19 19 19 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpRequestDuration: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1",
|
||||
wanthttpRequestCount: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1",
|
||||
name: "bad post",
|
||||
url: "http://test.local/bad",
|
||||
verb: "POST",
|
||||
wanthttpServerResponseSize: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1 19 19 19 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpServerRequestDuration: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1",
|
||||
wanthttpServerRequestCount: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -76,35 +78,145 @@ func Test_HTTPMetricsHandler(t *testing.T) {
|
|||
chainHandler.ServeHTTP(rec, req)
|
||||
|
||||
// httpResponseSize
|
||||
data, _ := view.RetrieveData(httpResponseSize.Name())
|
||||
data, _ := view.RetrieveData(httpServerResponseSize.Name())
|
||||
if len(data) != 1 {
|
||||
t.Errorf("httpResponseSize: received wrong number of data rows: %d", len(data))
|
||||
t.Errorf("httpServerResponseSize: received wrong number of data rows: %d", len(data))
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpResponseSize) {
|
||||
t.Errorf("httpResponseSize: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpResponseSize, data[0].String())
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpServerResponseSize) {
|
||||
t.Errorf("httpServerResponseSize: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpServerResponseSize, data[0].String())
|
||||
}
|
||||
|
||||
// httpResponseSize
|
||||
data, _ = view.RetrieveData(httpRequestDuration.Name())
|
||||
// httpRequestDuration
|
||||
data, _ = view.RetrieveData(httpServerRequestDuration.Name())
|
||||
if len(data) != 1 {
|
||||
t.Errorf("httpRequestDuration: received too many data rows: %d", len(data))
|
||||
t.Errorf("httpServerRequestDuration: received too many data rows: %d", len(data))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpRequestDuration) {
|
||||
t.Errorf("httpRequestDuration: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpRequestDuration, data[0].String())
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpServerRequestDuration) {
|
||||
t.Errorf("httpServerRequestDuration: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpServerRequestDuration, data[0].String())
|
||||
}
|
||||
|
||||
// httpRequestCount
|
||||
data, _ = view.RetrieveData(httpRequestCount.Name())
|
||||
data, _ = view.RetrieveData(httpServerRequestCount.Name())
|
||||
if len(data) != 1 {
|
||||
t.Errorf("httpRequestCount: received too many data rows: %d", len(data))
|
||||
t.Errorf("httpServerRequestCount: received too many data rows: %d", len(data))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpRequestCount) {
|
||||
t.Errorf("httpRequestCount: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpRequestCount, data[0].String())
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpServerRequestCount) {
|
||||
t.Errorf("httpServerRequestCount: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpServerRequestCount, data[0].String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestTransport() http.RoundTripper {
|
||||
return tripper.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
resp := httptest.NewRecorder()
|
||||
newTestMux().ServeHTTP(resp, r)
|
||||
resp.Flush()
|
||||
result := resp.Result()
|
||||
|
||||
// This really looks like a regression / bug?
|
||||
// https://github.com/golang/go/issues/16952
|
||||
result.ContentLength = int64(len(resp.Body.Bytes()))
|
||||
return result, nil
|
||||
})
|
||||
}
|
||||
|
||||
func newFailingTestTransport() http.RoundTripper {
|
||||
return tripper.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("failure")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_HTTPMetricsRoundTripper(t *testing.T) {
|
||||
chain := tripper.NewChain(HTTPMetricsRoundTripper("test_service"))
|
||||
rt := chain.Then(newTestTransport())
|
||||
client := http.Client{Transport: rt}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
verb string
|
||||
wanthttpClientResponseSize string
|
||||
wanthttpClientRequestDuration string
|
||||
wanthttpClientRequestCount string
|
||||
}{
|
||||
{
|
||||
name: "good get",
|
||||
url: "http://test.local/good",
|
||||
verb: "GET",
|
||||
wanthttpClientResponseSize: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1 5 5 5 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpClientRequestDuration: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1",
|
||||
wanthttpClientRequestCount: "{ { {host test.local}{method GET}{service test_service}{status 200} }&{1",
|
||||
},
|
||||
{
|
||||
name: "good post",
|
||||
url: "http://test.local/good",
|
||||
verb: "POST",
|
||||
wanthttpClientResponseSize: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1 5 5 5 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpClientRequestDuration: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1",
|
||||
wanthttpClientRequestCount: "{ { {host test.local}{method POST}{service test_service}{status 200} }&{1",
|
||||
},
|
||||
{
|
||||
name: "bad post",
|
||||
url: "http://test.local/bad",
|
||||
verb: "POST",
|
||||
wanthttpClientResponseSize: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1 19 19 19 0 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]",
|
||||
wanthttpClientRequestDuration: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1",
|
||||
wanthttpClientRequestCount: "{ { {host test.local}{method POST}{service test_service}{status 404} }&{1",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
view.Unregister(views...)
|
||||
view.Register(views...)
|
||||
|
||||
req, _ := http.NewRequest(tt.verb, tt.url, new(bytes.Buffer))
|
||||
resp, err := client.Do(req)
|
||||
|
||||
t.Logf("response: %#v, %#v", resp, err)
|
||||
|
||||
// httpClientResponseSize
|
||||
data, _ := view.RetrieveData(httpClientResponseSize.Name())
|
||||
if len(data) != 1 {
|
||||
t.Errorf("httpClientResponseSize: received wrong number of data rows: %d", len(data))
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpClientResponseSize) {
|
||||
t.Errorf("httpResponseSize: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpClientResponseSize, data[0].String())
|
||||
}
|
||||
|
||||
// httpClientRequestDuration
|
||||
data, _ = view.RetrieveData(httpClientRequestDuration.Name())
|
||||
if len(data) != 1 {
|
||||
t.Errorf("httpClientRequestDuration: received too many data rows: %d", len(data))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpClientRequestDuration) {
|
||||
t.Errorf("httpClientRequestDuration: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpClientRequestDuration, data[0].String())
|
||||
}
|
||||
|
||||
// httpClientRequestCount
|
||||
data, _ = view.RetrieveData(httpClientRequestCount.Name())
|
||||
if len(data) != 1 {
|
||||
t.Errorf("httpRequestCount: received too many data rows: %d", len(data))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(data[0].String(), tt.wanthttpClientRequestCount) {
|
||||
t.Errorf("httpRequestCount: Found unexpected data row: \nwant: %s\ngot: %s\n", tt.wanthttpClientRequestCount, data[0].String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Check for transport Errors
|
||||
client = http.Client{Transport: chain.Then(newFailingTestTransport())}
|
||||
req, _ := http.NewRequest("GET", "http://test.local", new(bytes.Buffer))
|
||||
resp, err := client.Do(req)
|
||||
if err == nil || resp != nil {
|
||||
t.Error("Transport error not surfaced properly")
|
||||
}
|
||||
}
|
||||
|
|
69
internal/tripper/chain.go
Normal file
69
internal/tripper/chain.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package tripper // import "github.com/pomerium/pomerium/internal/tripper"
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Constructor is a type alias for func(http.RoundTripper) http.RoundTripper
|
||||
type Constructor func(http.RoundTripper) http.RoundTripper
|
||||
|
||||
// Chain acts as a list of http.RoundTripper constructors.
|
||||
// Chain is effectively immutable:
|
||||
// once created, it will always hold
|
||||
// the same set of constructors in the same order.
|
||||
type Chain struct {
|
||||
constructors []Constructor
|
||||
}
|
||||
|
||||
// NewChain creates a new chain,
|
||||
// memorizing the given list of tripper constructors.
|
||||
// New serves no other function,
|
||||
// constructors are only called upon a call to Then().
|
||||
func NewChain(constructors ...Constructor) Chain {
|
||||
return Chain{append(([]Constructor)(nil), constructors...)}
|
||||
}
|
||||
|
||||
// Then chains the trippers and returns the final http.RoundTripper.
|
||||
// NewChain(m1, m2, m3).Then(h)
|
||||
// is equivalent to:
|
||||
// m1(m2(m3(h)))
|
||||
// When the request comes in, it will be passed to m1, then m2, then m3
|
||||
// and finally, the given roundtripper
|
||||
// (assuming every tripper calls the following one).
|
||||
//
|
||||
// A chain can be safely reused by calling Then() several times.
|
||||
// stdStack := tripper.NewChain(ratelimitTripper, csrfTripper)
|
||||
// tracePipe = stdStack.Then(traceTripper)
|
||||
// authPipe = stdStack.Then(authTripper)
|
||||
// Note that constructors are called on every call to Then()
|
||||
// and thus several instances of the same tripper will be created
|
||||
// when a chain is reused in this way.
|
||||
// For proper tripper implementations, this should cause no problems.
|
||||
//
|
||||
// Then() treats nil as http.DefaultTransport.
|
||||
func (c Chain) Then(h http.RoundTripper) http.RoundTripper {
|
||||
if h == nil {
|
||||
h = http.DefaultTransport
|
||||
}
|
||||
|
||||
for i := range c.constructors {
|
||||
h = c.constructors[len(c.constructors)-1-i](h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Append extends a chain, adding the specified constructors
|
||||
// as the last ones in the request flow.
|
||||
//
|
||||
// Append returns a new chain, leaving the original one untouched.
|
||||
//
|
||||
// stdChain := middleware.NewChain(m1, m2)
|
||||
// extChain := stdChain.Append(m3, m4)
|
||||
// // requests in stdChain go m1 -> m2
|
||||
// // requests in extChain go m1 -> m2 -> m3 -> m4
|
||||
func (c Chain) Append(constructors ...Constructor) Chain {
|
||||
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
|
||||
newCons = append(newCons, c.constructors...)
|
||||
newCons = append(newCons, constructors...)
|
||||
|
||||
return Chain{newCons}
|
||||
}
|
123
internal/tripper/chain_test.go
Normal file
123
internal/tripper/chain_test.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package tripper // import "github.com/pomerium/pomerium/internal/tripper"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockTransport struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (t *mockTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
w.WriteString(t.id)
|
||||
return w.Result(), nil
|
||||
}
|
||||
|
||||
// mockMiddleware appends the id into the response body as
|
||||
// the call stack unwinds.
|
||||
//
|
||||
// If your chain is c1->c2->t, it should return 't,c2,c1'
|
||||
func mockMiddleware(id string) func(next http.RoundTripper) http.RoundTripper {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
|
||||
resp, _ := next.RoundTrip(r)
|
||||
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
mockResp := httptest.NewRecorder()
|
||||
mockResp.Write(body)
|
||||
mockResp.WriteString(fmt.Sprintf(",%s", id))
|
||||
return mockResp.Result(), nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
m1 := mockMiddleware("c1")
|
||||
m2 := mockMiddleware("c2")
|
||||
t1 := &mockTransport{id: "t"}
|
||||
want := "t,c2,c1"
|
||||
|
||||
chain := NewChain(m1, m2)
|
||||
|
||||
resp, _ := chain.Then(t1).
|
||||
RoundTrip(httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
if len(chain.constructors) != 2 {
|
||||
t.Errorf("Wrong number of constructors in chain")
|
||||
}
|
||||
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
if string(b) != want {
|
||||
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenNoMiddleware(t *testing.T) {
|
||||
chain := NewChain()
|
||||
t1 := &mockTransport{id: "t"}
|
||||
want := "t"
|
||||
|
||||
resp, _ := chain.Then(t1).
|
||||
RoundTrip(httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
if string(b) != want {
|
||||
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilThen(t *testing.T) {
|
||||
if NewChain().Then(nil) != http.DefaultTransport {
|
||||
t.Error("Then does not treat nil as DefaultTransport")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAppend(t *testing.T) {
|
||||
chain := NewChain(mockMiddleware("c1"))
|
||||
if len(chain.constructors) != 1 {
|
||||
t.Errorf("Wrong number of constructors in chain")
|
||||
}
|
||||
|
||||
chain = chain.Append(mockMiddleware("c2"))
|
||||
t1 := &mockTransport{id: "t"}
|
||||
want := "t,c2,c1"
|
||||
|
||||
resp, _ := chain.Then(t1).
|
||||
RoundTrip(httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
if len(chain.constructors) != 2 {
|
||||
t.Errorf("Wrong number of constructors in chain")
|
||||
}
|
||||
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
if string(b) != want {
|
||||
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendImmutability(t *testing.T) {
|
||||
chain := NewChain(mockMiddleware("c1"))
|
||||
chain.Append(mockMiddleware("c2"))
|
||||
t1 := &mockTransport{id: "t"}
|
||||
want := "t,c1"
|
||||
|
||||
if len(chain.constructors) != 1 {
|
||||
t.Errorf("Append does not respect immutability")
|
||||
}
|
||||
|
||||
resp, _ := chain.Then(t1).
|
||||
RoundTrip(httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
if string(b) != want {
|
||||
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
|
||||
}
|
||||
}
|
13
internal/tripper/roundtripper.go
Normal file
13
internal/tripper/roundtripper.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package tripper // import "github.com/pomerium/pomerium/internal/tripper"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// RoundTripperFunc wraps a function in a RoundTripper interface similar to HandlerFunc
|
||||
type RoundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
// RoundTrip calls the underlying tripper function in the RoundTripperFunc
|
||||
func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
|
@ -15,9 +15,11 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/metrics"
|
||||
"github.com/pomerium/pomerium/internal/policy"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
"github.com/pomerium/pomerium/internal/tripper"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
)
|
||||
|
||||
|
@ -251,6 +253,9 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
|||
director(req)
|
||||
req.Host = to.Host
|
||||
}
|
||||
|
||||
chain := tripper.NewChain().Append(metrics.HTTPMetricsRoundTripper("proxy"))
|
||||
proxy.Transport = chain.Then(nil)
|
||||
return proxy
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue