From 244d8a92602a284a1a69f70609cc60aefa028dd7 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 18 Feb 2025 08:24:11 -0700 Subject: [PATCH] make the session id per-idp --- config/session.go | 46 +++++++++++++++++++----------------------- config/session_test.go | 11 +++------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/config/session.go b/config/session.go index 7e6637212..aec2997e7 100644 --- a/config/session.go +++ b/config/session.go @@ -2,7 +2,6 @@ package config import ( "context" - "encoding/binary" "fmt" "net/http" "strings" @@ -22,6 +21,7 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/authenticateapi" "github.com/pomerium/pomerium/pkg/grpc/databroker" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpcutil" @@ -173,7 +173,12 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( policy *Policy, rawAccessToken string, ) (*session.Session, error) { - sessionID := getAccessTokenSessionID(policy, rawAccessToken) + idp, err := cfg.Options.GetIdentityProviderForPolicy(policy) + if err != nil { + return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err) + } + + sessionID := getAccessTokenSessionID(idp, rawAccessToken) s, err := c.getSession(ctx, sessionID) if err == nil { return s, nil @@ -181,11 +186,6 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( return nil, err } - idp, err := cfg.Options.GetIdentityProviderForPolicy(policy) - if err != nil { - return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err) - } - authenticateURL, transport, err := cfg.resolveAuthenticateURL() if err != nil { return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err) @@ -222,7 +222,12 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( policy *Policy, rawIdentityToken string, ) (*session.Session, error) { - sessionID := getIdentityTokenSessionID(policy, rawIdentityToken) + idp, err := cfg.Options.GetIdentityProviderForPolicy(policy) + if err != nil { + return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err) + } + + sessionID := getIdentityTokenSessionID(idp, rawIdentityToken) s, err := c.getSession(ctx, sessionID) if err == nil { return s, nil @@ -230,11 +235,6 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( return nil, err } - idp, err := cfg.Options.GetIdentityProviderForPolicy(policy) - if err != nil { - return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err) - } - authenticateURL, transport, err := cfg.resolveAuthenticateURL() if err != nil { return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err) @@ -417,26 +417,22 @@ func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http. var accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d") -func getAccessTokenSessionID(policy *Policy, rawAccessToken string) string { +func getAccessTokenSessionID(idp *identitypb.Provider, rawAccessToken string) string { namespace := accessTokenUUIDNamespace - // make the session ID per-route - if policy != nil { - var data [8]byte - binary.BigEndian.PutUint64(data[:], policy.MustRouteID()) - namespace = uuid.NewSHA1(namespace, data[:]) + // make the session ID per-idp settings + if idp != nil { + namespace = uuid.NewSHA1(namespace, []byte(idp.GetId())) } return uuid.NewSHA1(namespace, []byte(rawAccessToken)).String() } var identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17") -func getIdentityTokenSessionID(policy *Policy, rawIdentityToken string) string { +func getIdentityTokenSessionID(idp *identitypb.Provider, rawIdentityToken string) string { namespace := identityTokenUUIDNamespace - // make the session ID per-route - if policy != nil { - var data [8]byte - binary.BigEndian.PutUint64(data[:], policy.MustRouteID()) - namespace = uuid.NewSHA1(namespace, data[:]) + // make the session ID per-idp settings + if idp != nil { + namespace = uuid.NewSHA1(namespace, []byte(idp.GetId())) } return uuid.NewSHA1(namespace, []byte(rawIdentityToken)).String() } diff --git a/config/session_test.go b/config/session_test.go index 1880b5cdd..6ff7fc7f0 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -19,6 +19,7 @@ import ( "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/identity" @@ -177,14 +178,8 @@ func Test_getTokenSessionID(t *testing.T) { assert.Equal(t, "532b0a3d-b413-50a0-8c9f-e6eb340a05d3", getAccessTokenSessionID(nil, "TOKEN")) assert.Equal(t, "e0b8096c-54dd-5623-8098-5488f9c302db", getIdentityTokenSessionID(nil, "TOKEN")) - assert.Equal(t, "c58990ec-85d4-5054-b27f-e7c5d9c602c5", getAccessTokenSessionID(&Policy{ - From: "https://from.example.com", - Response: &DirectResponse{Status: 204}, - }, "TOKEN")) - assert.Equal(t, "4dff4540-493b-502a-bdec-2f346e6e480d", getIdentityTokenSessionID(&Policy{ - From: "https://from.example.com", - Response: &DirectResponse{Status: 204}, - }, "TOKEN")) + assert.Equal(t, "9c99d1d0-805e-51cb-b808-772ab654268b", getAccessTokenSessionID(&identitypb.Provider{Id: "IDP1"}, "TOKEN")) + assert.Equal(t, "0fe0e289-40bb-5ffe-b328-e290e043a652", getIdentityTokenSessionID(&identitypb.Provider{Id: "IDP1"}, "TOKEN")) } func TestGetIncomingIDPAccessTokenForPolicy(t *testing.T) {