mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 17:07:24 +02:00
internal/httputil: add HTTPStatsRoundTripper to DefaultClient (#828)
This commit is contained in:
parent
7abe3a3b02
commit
9e711b4612
2 changed files with 41 additions and 6 deletions
|
@ -12,19 +12,29 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
"go.opencensus.io/plugin/ochttp"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
"github.com/pomerium/pomerium/internal/tripper"
|
||||
)
|
||||
|
||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||
var ErrTokenRevoked = errors.New("token expired or revoked")
|
||||
|
||||
type httpClient struct {
|
||||
*http.Client
|
||||
requestIDTripper http.RoundTripper
|
||||
}
|
||||
|
||||
func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
||||
tripperChain := tripper.NewChain(metrics.HTTPMetricsRoundTripper("idp_http_client", req.Host))
|
||||
c.Client.Transport = tripperChain.Then(c.requestIDTripper)
|
||||
return c.Client.Do(req)
|
||||
}
|
||||
|
||||
// DefaultClient avoids leaks by setting an upper limit for timeouts.
|
||||
var DefaultClient = &http.Client{
|
||||
Timeout: 1 * time.Minute,
|
||||
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
|
||||
Transport: requestid.NewRoundTripper(&ochttp.Transport{}),
|
||||
var DefaultClient = &httpClient{
|
||||
&http.Client{Timeout: 1 * time.Minute},
|
||||
requestid.NewRoundTripper(http.DefaultTransport),
|
||||
}
|
||||
|
||||
// Client provides a simple helper interface to make HTTP requests
|
||||
|
|
25
internal/httputil/client_test.go
Normal file
25
internal/httputil/client_test.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
)
|
||||
|
||||
func TestDefaultClient(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, header := range []string{"X-B3-Sampled", "X-B3-Spanid", "X-B3-Traceid", "X-Request-Id"} {
|
||||
if _, ok := r.Header[header]; !ok {
|
||||
t.Errorf("header %s is not set", header)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
req, _ := http.NewRequest(http.MethodGet, ts.URL, nil)
|
||||
req = req.WithContext(requestid.WithValue(context.Background(), "foo"))
|
||||
_, _ = DefaultClient.Do(req)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue