diff --git a/authorize/grpc.go b/authorize/grpc.go index e8a8ddbb1..2c93fdbb5 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/grpc/user" @@ -39,6 +40,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe // convert the incoming envoy-style http request into a go-style http request hreq := getHTTPRequestFromCheckRequest(in) + ctx = requestid.WithValue(ctx, requestid.FromHTTPHeader(hreq.Header)) isForwardAuth := a.isForwardAuth(in) if isForwardAuth { diff --git a/integration/policy_test.go b/integration/policy_test.go index ce03f57ee..ab3116e60 100644 --- a/integration/policy_test.go +++ b/integration/policy_test.go @@ -322,13 +322,20 @@ func TestLoadBalancer(t *testing.T) { if !assert.NoError(t, err) { return distribution } + defer res.Body.Close() + + bs, err := io.ReadAll(res.Body) + if !assert.NoError(t, err) { + return distribution + } var result struct { Hostname string `json:"hostname"` } - err = json.NewDecoder(res.Body).Decode(&result) - _ = res.Body.Close() - assert.NoError(t, err) + err = json.Unmarshal(bs, &result) + if !assert.NoError(t, err, "invalid json: %s", bs) { + return distribution + } distribution[result.Hostname]++ } diff --git a/internal/telemetry/requestid/requestid.go b/internal/telemetry/requestid/requestid.go index 83e514eda..df780b24d 100644 --- a/internal/telemetry/requestid/requestid.go +++ b/internal/telemetry/requestid/requestid.go @@ -11,16 +11,16 @@ import ( const headerName = "x-request-id" -var contextKey struct{} +type contextKey struct{} // WithValue returns a new context from the parent context with a request id value set. func WithValue(parent context.Context, requestID string) context.Context { - return context.WithValue(parent, contextKey, requestID) + return context.WithValue(parent, contextKey{}, requestID) } // FromContext gets the request id from a context. func FromContext(ctx context.Context) string { - if id, ok := ctx.Value(contextKey).(string); ok { + if id, ok := ctx.Value(contextKey{}).(string); ok { return id } return "" diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 1eac9a2f5..7a9736d16 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -33,11 +33,11 @@ func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts . return nil, status.Error(codes.NotFound, "not found") } -var querierKey struct{} +type querierKey struct{} // GetQuerier gets the databroker Querier from the context. func GetQuerier(ctx context.Context) Querier { - q, ok := ctx.Value(querierKey).(Querier) + q, ok := ctx.Value(querierKey{}).(Querier) if !ok { q = nilQuerier{} } @@ -46,7 +46,7 @@ func GetQuerier(ctx context.Context) Querier { // WithQuerier sets the databroker Querier on a context. func WithQuerier(ctx context.Context, querier Querier) context.Context { - return context.WithValue(ctx, querierKey, querier) + return context.WithValue(ctx, querierKey{}, querier) } type staticQuerier struct {