pomerium/internal/mcp/token.go
Denis Mishin f6ddb8878d
mcp: if upstream oauth does not return a refresh token, keep previous (#5738)
## Summary

Upstream OAuth2 providers may not return the refresh token at every
access token renewal request,
this PR ensures we do not accidentally overwrite the refresh token at
hand with an empty string.

## Related issues

Fix
https://linear.app/pomerium/issue/ENG-2619/mcp-upstream-oauth2-google-drive-did-not-return-refresh-token

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [x] reference any related issues
- [ ] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
2025-07-21 21:10:32 -04:00

87 lines
2.6 KiB
Go

package mcp
import (
"context"
"fmt"
"time"
"github.com/pomerium/pomerium/internal/oauth21"
)
func CheckPKCE(
codeChallengeMethod string,
codeChallenge string,
codeVerifier string,
) error {
if codeChallengeMethod == "" || codeChallengeMethod == "plain" {
if !oauth21.VerifyPKCEPlain(codeVerifier, codeChallenge) {
return fmt.Errorf("plain: code verifier does not match code challenge")
}
} else if codeChallengeMethod == "S256" {
if !oauth21.VerifyPKCES256(codeVerifier, codeChallenge) {
return fmt.Errorf("S256: code verifier does not match code challenge")
}
} else {
return fmt.Errorf("unsupported code challenge method: %s", codeChallengeMethod)
}
return nil
}
// GetAccessTokenForSession returns an access token for a given session and expiration time.
func (srv *Handler) GetAccessTokenForSession(sessionID string, sessionExpiresAt time.Time) (string, error) {
return CreateCode(CodeTypeAccess, sessionID, sessionExpiresAt, "", srv.cipher)
}
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
func (srv *Handler) GetSessionIDFromAccessToken(accessToken string) (string, error) {
code, err := DecryptCode(CodeTypeAccess, accessToken, srv.cipher, "", time.Now())
if err != nil {
return "", err
}
return code.Id, nil
}
// GetUpstreamOAuth2Token retrieves the OAuth2 token for a given host and user ID.
// it also checks if the token is still valid and refreshes it if necessary.
func (srv *Handler) GetUpstreamOAuth2Token(
ctx context.Context,
host string,
userID string,
) (string, error) {
token, err, _ := srv.hostsSingleFlight.Do(host, func() (any, error) {
tokenPB, err := srv.storage.GetUpstreamOAuth2Token(ctx, host, userID)
if err != nil {
return "", fmt.Errorf("failed to get upstream oauth2 token: %w", err)
}
cfg, ok := srv.hosts.GetOAuth2ConfigForHost(host)
if !ok {
return "", fmt.Errorf("no OAuth2 config found for host %s", host)
}
token, err := cfg.TokenSource(ctx, PBToOAuth2Token(tokenPB)).Token()
if err != nil {
return "", fmt.Errorf("failed to get OAuth2 token: %w", err)
}
if token.RefreshToken == "" {
token.RefreshToken = tokenPB.GetRefreshToken()
}
if token.AccessToken != tokenPB.GetAccessToken() ||
token.RefreshToken != tokenPB.GetRefreshToken() {
err = srv.storage.StoreUpstreamOAuth2Token(ctx, host, userID, OAuth2TokenToPB(token))
if err != nil {
return "", fmt.Errorf("failed to store updated upstream oauth2 token: %w", err)
}
}
return token.AccessToken, nil
})
if err != nil {
return "", err
}
return token.(string), nil
}