diff --git a/config/envoyconfig/http_connection_manager.go b/config/envoyconfig/http_connection_manager.go index 5d956b373..14dbbff50 100644 --- a/config/envoyconfig/http_connection_manager.go +++ b/config/envoyconfig/http_connection_manager.go @@ -55,13 +55,19 @@ func (b *Builder) buildLocalReplyConfig( headers = toEnvoyHeaders(options.GetSetResponseHeaders()) } - data := map[string]any{ - "status": "%RESPONSE_CODE%", - "statusText": "%RESPONSE_CODE_DETAILS%", - "requestId": "%STREAM_ID%", - "responseFlags": "%RESPONSE_FLAGS%", - } + data := make(map[string]any) httputil.AddBrandingOptionsToMap(data, options.BrandingOptions) + for k, v := range data { + // Escape any % signs in the branding options data, as Envoy will + // interpret the page output as a substitution format string. + if s, ok := v.(string); ok { + data[k] = strings.ReplaceAll(s, "%", "%%") + } + } + data["status"] = "%RESPONSE_CODE%" + data["statusText"] = "%RESPONSE_CODE_DETAILS%" + data["requestId"] = "%STREAM_ID%" + data["responseFlags"] = "%RESPONSE_FLAGS%" bs, err := ui.RenderPage("Error", "Error", data) if err != nil { diff --git a/config/envoyconfig/http_connection_manager_test.go b/config/envoyconfig/http_connection_manager_test.go new file mode 100644 index 000000000..e272bdc73 --- /dev/null +++ b/config/envoyconfig/http_connection_manager_test.go @@ -0,0 +1,64 @@ +package envoyconfig + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/config" + configpb "github.com/pomerium/pomerium/pkg/grpc/config" +) + +func Test_buildLocalReplyConfig(t *testing.T) { + b := Builder{} + opts := config.NewDefaultOptions() + opts.BrandingOptions = &configpb.Settings{ + LogoUrl: proto.String("http://example.com/my%20branding%20logo.png"), + ErrorMessageFirstParagraph: proto.String("It's 100% broken."), + } + lrc, err := b.buildLocalReplyConfig(opts) + require.NoError(t, err) + tmpl := string(lrc.Mappers[0].GetBodyFormatOverride().GetTextFormatSource().GetInlineBytes()) + assert.Equal(t, ` + +
+ + + + + + +