authorize: add request id to context (#3497)

* authorize: add request id to context

* fix context keys
This commit is contained in:
Caleb Doxsey 2022-07-26 14:34:48 -06:00 committed by GitHub
parent 06ee1c8711
commit 89a105c8e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 9 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
@ -39,6 +40,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
// convert the incoming envoy-style http request into a go-style http request // convert the incoming envoy-style http request into a go-style http request
hreq := getHTTPRequestFromCheckRequest(in) hreq := getHTTPRequestFromCheckRequest(in)
ctx = requestid.WithValue(ctx, requestid.FromHTTPHeader(hreq.Header))
isForwardAuth := a.isForwardAuth(in) isForwardAuth := a.isForwardAuth(in)
if isForwardAuth { if isForwardAuth {

View file

@ -322,13 +322,20 @@ func TestLoadBalancer(t *testing.T) {
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return distribution return distribution
} }
defer res.Body.Close()
bs, err := io.ReadAll(res.Body)
if !assert.NoError(t, err) {
return distribution
}
var result struct { var result struct {
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
} }
err = json.NewDecoder(res.Body).Decode(&result) err = json.Unmarshal(bs, &result)
_ = res.Body.Close() if !assert.NoError(t, err, "invalid json: %s", bs) {
assert.NoError(t, err) return distribution
}
distribution[result.Hostname]++ distribution[result.Hostname]++
} }

View file

@ -11,16 +11,16 @@ import (
const headerName = "x-request-id" const headerName = "x-request-id"
var contextKey struct{} type contextKey struct{}
// WithValue returns a new context from the parent context with a request id value set. // WithValue returns a new context from the parent context with a request id value set.
func WithValue(parent context.Context, requestID string) context.Context { func WithValue(parent context.Context, requestID string) context.Context {
return context.WithValue(parent, contextKey, requestID) return context.WithValue(parent, contextKey{}, requestID)
} }
// FromContext gets the request id from a context. // FromContext gets the request id from a context.
func FromContext(ctx context.Context) string { func FromContext(ctx context.Context) string {
if id, ok := ctx.Value(contextKey).(string); ok { if id, ok := ctx.Value(contextKey{}).(string); ok {
return id return id
} }
return "" return ""

View file

@ -33,11 +33,11 @@ func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts .
return nil, status.Error(codes.NotFound, "not found") return nil, status.Error(codes.NotFound, "not found")
} }
var querierKey struct{} type querierKey struct{}
// GetQuerier gets the databroker Querier from the context. // GetQuerier gets the databroker Querier from the context.
func GetQuerier(ctx context.Context) Querier { func GetQuerier(ctx context.Context) Querier {
q, ok := ctx.Value(querierKey).(Querier) q, ok := ctx.Value(querierKey{}).(Querier)
if !ok { if !ok {
q = nilQuerier{} q = nilQuerier{}
} }
@ -46,7 +46,7 @@ func GetQuerier(ctx context.Context) Querier {
// WithQuerier sets the databroker Querier on a context. // WithQuerier sets the databroker Querier on a context.
func WithQuerier(ctx context.Context, querier Querier) context.Context { func WithQuerier(ctx context.Context, querier Querier) context.Context {
return context.WithValue(ctx, querierKey, querier) return context.WithValue(ctx, querierKey{}, querier)
} }
type staticQuerier struct { type staticQuerier struct {