diff --git a/config/session.go b/config/session.go index 6f3b36ccb..f9a494005 100644 --- a/config/session.go +++ b/config/session.go @@ -69,14 +69,16 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e } // confirm that the identity provider id matches the state - idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) - if err != nil { - return nil, err - } + if state.IdentityProviderID != "" { + idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) + if err != nil { + return nil, err + } - if idp.GetId() != state.IdentityProviderID { - return nil, fmt.Errorf("unexpected session state identity provider id: %s != %s", - idp.GetId(), state.IdentityProviderID) + if idp.GetId() != state.IdentityProviderID { + return nil, fmt.Errorf("unexpected session state identity provider id: %s != %s", + idp.GetId(), state.IdentityProviderID) + } } return &state, nil diff --git a/config/session_test.go b/config/session_test.go index 6a1471b26..058850b86 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -125,4 +125,20 @@ func TestSessionStore_LoadSessionState(t *testing.T) { assert.Error(t, err) assert.Nil(t, s) }) + t.Run("blank idp", func(t *testing.T) { + rawJWS := makeJWS(t, &sessions.State{ + Issuer: "authenticate.example.com", + ID: "example", + }) + + r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) + require.NoError(t, err) + r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) + s, err := store.LoadSessionState(r) + assert.NoError(t, err) + assert.Empty(t, cmp.Diff(&sessions.State{ + Issuer: "authenticate.example.com", + ID: "example", + }, s)) + }) }