singleflight incoming idp token session creation (#5491)

This commit is contained in:
Caleb Doxsey 2025-02-24 08:24:57 -07:00 committed by GitHub
parent 4b95eda51e
commit f15400493d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/google/uuid"
"golang.org/x/sync/singleflight"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/encoding"
@ -135,9 +136,10 @@ type IncomingIDPTokenSessionCreator interface {
}
type incomingIDPTokenSessionCreator struct {
timeNow func() time.Time
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
putRecords func(ctx context.Context, records []*databroker.Record) error
timeNow func() time.Time
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
putRecords func(ctx context.Context, records []*databroker.Record) error
singleflight singleflight.Group
}
func NewIncomingIDPTokenSessionCreator(
@ -179,41 +181,47 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
}
sessionID := getAccessTokenSessionID(idp, rawAccessToken)
s, err := c.getSession(ctx, sessionID)
if err == nil {
res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
AccessToken: rawAccessToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid access token")
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
ExpiresAt: s.ExpiresAt,
}
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
AccessToken: rawAccessToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid access token")
return nil, err
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
ExpiresAt: s.ExpiresAt,
}
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
return res.(*session.Session), nil
}
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
@ -228,37 +236,43 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
}
sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
s, err := c.getSession(ctx, sessionID)
if err == nil {
res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
IdentityToken: rawIdentityToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid identity token")
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
IdentityToken: rawIdentityToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid identity token")
return nil, err
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
return res.(*session.Session), nil
}
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(