mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-25 23:17:18 +02:00
mcp: handle and pass upstream oauth2 tokens (#5595)
This commit is contained in:
parent
561b6040b5
commit
9d66f762e1
14 changed files with 337 additions and 80 deletions
|
@ -1,14 +1,19 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
)
|
||||
|
||||
type OAuth2Configs struct {
|
||||
|
@ -31,6 +36,26 @@ func NewOAuthConfig(
|
|||
}
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) CodeExchangeForHost(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
code string,
|
||||
) (*oauth2.Token, error) {
|
||||
r.buildOnce.Do(r.build)
|
||||
cfg, ok := r.perHost[host]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no oauth2 config for host %s", host)
|
||||
}
|
||||
|
||||
return cfg.Exchange(ctx, code)
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) HasConfigForHost(host string) bool {
|
||||
r.buildOnce.Do(r.build)
|
||||
_, ok := r.perHost[host]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
|
||||
r.buildOnce.Do(r.build)
|
||||
|
||||
|
@ -91,3 +116,25 @@ func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
|
|||
return oauth2.AuthStyleAutoDetect
|
||||
}
|
||||
}
|
||||
|
||||
func OAuth2TokenToPB(src *oauth2.Token) *oauth21proto.TokenResponse {
|
||||
return &oauth21proto.TokenResponse{
|
||||
AccessToken: src.AccessToken,
|
||||
TokenType: src.TokenType,
|
||||
RefreshToken: proto.String(src.RefreshToken),
|
||||
ExpiresIn: proto.Int64(src.ExpiresIn),
|
||||
}
|
||||
}
|
||||
|
||||
func PBToOAuth2Token(src *oauth21proto.TokenResponse, now time.Time) oauth2.Token {
|
||||
token := oauth2.Token{
|
||||
AccessToken: src.GetAccessToken(),
|
||||
TokenType: src.GetTokenType(),
|
||||
ExpiresIn: src.GetExpiresIn(),
|
||||
RefreshToken: src.GetRefreshToken(),
|
||||
}
|
||||
if token.ExpiresIn > 0 {
|
||||
token.Expiry = now.Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue