add support for pomerium.request.headers for set_request_headers (#5563)

* add support for pomerium.request.headers for set_request_headers

* add peg grammar
This commit is contained in:
Caleb Doxsey 2025-04-07 10:32:03 -06:00 committed by GitHub
parent 5f95dd32db
commit a1eb75a8fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 338 additions and 15 deletions

View file

@ -7,8 +7,8 @@ import (
"encoding/json"
"fmt"
"net/http"
"os"
"reflect"
"slices"
"strings"
"time"
@ -20,6 +20,8 @@ import (
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/headertemplate"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -149,20 +151,20 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context)
}
for k, v := range e.request.Policy.SetRequestHeaders {
e.response.Headers.Add(k, os.Expand(v, func(name string) string {
switch name {
case "$":
return "$"
case "pomerium.access_token":
e.response.Headers.Add(k, headertemplate.Render(v, func(ref []string) string {
switch {
case slices.Equal(ref, []string{"pomerium", "access_token"}):
s, _ := e.getSessionOrServiceAccount(ctx)
return s.GetOauthToken().GetAccessToken()
case "pomerium.client_cert_fingerprint":
case slices.Equal(ref, []string{"pomerium", "client_cert_fingerprint"}):
return e.getClientCertFingerprint()
case "pomerium.id_token":
case slices.Equal(ref, []string{"pomerium", "id_token"}):
s, _ := e.getSessionOrServiceAccount(ctx)
return s.GetIdToken().GetRaw()
case "pomerium.jwt":
case slices.Equal(ref, []string{"pomerium", "jwt"}):
return e.getSignedJWT(ctx)
case len(ref) > 3 && ref[0] == "pomerium" && ref[1] == "request" && ref[2] == "headers":
return e.request.HTTP.Headers[httputil.CanonicalHeaderKey(ref[3])]
}
return ""

View file

@ -218,15 +218,19 @@ func TestHeadersEvaluator(t *testing.T) {
HTTP: RequestHTTP{
Hostname: "from.example.com",
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
Headers: map[string]string{
"X-Incoming-Header": "INCOMING",
},
},
Policy: &config.Policy{
SetRequestHeaders: map[string]string{
"X-Custom-Header": "CUSTOM_VALUE",
"X-ID-Token": "${pomerium.id_token}",
"X-Access-Token": "${pomerium.access_token}",
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
"Authorization": "Bearer ${pomerium.jwt}",
"Foo": "escaped $$dollar sign",
"X-Custom-Header": "CUSTOM_VALUE",
"X-ID-Token": "${pomerium.id_token}",
"X-Access-Token": "${pomerium.access_token}",
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
"Authorization": "Bearer ${pomerium.jwt}",
"Foo": "escaped $$dollar sign",
"X-Incoming-Custom-Header": `From-Incoming ${pomerium.request.headers["X-Incoming-Header"]}`,
},
},
Session: RequestSession{ID: "s1"},
@ -239,6 +243,7 @@ func TestHeadersEvaluator(t *testing.T) {
assert.Equal(t, "3febe6467787e93f0a01030e0803072feaa710f724a9dc74de05cfba3d4a6d23",
output.Headers.Get("Client-Cert-Fingerprint"))
assert.Equal(t, "escaped $dollar sign", output.Headers.Get("Foo"))
assert.Equal(t, "From-Incoming INCOMING", output.Headers.Get("X-Incoming-Custom-Header"))
authHeader := output.Headers.Get("Authorization")
assert.True(t, strings.HasPrefix(authHeader, "Bearer "))
authHeader = strings.TrimPrefix(authHeader, "Bearer ")

View file

@ -0,0 +1,255 @@
// Package headertemplate contains functions for rendering header templates.
package headertemplate
import "strings"
// Render renders a header template string.
func Render(src string, fn func(ref []string) string) string {
p := newParser(src, fn)
return p.parse()
}
// This is a hand written parser attempting to model this peg grammar:
//
// Grammar <- ( Variable / Text )* !.
// Text <- .
// Variable <- EscapedVariable / SimpleVariable / ComplexVariable
// EscapedVariable <- '$' '$'
// SimpleVariable <- '$' SimpleExpression
// SimpleExpression <- identifier ( '.' identifier )*
// ComplexVariable <- '$' '{' _ ComplexExpression _ '}'
// ComplexExpression <- identifier _ (ComplexSelector / ComplexIndex)*
// ComplexSelector <- '.' _ ComplexExpression _
// ComplexIndex <- '[' _ StringLiteral _ ']' _
// StringLiteral <- '"' (('\\'.) / [^"])* '"'
// identifier <- [a-zA-Z0-9_] [a-zA-Z0-9_\-]*
// _ <- ( ' ' / '\t' )*
type parser struct {
buffer []byte
pos int
stack []int
visit func(ref []string) string
}
func newParser(src string, visit func(ref []string) string) *parser {
return &parser{buffer: []byte(src), visit: visit}
}
func (p *parser) save() {
p.stack = append(p.stack, p.pos)
}
func (p *parser) restore() {
p.pos = p.stack[len(p.stack)-1]
}
func (p *parser) pop() {
p.stack = p.stack[:len(p.stack)-1]
}
func (p *parser) peek() byte {
if p.pos < len(p.buffer) {
return p.buffer[p.pos]
}
return 0
}
func (p *parser) next() byte {
if p.pos < len(p.buffer) {
c := p.buffer[p.pos]
p.pos++
return c
}
return 0
}
func (p *parser) parse() string {
var b strings.Builder
for p.pos < len(p.buffer) {
if v, ok := p.parseVariable(); ok {
b.WriteString(v)
continue
}
b.WriteByte(p.next())
}
return b.String()
}
func (p *parser) parseVariable() (string, bool) {
if p.peek() != '$' {
return "", false
}
p.save()
defer p.pop()
// $$ becomes $
p.next()
if p.peek() == '$' {
p.next()
return "$", true
}
if p.peek() == '{' {
p.next()
e, ok := p.parseComplexExpression()
if !ok {
p.restore()
return "", false
}
if p.next() != '}' {
p.restore()
return "", false
}
return e, true
}
e, ok := p.parseSimpleExpression()
if !ok {
p.restore()
return "", false
}
return e, true
}
func (p *parser) parseComplexExpression() (string, bool) {
p.save()
defer p.pop()
p.skipWhitespace()
var ref []string
id, ok := p.parseIdentifier()
if !ok {
p.restore()
return "", false
}
ref = append(ref, id)
for {
p.skipWhitespace()
if p.peek() == '.' {
p.next()
p.skipWhitespace()
id, ok := p.parseIdentifier()
if !ok {
p.restore()
return "", false
}
ref = append(ref, id)
} else if p.peek() == '[' {
p.next()
p.skipWhitespace()
s, ok := p.parseString()
if !ok {
p.restore()
return "", false
}
ref = append(ref, s)
p.skipWhitespace()
if p.next() != ']' {
p.restore()
return "", false
}
} else {
break
}
}
return p.visit(ref), true
}
func (p *parser) parseString() (string, bool) {
p.save()
defer p.pop()
if p.next() != '"' {
p.restore()
return "", false
}
var b strings.Builder
for {
c := p.next()
switch c {
case '"':
return b.String(), true
case 0:
p.restore()
return "", false
case '\\':
c = p.next()
if c == 0 {
p.restore()
return "", false
}
b.WriteByte(c)
default:
b.WriteByte(c)
}
}
}
func (p *parser) parseSimpleExpression() (string, bool) {
p.save()
defer p.pop()
var ref []string
for {
id, ok := p.parseIdentifier()
if !ok {
p.restore()
return "", false
}
ref = append(ref, id)
if p.peek() != '.' {
break
}
p.next()
}
return p.visit(ref), true
}
func (p *parser) parseIdentifier() (string, bool) {
p.save()
defer p.pop()
var b strings.Builder
for isIdentifierCharacter(p.peek()) {
b.WriteByte(p.next())
}
if b.Len() == 0 {
p.restore()
return "", false
}
return b.String(), true
}
func (p *parser) skipWhitespace() {
for isWhitespaceCharacter(p.peek()) {
p.next()
}
}
func isIdentifierCharacter(c byte) bool {
return (c >= '0' && c <= '9') ||
(c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
c == '_' ||
c == '-'
}
func isWhitespaceCharacter(c byte) bool {
return c == ' ' || c == '\t'
}

View file

@ -0,0 +1,61 @@
package headertemplate_test
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/headertemplate"
)
func TestRender(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
in string
expect string
}{
{"x $$ y $$ z", "x $ y $ z"},
{`${x.y.z}`, `<x,y,z>`},
{`${ x . y . z }`, `<x,y,z>`},
{`${x["y"].z}`, `<x,y,z>`},
{`${x["`, `${x["`},
{`${`, `${`},
{`${}`, `${}`},
{`${x["\\"]}`, `<x,\>`},
{`${x["\""]}`, `<x,">`},
{`${pomerium.access_token}`, `<pomerium,access_token>`},
{`$pomerium.access_token`, `<pomerium,access_token>`},
{`${pomerium.client_cert_fingerprint}`, `<pomerium,client_cert_fingerprint>`},
{`$pomerium.client_cert_fingerprint`, `<pomerium,client_cert_fingerprint>`},
{`${pomerium.id_token}`, `<pomerium,id_token>`},
{`$pomerium.id_token`, `<pomerium,id_token>`},
{`${pomerium.jwt}`, `<pomerium,jwt>`},
{`$pomerium.jwt`, `<pomerium,jwt>`},
{`${pomerium.request.headers["X-Access-Token"]}`, `<pomerium,request,headers,X-Access-Token>`},
{`$pomerium.request.headers.X-Access-Token`, `<pomerium,request,headers,X-Access-Token>`},
} {
actual := headertemplate.Render(tc.in, func(ref []string) string {
return "<" + strings.Join(ref, ",") + ">"
})
assert.Equal(t, tc.expect, actual)
}
assert.Equal(t, "x $ y $ z", headertemplate.Render("x $$ y $$ z", func(_ []string) string {
return ""
}))
assert.Equal(t, "before JWT after", headertemplate.Render("before $pomerium.jwt after", func(ref []string) string {
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
return "JWT"
}))
assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string {
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
return "JWT"
}))
assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string {
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
return "JWT"
}))
}