package config import ( "encoding/base64" "net/http" "net/url" "testing" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" ) func TestSessionStore_LoadSessionState(t *testing.T) { t.Parallel() sharedKey := cryptutil.NewKey() options := NewDefaultOptions() options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey) options.Provider = "oidc" options.ProviderURL = "https://oidc.example.com" options.ClientID = "client_id" options.ClientSecret = "client_secret" options.Policies = append(options.Policies, Policy{ From: "https://p1.example.com", To: mustParseWeightedURLs(t, "https://p1"), IDPClientID: "client_id_1", IDPClientSecret: "client_secret_1", }, Policy{ From: "https://p2.example.com", To: mustParseWeightedURLs(t, "https://p2"), IDPClientID: "client_id_2", IDPClientSecret: "client_secret_2", }) require.NoError(t, options.Validate()) store, err := NewSessionStore(options) require.NoError(t, err) idp1, err := options.GetIdentityProviderForPolicy(nil) require.NoError(t, err) require.NotNil(t, idp1) idp2, err := options.GetIdentityProviderForPolicy(&options.Policies[0]) require.NoError(t, err) require.NotNil(t, idp2) idp3, err := options.GetIdentityProviderForPolicy(&options.Policies[1]) require.NoError(t, err) require.NotNil(t, idp3) makeJWS := func(t *testing.T, state *sessions.State) string { e, err := jws.NewHS256Signer(sharedKey) require.NoError(t, err) rawJWS, err := e.Marshal(state) require.NoError(t, err) return string(rawJWS) } t.Run("mssing", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil) require.NoError(t, err) s, err := store.LoadSessionStateAndCheckIDP(r) assert.ErrorIs(t, err, sessions.ErrNoSessionFound) assert.Nil(t, s) }) t.Run("query", func(t *testing.T) { rawJWS := makeJWS(t, &sessions.State{ Issuer: "authenticate.example.com", ID: "example", IdentityProviderID: idp2.GetId(), }) r, err := http.NewRequest(http.MethodGet, "https://p1.example.com?"+url.Values{ urlutil.QuerySession: {rawJWS}, }.Encode(), nil) require.NoError(t, err) s, err := store.LoadSessionStateAndCheckIDP(r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", ID: "example", IdentityProviderID: idp2.GetId(), }, s)) }) t.Run("header", func(t *testing.T) { rawJWS := makeJWS(t, &sessions.State{ Issuer: "authenticate.example.com", ID: "example", IdentityProviderID: idp3.GetId(), }) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) s, err := store.LoadSessionStateAndCheckIDP(r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", ID: "example", IdentityProviderID: idp3.GetId(), }, s)) }) t.Run("wrong idp", func(t *testing.T) { rawJWS := makeJWS(t, &sessions.State{ Issuer: "authenticate.example.com", ID: "example", IdentityProviderID: idp1.GetId(), }) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) s, err := store.LoadSessionStateAndCheckIDP(r) assert.Error(t, err) assert.Nil(t, s) }) t.Run("blank idp", func(t *testing.T) { rawJWS := makeJWS(t, &sessions.State{ Issuer: "authenticate.example.com", ID: "example", }) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) s, err := store.LoadSessionStateAndCheckIDP(r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", ID: "example", }, s)) }) } func TestGetIdentityProviderDetectsChangesToAuthenticateServiceURL(t *testing.T) { t.Parallel() options := NewDefaultOptions() options.AuthenticateURLString = "https://authenticate.example.com" options.Provider = "oidc" options.ProviderURL = "https://oidc.example.com" options.ClientID = "client_id" options.ClientSecret = "client_secret" idp1, err := options.GetIdentityProviderForPolicy(nil) require.NoError(t, err) options.AuthenticateURLString = "" idp2, err := options.GetIdentityProviderForPolicy(nil) require.NoError(t, err) assert.NotEqual(t, idp1.GetId(), idp2.GetId(), "identity provider should change when authenticate service url changes") }