mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-30 06:51:30 +02:00
## Summary Upstream OAuth2 providers may not return the refresh token at every access token renewal request, this PR ensures we do not accidentally overwrite the refresh token at hand with an empty string. ## Related issues Fix https://linear.app/pomerium/issue/ENG-2619/mcp-upstream-oauth2-google-drive-did-not-return-refresh-token ## User Explanation <!-- How would you explain this change to the user? If this change doesn't create any user-facing changes, you can leave this blank. If filled out, add the `docs` label --> ## Checklist - [x] reference any related issues - [ ] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review
87 lines
2.6 KiB
Go
87 lines
2.6 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/pomerium/pomerium/internal/oauth21"
|
|
)
|
|
|
|
func CheckPKCE(
|
|
codeChallengeMethod string,
|
|
codeChallenge string,
|
|
codeVerifier string,
|
|
) error {
|
|
if codeChallengeMethod == "" || codeChallengeMethod == "plain" {
|
|
if !oauth21.VerifyPKCEPlain(codeVerifier, codeChallenge) {
|
|
return fmt.Errorf("plain: code verifier does not match code challenge")
|
|
}
|
|
} else if codeChallengeMethod == "S256" {
|
|
if !oauth21.VerifyPKCES256(codeVerifier, codeChallenge) {
|
|
return fmt.Errorf("S256: code verifier does not match code challenge")
|
|
}
|
|
} else {
|
|
return fmt.Errorf("unsupported code challenge method: %s", codeChallengeMethod)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAccessTokenForSession returns an access token for a given session and expiration time.
|
|
func (srv *Handler) GetAccessTokenForSession(sessionID string, sessionExpiresAt time.Time) (string, error) {
|
|
return CreateCode(CodeTypeAccess, sessionID, sessionExpiresAt, "", srv.cipher)
|
|
}
|
|
|
|
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
|
|
func (srv *Handler) GetSessionIDFromAccessToken(accessToken string) (string, error) {
|
|
code, err := DecryptCode(CodeTypeAccess, accessToken, srv.cipher, "", time.Now())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return code.Id, nil
|
|
}
|
|
|
|
// GetUpstreamOAuth2Token retrieves the OAuth2 token for a given host and user ID.
|
|
// it also checks if the token is still valid and refreshes it if necessary.
|
|
func (srv *Handler) GetUpstreamOAuth2Token(
|
|
ctx context.Context,
|
|
host string,
|
|
userID string,
|
|
) (string, error) {
|
|
token, err, _ := srv.hostsSingleFlight.Do(host, func() (any, error) {
|
|
tokenPB, err := srv.storage.GetUpstreamOAuth2Token(ctx, host, userID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get upstream oauth2 token: %w", err)
|
|
}
|
|
|
|
cfg, ok := srv.hosts.GetOAuth2ConfigForHost(host)
|
|
if !ok {
|
|
return "", fmt.Errorf("no OAuth2 config found for host %s", host)
|
|
}
|
|
|
|
token, err := cfg.TokenSource(ctx, PBToOAuth2Token(tokenPB)).Token()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get OAuth2 token: %w", err)
|
|
}
|
|
|
|
if token.RefreshToken == "" {
|
|
token.RefreshToken = tokenPB.GetRefreshToken()
|
|
}
|
|
|
|
if token.AccessToken != tokenPB.GetAccessToken() ||
|
|
token.RefreshToken != tokenPB.GetRefreshToken() {
|
|
err = srv.storage.StoreUpstreamOAuth2Token(ctx, host, userID, OAuth2TokenToPB(token))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to store updated upstream oauth2 token: %w", err)
|
|
}
|
|
}
|
|
|
|
return token.AccessToken, nil
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return token.(string), nil
|
|
}
|