singleflight incoming idp token session creation (#5491)

This commit is contained in:
Caleb Doxsey 2025-02-24 08:24:57 -07:00 committed by GitHub
parent 4b95eda51e
commit f15400493d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/sync/singleflight"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
@ -135,9 +136,10 @@ type IncomingIDPTokenSessionCreator interface {
} }
type incomingIDPTokenSessionCreator struct { type incomingIDPTokenSessionCreator struct {
timeNow func() time.Time timeNow func() time.Time
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
putRecords func(ctx context.Context, records []*databroker.Record) error putRecords func(ctx context.Context, records []*databroker.Record) error
singleflight singleflight.Group
} }
func NewIncomingIDPTokenSessionCreator( func NewIncomingIDPTokenSessionCreator(
@ -179,41 +181,47 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
} }
sessionID := getAccessTokenSessionID(idp, rawAccessToken) sessionID := getAccessTokenSessionID(idp, rawAccessToken)
s, err := c.getSession(ctx, sessionID) res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
if err == nil { s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
AccessToken: rawAccessToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid access token")
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
ExpiresAt: s.ExpiresAt,
}
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
AccessToken: rawAccessToken,
IdentityProviderID: idp.GetId(),
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err) return nil, err
} else if !res.Valid {
return nil, fmt.Errorf("invalid access token")
} }
return res.(*session.Session), nil
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
ExpiresAt: s.ExpiresAt,
}
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
} }
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
@ -228,37 +236,43 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
} }
sessionID := getIdentityTokenSessionID(idp, rawIdentityToken) sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
s, err := c.getSession(ctx, sessionID) res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
if err == nil { s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
IdentityToken: rawIdentityToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("invalid identity token")
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
IdentityToken: rawIdentityToken,
IdentityProviderID: idp.GetId(),
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err) return nil, err
} else if !res.Valid {
return nil, fmt.Errorf("invalid identity token")
} }
return res.(*session.Session), nil
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u := c.newUserFromIDPClaims(res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
} }
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims( func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(