mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 12:52:53 +02:00
add support for pomerium.request.headers for set_request_headers (#5563)
* add support for pomerium.request.headers for set_request_headers * add peg grammar
This commit is contained in:
parent
5f95dd32db
commit
a1eb75a8fe
4 changed files with 338 additions and 15 deletions
|
@ -7,8 +7,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -20,6 +20,8 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/datasource/pkg/directory"
|
"github.com/pomerium/datasource/pkg/directory"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/headertemplate"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
@ -149,20 +151,20 @@ func (e *headersEvaluatorEvaluation) fillSetRequestHeaders(ctx context.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range e.request.Policy.SetRequestHeaders {
|
for k, v := range e.request.Policy.SetRequestHeaders {
|
||||||
e.response.Headers.Add(k, os.Expand(v, func(name string) string {
|
e.response.Headers.Add(k, headertemplate.Render(v, func(ref []string) string {
|
||||||
switch name {
|
switch {
|
||||||
case "$":
|
case slices.Equal(ref, []string{"pomerium", "access_token"}):
|
||||||
return "$"
|
|
||||||
case "pomerium.access_token":
|
|
||||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||||
return s.GetOauthToken().GetAccessToken()
|
return s.GetOauthToken().GetAccessToken()
|
||||||
case "pomerium.client_cert_fingerprint":
|
case slices.Equal(ref, []string{"pomerium", "client_cert_fingerprint"}):
|
||||||
return e.getClientCertFingerprint()
|
return e.getClientCertFingerprint()
|
||||||
case "pomerium.id_token":
|
case slices.Equal(ref, []string{"pomerium", "id_token"}):
|
||||||
s, _ := e.getSessionOrServiceAccount(ctx)
|
s, _ := e.getSessionOrServiceAccount(ctx)
|
||||||
return s.GetIdToken().GetRaw()
|
return s.GetIdToken().GetRaw()
|
||||||
case "pomerium.jwt":
|
case slices.Equal(ref, []string{"pomerium", "jwt"}):
|
||||||
return e.getSignedJWT(ctx)
|
return e.getSignedJWT(ctx)
|
||||||
|
case len(ref) > 3 && ref[0] == "pomerium" && ref[1] == "request" && ref[2] == "headers":
|
||||||
|
return e.request.HTTP.Headers[httputil.CanonicalHeaderKey(ref[3])]
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -218,6 +218,9 @@ func TestHeadersEvaluator(t *testing.T) {
|
||||||
HTTP: RequestHTTP{
|
HTTP: RequestHTTP{
|
||||||
Hostname: "from.example.com",
|
Hostname: "from.example.com",
|
||||||
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
|
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
|
||||||
|
Headers: map[string]string{
|
||||||
|
"X-Incoming-Header": "INCOMING",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Policy: &config.Policy{
|
Policy: &config.Policy{
|
||||||
SetRequestHeaders: map[string]string{
|
SetRequestHeaders: map[string]string{
|
||||||
|
@ -227,6 +230,7 @@ func TestHeadersEvaluator(t *testing.T) {
|
||||||
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
|
"Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}",
|
||||||
"Authorization": "Bearer ${pomerium.jwt}",
|
"Authorization": "Bearer ${pomerium.jwt}",
|
||||||
"Foo": "escaped $$dollar sign",
|
"Foo": "escaped $$dollar sign",
|
||||||
|
"X-Incoming-Custom-Header": `From-Incoming ${pomerium.request.headers["X-Incoming-Header"]}`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Session: RequestSession{ID: "s1"},
|
Session: RequestSession{ID: "s1"},
|
||||||
|
@ -239,6 +243,7 @@ func TestHeadersEvaluator(t *testing.T) {
|
||||||
assert.Equal(t, "3febe6467787e93f0a01030e0803072feaa710f724a9dc74de05cfba3d4a6d23",
|
assert.Equal(t, "3febe6467787e93f0a01030e0803072feaa710f724a9dc74de05cfba3d4a6d23",
|
||||||
output.Headers.Get("Client-Cert-Fingerprint"))
|
output.Headers.Get("Client-Cert-Fingerprint"))
|
||||||
assert.Equal(t, "escaped $dollar sign", output.Headers.Get("Foo"))
|
assert.Equal(t, "escaped $dollar sign", output.Headers.Get("Foo"))
|
||||||
|
assert.Equal(t, "From-Incoming INCOMING", output.Headers.Get("X-Incoming-Custom-Header"))
|
||||||
authHeader := output.Headers.Get("Authorization")
|
authHeader := output.Headers.Get("Authorization")
|
||||||
assert.True(t, strings.HasPrefix(authHeader, "Bearer "))
|
assert.True(t, strings.HasPrefix(authHeader, "Bearer "))
|
||||||
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
|
255
internal/headertemplate/headertemplate.go
Normal file
255
internal/headertemplate/headertemplate.go
Normal file
|
@ -0,0 +1,255 @@
|
||||||
|
// Package headertemplate contains functions for rendering header templates.
|
||||||
|
package headertemplate
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// Render renders a header template string.
|
||||||
|
func Render(src string, fn func(ref []string) string) string {
|
||||||
|
p := newParser(src, fn)
|
||||||
|
return p.parse()
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a hand written parser attempting to model this peg grammar:
|
||||||
|
//
|
||||||
|
// Grammar <- ( Variable / Text )* !.
|
||||||
|
// Text <- .
|
||||||
|
// Variable <- EscapedVariable / SimpleVariable / ComplexVariable
|
||||||
|
// EscapedVariable <- '$' '$'
|
||||||
|
// SimpleVariable <- '$' SimpleExpression
|
||||||
|
// SimpleExpression <- identifier ( '.' identifier )*
|
||||||
|
// ComplexVariable <- '$' '{' _ ComplexExpression _ '}'
|
||||||
|
// ComplexExpression <- identifier _ (ComplexSelector / ComplexIndex)*
|
||||||
|
// ComplexSelector <- '.' _ ComplexExpression _
|
||||||
|
// ComplexIndex <- '[' _ StringLiteral _ ']' _
|
||||||
|
// StringLiteral <- '"' (('\\'.) / [^"])* '"'
|
||||||
|
// identifier <- [a-zA-Z0-9_] [a-zA-Z0-9_\-]*
|
||||||
|
// _ <- ( ' ' / '\t' )*
|
||||||
|
|
||||||
|
type parser struct {
|
||||||
|
buffer []byte
|
||||||
|
pos int
|
||||||
|
stack []int
|
||||||
|
visit func(ref []string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newParser(src string, visit func(ref []string) string) *parser {
|
||||||
|
return &parser{buffer: []byte(src), visit: visit}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) save() {
|
||||||
|
p.stack = append(p.stack, p.pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) restore() {
|
||||||
|
p.pos = p.stack[len(p.stack)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) pop() {
|
||||||
|
p.stack = p.stack[:len(p.stack)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) peek() byte {
|
||||||
|
if p.pos < len(p.buffer) {
|
||||||
|
return p.buffer[p.pos]
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) next() byte {
|
||||||
|
if p.pos < len(p.buffer) {
|
||||||
|
c := p.buffer[p.pos]
|
||||||
|
p.pos++
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parse() string {
|
||||||
|
var b strings.Builder
|
||||||
|
for p.pos < len(p.buffer) {
|
||||||
|
if v, ok := p.parseVariable(); ok {
|
||||||
|
b.WriteString(v)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteByte(p.next())
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseVariable() (string, bool) {
|
||||||
|
if p.peek() != '$' {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.save()
|
||||||
|
defer p.pop()
|
||||||
|
|
||||||
|
// $$ becomes $
|
||||||
|
p.next()
|
||||||
|
if p.peek() == '$' {
|
||||||
|
p.next()
|
||||||
|
return "$", true
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.peek() == '{' {
|
||||||
|
p.next()
|
||||||
|
e, ok := p.parseComplexExpression()
|
||||||
|
if !ok {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if p.next() != '}' {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return e, true
|
||||||
|
}
|
||||||
|
|
||||||
|
e, ok := p.parseSimpleExpression()
|
||||||
|
if !ok {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return e, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseComplexExpression() (string, bool) {
|
||||||
|
p.save()
|
||||||
|
defer p.pop()
|
||||||
|
|
||||||
|
p.skipWhitespace()
|
||||||
|
|
||||||
|
var ref []string
|
||||||
|
id, ok := p.parseIdentifier()
|
||||||
|
if !ok {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
ref = append(ref, id)
|
||||||
|
|
||||||
|
for {
|
||||||
|
p.skipWhitespace()
|
||||||
|
|
||||||
|
if p.peek() == '.' {
|
||||||
|
p.next()
|
||||||
|
p.skipWhitespace()
|
||||||
|
|
||||||
|
id, ok := p.parseIdentifier()
|
||||||
|
if !ok {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
ref = append(ref, id)
|
||||||
|
|
||||||
|
} else if p.peek() == '[' {
|
||||||
|
p.next()
|
||||||
|
p.skipWhitespace()
|
||||||
|
|
||||||
|
s, ok := p.parseString()
|
||||||
|
if !ok {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
ref = append(ref, s)
|
||||||
|
|
||||||
|
p.skipWhitespace()
|
||||||
|
if p.next() != ']' {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.visit(ref), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseString() (string, bool) {
|
||||||
|
p.save()
|
||||||
|
defer p.pop()
|
||||||
|
|
||||||
|
if p.next() != '"' {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
for {
|
||||||
|
c := p.next()
|
||||||
|
switch c {
|
||||||
|
case '"':
|
||||||
|
return b.String(), true
|
||||||
|
case 0:
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
case '\\':
|
||||||
|
c = p.next()
|
||||||
|
if c == 0 {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
b.WriteByte(c)
|
||||||
|
default:
|
||||||
|
b.WriteByte(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseSimpleExpression() (string, bool) {
|
||||||
|
p.save()
|
||||||
|
defer p.pop()
|
||||||
|
|
||||||
|
var ref []string
|
||||||
|
for {
|
||||||
|
id, ok := p.parseIdentifier()
|
||||||
|
if !ok {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
ref = append(ref, id)
|
||||||
|
|
||||||
|
if p.peek() != '.' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
p.next()
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.visit(ref), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseIdentifier() (string, bool) {
|
||||||
|
p.save()
|
||||||
|
defer p.pop()
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
for isIdentifierCharacter(p.peek()) {
|
||||||
|
b.WriteByte(p.next())
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.Len() == 0 {
|
||||||
|
p.restore()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) skipWhitespace() {
|
||||||
|
for isWhitespaceCharacter(p.peek()) {
|
||||||
|
p.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentifierCharacter(c byte) bool {
|
||||||
|
return (c >= '0' && c <= '9') ||
|
||||||
|
(c >= 'a' && c <= 'z') ||
|
||||||
|
(c >= 'A' && c <= 'Z') ||
|
||||||
|
c == '_' ||
|
||||||
|
c == '-'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWhitespaceCharacter(c byte) bool {
|
||||||
|
return c == ' ' || c == '\t'
|
||||||
|
}
|
61
internal/headertemplate/headertemplate_test.go
Normal file
61
internal/headertemplate/headertemplate_test.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package headertemplate_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/headertemplate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRender(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
in string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{"x $$ y $$ z", "x $ y $ z"},
|
||||||
|
{`${x.y.z}`, `<x,y,z>`},
|
||||||
|
{`${ x . y . z }`, `<x,y,z>`},
|
||||||
|
{`${x["y"].z}`, `<x,y,z>`},
|
||||||
|
{`${x["`, `${x["`},
|
||||||
|
{`${`, `${`},
|
||||||
|
{`${}`, `${}`},
|
||||||
|
{`${x["\\"]}`, `<x,\>`},
|
||||||
|
{`${x["\""]}`, `<x,">`},
|
||||||
|
|
||||||
|
{`${pomerium.access_token}`, `<pomerium,access_token>`},
|
||||||
|
{`$pomerium.access_token`, `<pomerium,access_token>`},
|
||||||
|
{`${pomerium.client_cert_fingerprint}`, `<pomerium,client_cert_fingerprint>`},
|
||||||
|
{`$pomerium.client_cert_fingerprint`, `<pomerium,client_cert_fingerprint>`},
|
||||||
|
{`${pomerium.id_token}`, `<pomerium,id_token>`},
|
||||||
|
{`$pomerium.id_token`, `<pomerium,id_token>`},
|
||||||
|
{`${pomerium.jwt}`, `<pomerium,jwt>`},
|
||||||
|
{`$pomerium.jwt`, `<pomerium,jwt>`},
|
||||||
|
{`${pomerium.request.headers["X-Access-Token"]}`, `<pomerium,request,headers,X-Access-Token>`},
|
||||||
|
{`$pomerium.request.headers.X-Access-Token`, `<pomerium,request,headers,X-Access-Token>`},
|
||||||
|
} {
|
||||||
|
actual := headertemplate.Render(tc.in, func(ref []string) string {
|
||||||
|
return "<" + strings.Join(ref, ",") + ">"
|
||||||
|
})
|
||||||
|
assert.Equal(t, tc.expect, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "x $ y $ z", headertemplate.Render("x $$ y $$ z", func(_ []string) string {
|
||||||
|
return ""
|
||||||
|
}))
|
||||||
|
assert.Equal(t, "before JWT after", headertemplate.Render("before $pomerium.jwt after", func(ref []string) string {
|
||||||
|
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
|
||||||
|
return "JWT"
|
||||||
|
}))
|
||||||
|
assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string {
|
||||||
|
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
|
||||||
|
return "JWT"
|
||||||
|
}))
|
||||||
|
assert.Equal(t, "before JWT after", headertemplate.Render("before ${ pomerium . jwt } after", func(ref []string) string {
|
||||||
|
assert.Equal(t, []string{"pomerium", "jwt"}, ref)
|
||||||
|
return "JWT"
|
||||||
|
}))
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue