mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 10:56:28 +02:00
singleflight incoming idp token session creation (#5491)
This commit is contained in:
parent
4b95eda51e
commit
f15400493d
1 changed files with 75 additions and 61 deletions
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
@ -135,9 +136,10 @@ type IncomingIDPTokenSessionCreator interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type incomingIDPTokenSessionCreator struct {
|
type incomingIDPTokenSessionCreator struct {
|
||||||
timeNow func() time.Time
|
timeNow func() time.Time
|
||||||
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
|
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
|
||||||
putRecords func(ctx context.Context, records []*databroker.Record) error
|
putRecords func(ctx context.Context, records []*databroker.Record) error
|
||||||
|
singleflight singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIncomingIDPTokenSessionCreator(
|
func NewIncomingIDPTokenSessionCreator(
|
||||||
|
@ -179,41 +181,47 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID := getAccessTokenSessionID(idp, rawAccessToken)
|
sessionID := getAccessTokenSessionID(idp, rawAccessToken)
|
||||||
s, err := c.getSession(ctx, sessionID)
|
res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
|
||||||
if err == nil {
|
s, err := c.getSession(ctx, sessionID)
|
||||||
|
if err == nil {
|
||||||
|
return s, nil
|
||||||
|
} else if !storage.IsNotFound(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
|
||||||
|
AccessToken: rawAccessToken,
|
||||||
|
IdentityProviderID: idp.GetId(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||||
|
} else if !res.Valid {
|
||||||
|
return nil, fmt.Errorf("invalid access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||||
|
s.OauthToken = &session.OAuthToken{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
AccessToken: rawAccessToken,
|
||||||
|
ExpiresAt: s.ExpiresAt,
|
||||||
|
}
|
||||||
|
u := c.newUserFromIDPClaims(res.Claims)
|
||||||
|
err = c.putSessionAndUser(ctx, s, u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
} else if !storage.IsNotFound(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
|
|
||||||
AccessToken: rawAccessToken,
|
|
||||||
IdentityProviderID: idp.GetId(),
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
return nil, err
|
||||||
} else if !res.Valid {
|
|
||||||
return nil, fmt.Errorf("invalid access token")
|
|
||||||
}
|
}
|
||||||
|
return res.(*session.Session), nil
|
||||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
|
||||||
s.OauthToken = &session.OAuthToken{
|
|
||||||
TokenType: "Bearer",
|
|
||||||
AccessToken: rawAccessToken,
|
|
||||||
ExpiresAt: s.ExpiresAt,
|
|
||||||
}
|
|
||||||
u := c.newUserFromIDPClaims(res.Claims)
|
|
||||||
err = c.putSessionAndUser(ctx, s, u)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
||||||
|
@ -228,37 +236,43 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
|
sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
|
||||||
s, err := c.getSession(ctx, sessionID)
|
res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
|
||||||
if err == nil {
|
s, err := c.getSession(ctx, sessionID)
|
||||||
|
if err == nil {
|
||||||
|
return s, nil
|
||||||
|
} else if !storage.IsNotFound(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
|
||||||
|
IdentityToken: rawIdentityToken,
|
||||||
|
IdentityProviderID: idp.GetId(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
||||||
|
} else if !res.Valid {
|
||||||
|
return nil, fmt.Errorf("invalid identity token")
|
||||||
|
}
|
||||||
|
|
||||||
|
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||||
|
s.SetRawIDToken(rawIdentityToken)
|
||||||
|
u := c.newUserFromIDPClaims(res.Claims)
|
||||||
|
err = c.putSessionAndUser(ctx, s, u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
} else if !storage.IsNotFound(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
|
|
||||||
IdentityToken: rawIdentityToken,
|
|
||||||
IdentityProviderID: idp.GetId(),
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
return nil, err
|
||||||
} else if !res.Valid {
|
|
||||||
return nil, fmt.Errorf("invalid identity token")
|
|
||||||
}
|
}
|
||||||
|
return res.(*session.Session), nil
|
||||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
|
||||||
s.SetRawIDToken(rawIdentityToken)
|
|
||||||
u := c.newUserFromIDPClaims(res.Claims)
|
|
||||||
err = c.putSessionAndUser(ctx, s, u)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
||||||
|
|
Loading…
Add table
Reference in a new issue