mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 18:33:19 +02:00
add support for pomerium.request.headers for set_request_headers (#5563)
* add support for pomerium.request.headers for set_request_headers * add peg grammar
This commit is contained in:
parent
5f95dd32db
commit
a1eb75a8fe
4 changed files with 338 additions and 15 deletions
|
@ -7,8 +7,8 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -20,6 +20,8 @@ import (
|
|||
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/headertemplate"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
@ -149,20 +151,20 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context)
|
|||
}
|
||||
|
||||
for k, v := range e.request.Policy.SetRequestHeaders {
|
||||
e.response.Headers.Add(k, os.Expand(v, func(name string) string {
|
||||
switch name {
|
||||
case "$":
|
||||
return "$"
|
||||
case "pomerium.access_token":
|
||||
e.response.Headers.Add(k, headertemplate.Render(v, func(ref []string) string {
|
||||
switch {
|
||||
case slices.Equal(ref, []string{"pomerium", "access_token"}):
|
||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||
return s.GetOauthToken().GetAccessToken()
|
||||
case "pomerium.client_cert_fingerprint":
|
||||
case slices.Equal(ref, []string{"pomerium", "client_cert_fingerprint"}):
|
||||
return e.getClientCertFingerprint()
|
||||
case "pomerium.id_token":
|
||||
case slices.Equal(ref, []string{"pomerium", "id_token"}):
|
||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||
return s.GetIdToken().GetRaw()
|
||||
case "pomerium.jwt":
|
||||
case slices.Equal(ref, []string{"pomerium", "jwt"}):
|
||||
return e.getSignedJWT(ctx)
|
||||
case len(ref) > 3 && ref[0] == "pomerium" && ref[1] == "request" && ref[2] == "headers":
|
||||
return e.request.HTTP.Headers[httputil.CanonicalHeaderKey(ref[3])]
|
||||
}
|
||||
|
||||
return ""
|
||||
|
|
|
@ -218,15 +218,19 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
HTTP: RequestHTTP{
|
||||
Hostname: "from.example.com",
|
||||
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
|
||||
Headers: map[string]string{
|
||||
"X-Incoming-Header": "INCOMING",
|
||||
},
|
||||
},
|
||||
Policy: &config.Policy{
|
||||
SetRequestHeaders: map[string]string{
|
||||
"X-Custom-Header": "CUSTOM_VALUE",
|
||||
"X-ID-Token": "${pomerium.id_token}",
|
||||
"X-Access-Token": "${pomerium.access_token}",
|
||||
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
|
||||
"Authorization": "Bearer ${pomerium.jwt}",
|
||||
"Foo": "escaped $$dollar sign",
|
||||
"X-Custom-Header": "CUSTOM_VALUE",
|
||||
"X-ID-Token": "${pomerium.id_token}",
|
||||
"X-Access-Token": "${pomerium.access_token}",
|
||||
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
|
||||
"Authorization": "Bearer ${pomerium.jwt}",
|
||||
"Foo": "escaped $$dollar sign",
|
||||
"X-Incoming-Custom-Header": `From-Incoming ${pomerium.request.headers["X-Incoming-Header"]}`,
|
||||
},
|
||||
},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
|
@ -239,6 +243,7 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
assert.Equal(t, "3febe6467787e93f0a01030e0803072feaa710f724a9dc74de05cfba3d4a6d23",
|
||||
output.Headers.Get("Client-Cert-Fingerprint"))
|
||||
assert.Equal(t, "escaped $dollar sign", output.Headers.Get("Foo"))
|
||||
assert.Equal(t, "From-Incoming INCOMING", output.Headers.Get("X-Incoming-Custom-Header"))
|
||||
authHeader := output.Headers.Get("Authorization")
|
||||
assert.True(t, strings.HasPrefix(authHeader, "Bearer "))
|
||||
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
|
255
internal/headertemplate/headertemplate.go
Normal file
255
internal/headertemplate/headertemplate.go
Normal file
|
@ -0,0 +1,255 @@
|
|||
// Package headertemplate contains functions for rendering header templates.
|
||||
package headertemplate
|
||||
|
||||
import "strings"
|
||||
|
||||
// Render renders a header template string.
|
||||
func Render(src string, fn func(ref []string) string) string {
|
||||
p := newParser(src, fn)
|
||||
return p.parse()
|
||||
}
|
||||
|
||||
// This is a hand written parser attempting to model this peg grammar:
|
||||
//
|
||||
// Grammar <- ( Variable / Text )* !.
|
||||
// Text <- .
|
||||
// Variable <- EscapedVariable / SimpleVariable / ComplexVariable
|
||||
// EscapedVariable <- '$' '$'
|
||||
// SimpleVariable <- '$' SimpleExpression
|
||||
// SimpleExpression <- identifier ( '.' identifier )*
|
||||
// ComplexVariable <- '$' '{' _ ComplexExpression _ '}'
|
||||
// ComplexExpression <- identifier _ (ComplexSelector / ComplexIndex)*
|
||||
// ComplexSelector <- '.' _ ComplexExpression _
|
||||
// ComplexIndex <- '[' _ StringLiteral _ ']' _
|
||||
// StringLiteral <- '"' (('\\'.) / [^"])* '"'
|
||||
// identifier <- [a-zA-Z0-9_] [a-zA-Z0-9_\-]*
|
||||
// _ <- ( ' ' / '\t' )*
|
||||
|
||||
type parser struct {
|
||||
buffer []byte
|
||||
pos int
|
||||
stack []int
|
||||
visit func(ref []string) string
|
||||
}
|
||||
|
||||
func newParser(src string, visit func(ref []string) string) *parser {
|
||||
return &parser{buffer: []byte(src), visit: visit}
|
||||
}
|
||||
|
||||
func (p *parser) save() {
|
||||
p.stack = append(p.stack, p.pos)
|
||||
}
|
||||
|
||||
func (p *parser) restore() {
|
||||
p.pos = p.stack[len(p.stack)-1]
|
||||
}
|
||||
|
||||
func (p *parser) pop() {
|
||||
p.stack = p.stack[:len(p.stack)-1]
|
||||
}
|
||||
|
||||
func (p *parser) peek() byte {
|
||||
if p.pos < len(p.buffer) {
|
||||
return p.buffer[p.pos]
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *parser) next() byte {
|
||||
if p.pos < len(p.buffer) {
|
||||
c := p.buffer[p.pos]
|
||||
p.pos++
|
||||
return c
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *parser) parse() string {
|
||||
var b strings.Builder
|
||||
for p.pos < len(p.buffer) {
|
||||
if v, ok := p.parseVariable(); ok {
|
||||
b.WriteString(v)
|
||||
continue
|
||||
}
|
||||
b.WriteByte(p.next())
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (p *parser) parseVariable() (string, bool) {
|
||||
if p.peek() != '$' {
|
||||
return "", false
|
||||
}
|
||||
|
||||
p.save()
|
||||
defer p.pop()
|
||||
|
||||
// $$ becomes $
|
||||
p.next()
|
||||
if p.peek() == '$' {
|
||||
p.next()
|
||||
return "$", true
|
||||
}
|
||||
|
||||
if p.peek() == '{' {
|
||||
p.next()
|
||||
e, ok := p.parseComplexExpression()
|
||||
if !ok {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
if p.next() != '}' {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
return e, true
|
||||
}
|
||||
|
||||
e, ok := p.parseSimpleExpression()
|
||||
if !ok {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
|
||||
return e, true
|
||||
}
|
||||
|
||||
func (p *parser) parseComplexExpression() (string, bool) {
|
||||
p.save()
|
||||
defer p.pop()
|
||||
|
||||
p.skipWhitespace()
|
||||
|
||||
var ref []string
|
||||
id, ok := p.parseIdentifier()
|
||||
if !ok {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
ref = append(ref, id)
|
||||
|
||||
for {
|
||||
p.skipWhitespace()
|
||||
|
||||
if p.peek() == '.' {
|
||||
p.next()
|
||||
p.skipWhitespace()
|
||||
|
||||
id, ok := p.parseIdentifier()
|
||||
if !ok {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
ref = append(ref, id)
|
||||
|
||||
} else if p.peek() == '[' {
|
||||
p.next()
|
||||
p.skipWhitespace()
|
||||
|
||||
s, ok := p.parseString()
|
||||
if !ok {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
ref = append(ref, s)
|
||||
|
||||
p.skipWhitespace()
|
||||
if p.next() != ']' {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return p.visit(ref), true
|
||||
}
|
||||
|
||||
func (p *parser) parseString() (string, bool) {
|
||||
p.save()
|
||||
defer p.pop()
|
||||
|
||||
if p.next() != '"' {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for {
|
||||
c := p.next()
|
||||
switch c {
|
||||
case '"':
|
||||
return b.String(), true
|
||||
case 0:
|
||||
p.restore()
|
||||
return "", false
|
||||
case '\\':
|
||||
c = p.next()
|
||||
if c == 0 {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
b.WriteByte(c)
|
||||
default:
|
||||
b.WriteByte(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) parseSimpleExpression() (string, bool) {
|
||||
p.save()
|
||||
defer p.pop()
|
||||
|
||||
var ref []string
|
||||
for {
|
||||
id, ok := p.parseIdentifier()
|
||||
if !ok {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
ref = append(ref, id)
|
||||
|
||||
if p.peek() != '.' {
|
||||
break
|
||||
}
|
||||
p.next()
|
||||
}
|
||||
|
||||
return p.visit(ref), true
|
||||
}
|
||||
|
||||
func (p *parser) parseIdentifier() (string, bool) {
|
||||
p.save()
|
||||
defer p.pop()
|
||||
|
||||
var b strings.Builder
|
||||
for isIdentifierCharacter(p.peek()) {
|
||||
b.WriteByte(p.next())
|
||||
}
|
||||
|
||||
if b.Len() == 0 {
|
||||
p.restore()
|
||||
return "", false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func (p *parser) skipWhitespace() {
|
||||
for isWhitespaceCharacter(p.peek()) {
|
||||
p.next()
|
||||
}
|
||||
}
|
||||
|
||||
func isIdentifierCharacter(c byte) bool {
|
||||
return (c >= '0' && c <= '9') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
c == '_' ||
|
||||
c == '-'
|
||||
}
|
||||
|
||||
func isWhitespaceCharacter(c byte) bool {
|
||||
return c == ' ' || c == '\t'
|
||||
}
|
61
internal/headertemplate/headertemplate_test.go
Normal file
61
internal/headertemplate/headertemplate_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package headertemplate_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/headertemplate"
|
||||
)
|
||||
|
||||
func TestRender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
in string
|
||||
expect string
|
||||
}{
|
||||
{"x $$ y $$ z", "x $ y $ z"},
|
||||
{`${x.y.z}`, `<x,y,z>`},
|
||||
{`${ x . y . z }`, `<x,y,z>`},
|
||||
{`${x["y"].z}`, `<x,y,z>`},
|
||||
{`${x["`, `${x["`},
|
||||
{`${`, `${`},
|
||||
{`${}`, `${}`},
|
||||
{`${x["\\"]}`, `<x,\>`},
|
||||
{`${x["\""]}`, `<x,">`},
|
||||
|
||||
{`${pomerium.access_token}`, `<pomerium,access_token>`},
|
||||
{`$pomerium.access_token`, `<pomerium,access_token>`},
|
||||
{`${pomerium.client_cert_fingerprint}`, `<pomerium,client_cert_fingerprint>`},
|
||||
{`$pomerium.client_cert_fingerprint`, `<pomerium,client_cert_fingerprint>`},
|
||||
{`${pomerium.id_token}`, `<pomerium,id_token>`},
|
||||
{`$pomerium.id_token`, `<pomerium,id_token>`},
|
||||
{`${pomerium.jwt}`, `<pomerium,jwt>`},
|
||||
{`$pomerium.jwt`, `<pomerium,jwt>`},
|
||||
{`${pomerium.request.headers["X-Access-Token"]}`, `<pomerium,request,headers,X-Access-Token>`},
|
||||
{`$pomerium.request.headers.X-Access-Token`, `<pomerium,request,headers,X-Access-Token>`},
|
||||
} {
|
||||
actual := headertemplate.Render(tc.in, func(ref []string) string {
|
||||
return "<" + strings.Join(ref, ",") + ">"
|
||||
})
|
||||
assert.Equal(t, tc.expect, actual)
|
||||
}
|
||||
|
||||
assert.Equal(t, "x $ y $ z", headertemplate.Render("x $$ y $$ z", func(_ []string) string {
|
||||
return ""
|
||||
}))
|
||||
assert.Equal(t, "before JWT after", headertemplate.Render("before $pomerium.jwt after", func(ref []string) string {
|
||||
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
|
||||
return "JWT"
|
||||
}))
|
||||
assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string {
|
||||
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
|
||||
return "JWT"
|
||||
}))
|
||||
assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string {
|
||||
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
|
||||
return "JWT"
|
||||
}))
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue