diff --git a/internal/httputil/client.go b/internal/httputil/client.go index c9a396428..422375b1a 100644 --- a/internal/httputil/client.go +++ b/internal/httputil/client.go @@ -12,19 +12,29 @@ import ( "net/url" "time" - "go.opencensus.io/plugin/ochttp" - + "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/requestid" + "github.com/pomerium/pomerium/internal/tripper" ) // ErrTokenRevoked signifies a token revokation or expiration error var ErrTokenRevoked = errors.New("token expired or revoked") +type httpClient struct { + *http.Client + requestIDTripper http.RoundTripper +} + +func (c *httpClient) Do(req *http.Request) (*http.Response, error) { + tripperChain := tripper.NewChain(metrics.HTTPMetricsRoundTripper("idp_http_client", req.Host)) + c.Client.Transport = tripperChain.Then(c.requestIDTripper) + return c.Client.Do(req) +} + // DefaultClient avoids leaks by setting an upper limit for timeouts. -var DefaultClient = &http.Client{ - Timeout: 1 * time.Minute, - //todo(bdd): incorporate metrics.HTTPMetricsRoundTripper - Transport: requestid.NewRoundTripper(&ochttp.Transport{}), +var DefaultClient = &httpClient{ + &http.Client{Timeout: 1 * time.Minute}, + requestid.NewRoundTripper(http.DefaultTransport), } // Client provides a simple helper interface to make HTTP requests diff --git a/internal/httputil/client_test.go b/internal/httputil/client_test.go new file mode 100644 index 000000000..b5fd6211b --- /dev/null +++ b/internal/httputil/client_test.go @@ -0,0 +1,25 @@ +package httputil + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/pomerium/pomerium/internal/telemetry/requestid" +) + +func TestDefaultClient(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, header := range []string{"X-B3-Sampled", "X-B3-Spanid", "X-B3-Traceid", "X-Request-Id"} { + if _, ok := r.Header[header]; !ok { + t.Errorf("header %s is not set", header) + } + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) + req = req.WithContext(requestid.WithValue(context.Background(), "foo")) + _, _ = DefaultClient.Do(req) +}