From 7e3db1a39fbac577fd8be2a106bfec50b99202c7 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 17 Feb 2025 16:29:06 -0700 Subject: [PATCH] make session ids route-specific --- config/session.go | 36 +++++++++++++++++++++++++++++------- config/session_test.go | 15 +++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/config/session.go b/config/session.go index fef0040e1..574e7b464 100644 --- a/config/session.go +++ b/config/session.go @@ -2,6 +2,7 @@ package config import ( "context" + "encoding/binary" "fmt" "net/http" "strings" @@ -129,11 +130,6 @@ func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v return store.store.SaveSession(w, r, v) } -var ( - accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d") - identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17") -) - type IncomingIDPTokenSessionCreator interface { CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error) } @@ -176,7 +172,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( policy *Policy, rawAccessToken string, ) (*session.Session, error) { - sessionID := uuid.NewSHA1(accessTokenUUIDNamespace, []byte(rawAccessToken)).String() + sessionID := getAccessTokenSessionID(policy, rawAccessToken) s, err := c.getSession(ctx, sessionID) if err == nil { return s, nil @@ -225,7 +221,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( policy *Policy, rawIdentityToken string, ) (*session.Session, error) { - sessionID := uuid.NewSHA1(identityTokenUUIDNamespace, []byte(rawIdentityToken)).String() + sessionID := getIdentityTokenSessionID(policy, rawIdentityToken) s, err := c.getSession(ctx, sessionID) if err == nil { return s, nil @@ -417,3 +413,29 @@ func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http. return "", false } + +var accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d") + +func getAccessTokenSessionID(policy *Policy, rawAccessToken string) string { + namespace := accessTokenUUIDNamespace + // make the session ID per-route + if policy != nil { + var data [8]byte + binary.BigEndian.PutUint64(data[:], policy.MustRouteID()) + namespace = uuid.NewSHA1(namespace, data[:]) + } + return uuid.NewSHA1(namespace, []byte(rawAccessToken)).String() +} + +var identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17") + +func getIdentityTokenSessionID(policy *Policy, rawIdentityToken string) string { + namespace := identityTokenUUIDNamespace + // make the session ID per-route + if policy != nil { + var data [8]byte + binary.BigEndian.PutUint64(data[:], policy.MustRouteID()) + namespace = uuid.NewSHA1(namespace, data[:]) + } + return uuid.NewSHA1(namespace, []byte(rawIdentityToken)).String() +} diff --git a/config/session_test.go b/config/session_test.go index 936b45c7c..3204b1fdc 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -164,3 +164,18 @@ func TestGetIdentityProviderDetectsChangesToAuthenticateServiceURL(t *testing.T) assert.NotEqual(t, idp1.GetId(), idp2.GetId(), "identity provider should change when authenticate service url changes") } + +func Test_getAccessTokenSessionID(t *testing.T) { + t.Parallel() + + assert.Equal(t, "532b0a3d-b413-50a0-8c9f-e6eb340a05d3", getAccessTokenSessionID(nil, "TOKEN")) + assert.Equal(t, "e0b8096c-54dd-5623-8098-5488f9c302db", getIdentityTokenSessionID(nil, "TOKEN")) + assert.Equal(t, "c58990ec-85d4-5054-b27f-e7c5d9c602c5", getAccessTokenSessionID(&Policy{ + From: "https://from.example.com", + Response: &DirectResponse{Status: 204}, + }, "TOKEN")) + assert.Equal(t, "4dff4540-493b-502a-bdec-2f346e6e480d", getIdentityTokenSessionID(&Policy{ + From: "https://from.example.com", + Response: &DirectResponse{Status: 204}, + }, "TOKEN")) +}