diff --git a/config/session.go b/config/session.go index aec2997e7..439a72cfe 100644 --- a/config/session.go +++ b/config/session.go @@ -8,6 +8,7 @@ import ( "time" "github.com/google/uuid" + "golang.org/x/sync/singleflight" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/encoding" @@ -135,9 +136,10 @@ type IncomingIDPTokenSessionCreator interface { } type incomingIDPTokenSessionCreator struct { - timeNow func() time.Time - getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) - putRecords func(ctx context.Context, records []*databroker.Record) error + timeNow func() time.Time + getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) + putRecords func(ctx context.Context, records []*databroker.Record) error + singleflight singleflight.Group } func NewIncomingIDPTokenSessionCreator( @@ -179,41 +181,47 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( } sessionID := getAccessTokenSessionID(idp, rawAccessToken) - s, err := c.getSession(ctx, sessionID) - if err == nil { + res, err, _ := c.singleflight.Do(sessionID, func() (any, error) { + s, err := c.getSession(ctx, sessionID) + if err == nil { + return s, nil + } else if !storage.IsNotFound(err) { + return nil, err + } + + authenticateURL, transport, err := cfg.resolveAuthenticateURL() + if err != nil { + return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err) + } + + res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{ + AccessToken: rawAccessToken, + IdentityProviderID: idp.GetId(), + }) + if err != nil { + return nil, fmt.Errorf("error verifying access token: %w", err) + } else if !res.Valid { + return nil, fmt.Errorf("invalid access token") + } + + s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) + s.OauthToken = &session.OAuthToken{ + TokenType: "Bearer", + AccessToken: rawAccessToken, + ExpiresAt: s.ExpiresAt, + } + u := c.newUserFromIDPClaims(res.Claims) + err = c.putSessionAndUser(ctx, s, u) + if err != nil { + return nil, fmt.Errorf("error saving session and user: %w", err) + } + return s, nil - } else if !storage.IsNotFound(err) { - return nil, err - } - - authenticateURL, transport, err := cfg.resolveAuthenticateURL() - if err != nil { - return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err) - } - - res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{ - AccessToken: rawAccessToken, - IdentityProviderID: idp.GetId(), }) if err != nil { - return nil, fmt.Errorf("error verifying access token: %w", err) - } else if !res.Valid { - return nil, fmt.Errorf("invalid access token") + return nil, err } - - s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) - s.OauthToken = &session.OAuthToken{ - TokenType: "Bearer", - AccessToken: rawAccessToken, - ExpiresAt: s.ExpiresAt, - } - u := c.newUserFromIDPClaims(res.Claims) - err = c.putSessionAndUser(ctx, s, u) - if err != nil { - return nil, fmt.Errorf("error saving session and user: %w", err) - } - - return s, nil + return res.(*session.Session), nil } func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( @@ -228,37 +236,43 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( } sessionID := getIdentityTokenSessionID(idp, rawIdentityToken) - s, err := c.getSession(ctx, sessionID) - if err == nil { + res, err, _ := c.singleflight.Do(sessionID, func() (any, error) { + s, err := c.getSession(ctx, sessionID) + if err == nil { + return s, nil + } else if !storage.IsNotFound(err) { + return nil, err + } + + authenticateURL, transport, err := cfg.resolveAuthenticateURL() + if err != nil { + return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err) + } + + res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{ + IdentityToken: rawIdentityToken, + IdentityProviderID: idp.GetId(), + }) + if err != nil { + return nil, fmt.Errorf("error verifying identity token: %w", err) + } else if !res.Valid { + return nil, fmt.Errorf("invalid identity token") + } + + s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) + s.SetRawIDToken(rawIdentityToken) + u := c.newUserFromIDPClaims(res.Claims) + err = c.putSessionAndUser(ctx, s, u) + if err != nil { + return nil, fmt.Errorf("error saving session and user: %w", err) + } + return s, nil - } else if !storage.IsNotFound(err) { - return nil, err - } - - authenticateURL, transport, err := cfg.resolveAuthenticateURL() - if err != nil { - return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err) - } - - res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{ - IdentityToken: rawIdentityToken, - IdentityProviderID: idp.GetId(), }) if err != nil { - return nil, fmt.Errorf("error verifying identity token: %w", err) - } else if !res.Valid { - return nil, fmt.Errorf("invalid identity token") + return nil, err } - - s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) - s.SetRawIDToken(rawIdentityToken) - u := c.newUserFromIDPClaims(res.Claims) - err = c.putSessionAndUser(ctx, s, u) - if err != nil { - return nil, fmt.Errorf("error saving session and user: %w", err) - } - - return s, nil + return res.(*session.Session), nil } func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(