From 84fcf69bc94c36b0b7dd04920fcc4e3b3bf06878 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 17 Feb 2025 16:51:16 -0700 Subject: [PATCH] add test --- config/session.go | 12 +++--- config/session_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 6 deletions(-) diff --git a/config/session.go b/config/session.go index 574e7b464..1c8db4acd 100644 --- a/config/session.go +++ b/config/session.go @@ -361,18 +361,18 @@ func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Re if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" { prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " " if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { - return strings.TrimPrefix(auth, prefix), true + return auth[len(prefix):], true } prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-" if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { - return strings.TrimPrefix(auth, prefix), true + return auth[len(prefix):], true } prefix = "Bearer " if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPAccessToken { - return strings.TrimPrefix(auth, prefix), true + return auth[len(prefix):], true } } @@ -396,18 +396,18 @@ func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http. if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" { prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " " if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { - return strings.TrimPrefix(auth, prefix), true + return auth[len(prefix):], true } prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-" if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { - return strings.TrimPrefix(auth, prefix), true + return auth[len(prefix):], true } prefix = "Bearer " if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPIdentityToken { - return strings.TrimPrefix(auth, prefix), true + return auth[len(prefix):], true } } diff --git a/config/session_test.go b/config/session_test.go index 8ee9b5fe6..2e23968aa 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -179,3 +179,86 @@ func Test_getTokenSessionID(t *testing.T) { Response: &DirectResponse{Status: 204}, }, "TOKEN")) } + +func TestGetIncomingIDPIdentityTokenForPolicy(t *testing.T) { + t.Parallel() + + bearerTokenFormatIDPIdentityToken := BearerTokenFormatIDPIdentityToken + + for _, tc := range []struct { + name string + globalBearerTokenFormat *BearerTokenFormat + routeBearerTokenFormat *BearerTokenFormat + headers http.Header + expectedOK bool + expectedToken string + }{ + { + name: "empty headers", + expectedOK: false, + }, + { + name: "custom header", + headers: http.Header{"X-Pomerium-Idp-Identity-Token": {"identity token via custom header"}}, + expectedOK: true, + expectedToken: "identity token via custom header", + }, + { + name: "custom authorization", + headers: http.Header{"Authorization": {"Pomerium-Idp-Identity-Token identity token via custom authorization"}}, + expectedOK: true, + expectedToken: "identity token via custom authorization", + }, + { + name: "custom bearer", + headers: http.Header{"Authorization": {"Bearer Pomerium-Idp-Identity-Token-identity token via custom bearer"}}, + expectedOK: true, + expectedToken: "identity token via custom bearer", + }, + { + name: "bearer disabled", + headers: http.Header{"Authorization": {"Bearer identity token via bearer"}}, + expectedOK: false, + }, + { + name: "bearer enabled via options", + globalBearerTokenFormat: &bearerTokenFormatIDPIdentityToken, + headers: http.Header{"Authorization": {"Bearer identity token via bearer"}}, + expectedOK: true, + expectedToken: "identity token via bearer", + }, + { + name: "bearer enabled via route", + routeBearerTokenFormat: &bearerTokenFormatIDPIdentityToken, + headers: http.Header{"Authorization": {"Bearer identity token via bearer"}}, + expectedOK: true, + expectedToken: "identity token via bearer", + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Options: NewDefaultOptions(), + } + cfg.Options.BearerTokenFormat = tc.globalBearerTokenFormat + + var route *Policy + if tc.routeBearerTokenFormat != nil { + route = &Policy{ + BearerTokenFormat: tc.routeBearerTokenFormat, + } + } + + r, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + if tc.headers != nil { + r.Header = tc.headers + } + + actualToken, actualOK := cfg.GetIncomingIDPIdentityTokenForPolicy(route, r) + assert.Equal(t, tc.expectedOK, actualOK) + assert.Equal(t, tc.expectedToken, actualToken) + }) + } +}