diff --git a/authorize/evaluator/headers_evaluator_evaluation.go b/authorize/evaluator/headers_evaluator_evaluation.go index 6d5621b00..8dce405d9 100644 --- a/authorize/evaluator/headers_evaluator_evaluation.go +++ b/authorize/evaluator/headers_evaluator_evaluation.go @@ -7,8 +7,8 @@ import ( "encoding/json" "fmt" "net/http" - "os" "reflect" + "slices" "strings" "time" @@ -20,6 +20,8 @@ import ( "github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/headertemplate" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/session" @@ -149,20 +151,20 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context) } for k, v := range e.request.Policy.SetRequestHeaders { - e.response.Headers.Add(k, os.Expand(v, func(name string) string { - switch name { - case "$": - return "$" - case "pomerium.access_token": + e.response.Headers.Add(k, headertemplate.Render(v, func(ref []string) string { + switch { + case slices.Equal(ref, []string{"pomerium", "access_token"}): s, _ := e.getSessionOrServiceAccount(ctx) return s.GetOauthToken().GetAccessToken() - case "pomerium.client_cert_fingerprint": + case slices.Equal(ref, []string{"pomerium", "client_cert_fingerprint"}): return e.getClientCertFingerprint() - case "pomerium.id_token": + case slices.Equal(ref, []string{"pomerium", "id_token"}): s, _ := e.getSessionOrServiceAccount(ctx) return s.GetIdToken().GetRaw() - case "pomerium.jwt": + case slices.Equal(ref, []string{"pomerium", "jwt"}): return e.getSignedJWT(ctx) + case len(ref) > 3 && ref[0] == "pomerium" && ref[1] == "request" && ref[2] == "headers": + return e.request.HTTP.Headers[httputil.CanonicalHeaderKey(ref[3])] } return "" diff --git a/authorize/evaluator/headers_evaluator_test.go b/authorize/evaluator/headers_evaluator_test.go index 67145a24b..00b0d57c7 100644 --- a/authorize/evaluator/headers_evaluator_test.go +++ b/authorize/evaluator/headers_evaluator_test.go @@ -218,15 +218,19 @@ func TestHeadersEvaluator(t *testing.T) { HTTP: RequestHTTP{ Hostname: "from.example.com", ClientCertificate: ClientCertificateInfo{Leaf: testValidCert}, + Headers: map[string]string{ + "X-Incoming-Header": "INCOMING", + }, }, Policy: &config.Policy{ SetRequestHeaders: map[string]string{ - "X-Custom-Header": "CUSTOM_VALUE", - "X-ID-Token": "${pomerium.id_token}", - "X-Access-Token": "${pomerium.access_token}", - "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", - "Authorization": "Bearer ${pomerium.jwt}", - "Foo": "escaped $$dollar sign", + "X-Custom-Header": "CUSTOM_VALUE", + "X-ID-Token": "${pomerium.id_token}", + "X-Access-Token": "${pomerium.access_token}", + "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", + "Authorization": "Bearer ${pomerium.jwt}", + "Foo": "escaped $$dollar sign", + "X-Incoming-Custom-Header": `From-Incoming ${pomerium.request.headers["X-Incoming-Header"]}`, }, }, Session: RequestSession{ID: "s1"}, @@ -239,6 +243,7 @@ func TestHeadersEvaluator(t *testing.T) { assert.Equal(t, "3febe6467787e93f0a01030e0803072feaa710f724a9dc74de05cfba3d4a6d23", output.Headers.Get("Client-Cert-Fingerprint")) assert.Equal(t, "escaped $dollar sign", output.Headers.Get("Foo")) + assert.Equal(t, "From-Incoming INCOMING", output.Headers.Get("X-Incoming-Custom-Header")) authHeader := output.Headers.Get("Authorization") assert.True(t, strings.HasPrefix(authHeader, "Bearer ")) authHeader = strings.TrimPrefix(authHeader, "Bearer ") diff --git a/internal/headertemplate/headertemplate.go b/internal/headertemplate/headertemplate.go new file mode 100644 index 000000000..ab44f8ffc --- /dev/null +++ b/internal/headertemplate/headertemplate.go @@ -0,0 +1,255 @@ +// Package headertemplate contains functions for rendering header templates. +package headertemplate + +import "strings" + +// Render renders a header template string. +func Render(src string, fn func(ref []string) string) string { + p := newParser(src, fn) + return p.parse() +} + +// This is a hand written parser attempting to model this peg grammar: +// +// Grammar <- ( Variable / Text )* !. +// Text <- . +// Variable <- EscapedVariable / SimpleVariable / ComplexVariable +// EscapedVariable <- '$' '$' +// SimpleVariable <- '$' SimpleExpression +// SimpleExpression <- identifier ( '.' identifier )* +// ComplexVariable <- '$' '{' _ ComplexExpression _ '}' +// ComplexExpression <- identifier _ (ComplexSelector / ComplexIndex)* +// ComplexSelector <- '.' _ ComplexExpression _ +// ComplexIndex <- '[' _ StringLiteral _ ']' _ +// StringLiteral <- '"' (('\\'.) / [^"])* '"' +// identifier <- [a-zA-Z0-9_] [a-zA-Z0-9_\-]* +// _ <- ( ' ' / '\t' )* + +type parser struct { + buffer []byte + pos int + stack []int + visit func(ref []string) string +} + +func newParser(src string, visit func(ref []string) string) *parser { + return &parser{buffer: []byte(src), visit: visit} +} + +func (p *parser) save() { + p.stack = append(p.stack, p.pos) +} + +func (p *parser) restore() { + p.pos = p.stack[len(p.stack)-1] +} + +func (p *parser) pop() { + p.stack = p.stack[:len(p.stack)-1] +} + +func (p *parser) peek() byte { + if p.pos < len(p.buffer) { + return p.buffer[p.pos] + } + return 0 +} + +func (p *parser) next() byte { + if p.pos < len(p.buffer) { + c := p.buffer[p.pos] + p.pos++ + return c + } + return 0 +} + +func (p *parser) parse() string { + var b strings.Builder + for p.pos < len(p.buffer) { + if v, ok := p.parseVariable(); ok { + b.WriteString(v) + continue + } + b.WriteByte(p.next()) + } + return b.String() +} + +func (p *parser) parseVariable() (string, bool) { + if p.peek() != '$' { + return "", false + } + + p.save() + defer p.pop() + + // $$ becomes $ + p.next() + if p.peek() == '$' { + p.next() + return "$", true + } + + if p.peek() == '{' { + p.next() + e, ok := p.parseComplexExpression() + if !ok { + p.restore() + return "", false + } + if p.next() != '}' { + p.restore() + return "", false + } + return e, true + } + + e, ok := p.parseSimpleExpression() + if !ok { + p.restore() + return "", false + } + + return e, true +} + +func (p *parser) parseComplexExpression() (string, bool) { + p.save() + defer p.pop() + + p.skipWhitespace() + + var ref []string + id, ok := p.parseIdentifier() + if !ok { + p.restore() + return "", false + } + ref = append(ref, id) + + for { + p.skipWhitespace() + + if p.peek() == '.' { + p.next() + p.skipWhitespace() + + id, ok := p.parseIdentifier() + if !ok { + p.restore() + return "", false + } + ref = append(ref, id) + + } else if p.peek() == '[' { + p.next() + p.skipWhitespace() + + s, ok := p.parseString() + if !ok { + p.restore() + return "", false + } + ref = append(ref, s) + + p.skipWhitespace() + if p.next() != ']' { + p.restore() + return "", false + } + } else { + break + } + } + + return p.visit(ref), true +} + +func (p *parser) parseString() (string, bool) { + p.save() + defer p.pop() + + if p.next() != '"' { + p.restore() + return "", false + } + + var b strings.Builder + for { + c := p.next() + switch c { + case '"': + return b.String(), true + case 0: + p.restore() + return "", false + case '\\': + c = p.next() + if c == 0 { + p.restore() + return "", false + } + b.WriteByte(c) + default: + b.WriteByte(c) + } + } +} + +func (p *parser) parseSimpleExpression() (string, bool) { + p.save() + defer p.pop() + + var ref []string + for { + id, ok := p.parseIdentifier() + if !ok { + p.restore() + return "", false + } + ref = append(ref, id) + + if p.peek() != '.' { + break + } + p.next() + } + + return p.visit(ref), true +} + +func (p *parser) parseIdentifier() (string, bool) { + p.save() + defer p.pop() + + var b strings.Builder + for isIdentifierCharacter(p.peek()) { + b.WriteByte(p.next()) + } + + if b.Len() == 0 { + p.restore() + return "", false + } + + return b.String(), true +} + +func (p *parser) skipWhitespace() { + for isWhitespaceCharacter(p.peek()) { + p.next() + } +} + +func isIdentifierCharacter(c byte) bool { + return (c >= '0' && c <= '9') || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + c == '_' || + c == '-' +} + +func isWhitespaceCharacter(c byte) bool { + return c == ' ' || c == '\t' +} diff --git a/internal/headertemplate/headertemplate_test.go b/internal/headertemplate/headertemplate_test.go new file mode 100644 index 000000000..8896c7707 --- /dev/null +++ b/internal/headertemplate/headertemplate_test.go @@ -0,0 +1,61 @@ +package headertemplate_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/headertemplate" +) + +func TestRender(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + in string + expect string + }{ + {"x $$ y $$ z", "x $ y $ z"}, + {`${x.y.z}`, ``}, + {`${ x . y . z }`, ``}, + {`${x["y"].z}`, ``}, + {`${x["`, `${x["`}, + {`${`, `${`}, + {`${}`, `${}`}, + {`${x["\\"]}`, ``}, + {`${x["\""]}`, ``}, + + {`${pomerium.access_token}`, ``}, + {`$pomerium.access_token`, ``}, + {`${pomerium.client_cert_fingerprint}`, ``}, + {`$pomerium.client_cert_fingerprint`, ``}, + {`${pomerium.id_token}`, ``}, + {`$pomerium.id_token`, ``}, + {`${pomerium.jwt}`, ``}, + {`$pomerium.jwt`, ``}, + {`${pomerium.request.headers["X-Access-Token"]}`, ``}, + {`$pomerium.request.headers.X-Access-Token`, ``}, + } { + actual := headertemplate.Render(tc.in, func(ref []string) string { + return "<" + strings.Join(ref, ",") + ">" + }) + assert.Equal(t, tc.expect, actual) + } + + assert.Equal(t, "x $ y $ z", headertemplate.Render("x $$ y $$ z", func(_ []string) string { + return "" + })) + assert.Equal(t, "before JWT after", headertemplate.Render("before $pomerium.jwt after", func(ref []string) string { + assert.Equal(t, []string{"pomerium", "jwt"}, ref) + return "JWT" + })) + assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string { + assert.Equal(t, []string{"pomerium", "jwt"}, ref) + return "JWT" + })) + assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string { + assert.Equal(t, []string{"pomerium", "jwt"}, ref) + return "JWT" + })) +}