mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 14:37:12 +02:00
mcp: handle and pass upstream oauth2 tokens (#5595)
This commit is contained in:
parent
561b6040b5
commit
9d66f762e1
14 changed files with 337 additions and 80 deletions
|
@ -1,8 +1,7 @@
|
|||
package evaluator
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
)
|
||||
|
@ -19,7 +18,7 @@ type evaluatorConfig struct {
|
|||
JWTClaimsHeaders config.JWTClaimHeaders
|
||||
JWTGroupsFilter config.JWTGroupsFilter
|
||||
DefaultJWTIssuerFormat config.JWTIssuerFormat
|
||||
MCPAccessTokenProvider func(string, time.Time) (string, error) `hash:"-"`
|
||||
MCPAccessTokenProvider store.MCPAccessTokenProvider `hash:"-"`
|
||||
}
|
||||
|
||||
// cacheKey() returns a hash over the configuration, except for the policies.
|
||||
|
@ -118,7 +117,7 @@ func WithDefaultJWTIssuerFormat(format config.JWTIssuerFormat) Option {
|
|||
}
|
||||
|
||||
// WithMCPAccessTokenProvider sets the MCP access token in the config.
|
||||
func WithMCPAccessTokenProvider(fn func(sessionID string, expires time.Time) (string, error)) Option {
|
||||
func WithMCPAccessTokenProvider(fn store.MCPAccessTokenProvider) Option {
|
||||
return func(cfg *evaluatorConfig) {
|
||||
cfg.MCPAccessTokenProvider = fn
|
||||
}
|
||||
|
|
|
@ -93,27 +93,40 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillMCPHeaders() error {
|
||||
if e.request == nil ||
|
||||
e.request.Policy == nil ||
|
||||
e.request.Policy.MCP == nil ||
|
||||
!e.request.Policy.MCP.PassUpstreamAccessToken {
|
||||
func (e *headersEvaluatorEvaluation) fillMCPHeaders(ctx context.Context) (err error) {
|
||||
if e.request == nil || e.request.Policy == nil || e.request.Policy.MCP == nil || e.request.Session.ID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.request.Session.ID == "" {
|
||||
p := e.evaluator.store.GetMCPAccessTokenProvider()
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
accessToken, err := e.evaluator.store.GetMCPAccessTokenProvider()(e.request.Session.ID, e.now.Add(time.Hour))
|
||||
var accessToken string
|
||||
if e.request.Policy.MCP.IsUpstreamClientNeedsAccessToken() {
|
||||
accessToken, err = p.GetAccessTokenForSession(e.request.Session.ID, time.Now().Add(5*time.Minute))
|
||||
if err != nil {
|
||||
return fmt.Errorf("authorize/header-evaluator: error getting MCP access token: %w", err)
|
||||
}
|
||||
|
||||
e.response.Headers.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.request.Policy.MCP.HasUpstreamOAuth2() {
|
||||
user := e.getUser(ctx)
|
||||
accessToken, err = p.GetUpstreamOAuth2Token(ctx, e.request.HTTP.Host, user.Id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authorize/header-evaluator: error getting upstream oauth2 token: %w", err)
|
||||
}
|
||||
e.response.Headers.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
e.response.Headers.Del("Authorization")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
|
||||
if e.request.Policy == nil {
|
||||
return
|
||||
|
@ -198,7 +211,7 @@ func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error {
|
|||
if err := e.fillJWTClaimHeaders(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
err := e.fillMCPHeaders()
|
||||
err := e.fillMCPHeaders(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -37,7 +37,17 @@ type Store struct {
|
|||
defaultJWTIssuerFormat atomic.Pointer[config.JWTIssuerFormat]
|
||||
signingKey atomic.Pointer[jose.JSONWebKey]
|
||||
|
||||
mcpAccessTokenProvider atomicutil.Value[func(string, time.Time) (string, error)]
|
||||
mcpAccessTokenProvider atomicutil.Value[MCPAccessTokenProvider]
|
||||
}
|
||||
|
||||
type MCPAccessTokenProvider interface {
|
||||
// GetAccessToken returns an access token for the given session ID and expiration time,
|
||||
// that may be upsed by the MCP client to interact with the MCP servers fronted by Pomerium.
|
||||
GetAccessTokenForSession(sessionID string, expiresAt time.Time) (string, error)
|
||||
|
||||
// GetUpstreamOAuth2Token returns an upstream OAuth2 token for the given host and session ID
|
||||
// that is used by the MCP server to interact with the upstream APIs.
|
||||
GetUpstreamOAuth2Token(ctx context.Context, host, sessionID string) (string, error)
|
||||
}
|
||||
|
||||
// New creates a new Store.
|
||||
|
@ -81,13 +91,8 @@ func (s *Store) GetSigningKey() *jose.JSONWebKey {
|
|||
return s.signingKey.Load()
|
||||
}
|
||||
|
||||
func (s *Store) GetMCPAccessTokenProvider() func(string, time.Time) (string, error) {
|
||||
if f := s.mcpAccessTokenProvider.Load(); f != nil {
|
||||
return f
|
||||
}
|
||||
return func(string, time.Time) (string, error) {
|
||||
return "", fmt.Errorf("no mcp access token provider")
|
||||
}
|
||||
func (s *Store) GetMCPAccessTokenProvider() MCPAccessTokenProvider {
|
||||
return s.mcpAccessTokenProvider.Load()
|
||||
}
|
||||
|
||||
// UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication
|
||||
|
@ -127,10 +132,12 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
|||
s.signingKey.Store(signingKey)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateMCPAccessTokenProvider(mcpAccessTokenProvider func(string, time.Time) (string, error)) {
|
||||
func (s *Store) UpdateMCPAccessTokenProvider(mcpAccessTokenProvider MCPAccessTokenProvider) {
|
||||
// This isn't used by the Rego code, so we don't need to write it to the opastorage.Store instance.
|
||||
if mcpAccessTokenProvider != nil {
|
||||
s.mcpAccessTokenProvider.Store(mcpAccessTokenProvider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) write(rawPath string, value any) {
|
||||
ctx := context.TODO()
|
||||
|
|
|
@ -65,7 +65,7 @@ func newAuthorizeStateFromConfig(
|
|||
}
|
||||
state.mcp = mcp
|
||||
|
||||
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp.CreateAccessTokenForSession))
|
||||
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
||||
}
|
||||
|
|
|
@ -215,6 +215,16 @@ type MCP struct {
|
|||
PassUpstreamAccessToken bool `mapstructure:"pass_upstream_access_token" yaml:"pass_upstream_access_token,omitempty" json:"pass_upstream_access_token,omitempty"`
|
||||
}
|
||||
|
||||
// HasUpstreamOAuth2 checks if the route is for the MCP Server and if it has an upstream OAuth2 configuration
|
||||
func (p *MCP) HasUpstreamOAuth2() bool {
|
||||
return p != nil && p.UpstreamOAuth2 != nil
|
||||
}
|
||||
|
||||
// IsUpstreamClientNeedsAccessToken checks if the route is for the MCP Client and if it needs to pass the upstream access token
|
||||
func (p *MCP) IsUpstreamClientNeedsAccessToken() bool {
|
||||
return p != nil && p.UpstreamOAuth2 != nil && p.PassUpstreamAccessToken
|
||||
}
|
||||
|
||||
type UpstreamOAuth2 struct {
|
||||
ClientID string `mapstructure:"client_id" yaml:"client_id,omitempty" json:"client_id,omitempty"`
|
||||
ClientSecret string `mapstructure:"client_secret" yaml:"client_secret,omitempty" json:"client_secret,omitempty"`
|
||||
|
|
|
@ -8,7 +8,9 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/bufbuild/protovalidate-go"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
|
@ -27,19 +29,38 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
ctx := r.Context()
|
||||
|
||||
sessionID, err := getSessionFromRequest(r)
|
||||
claims, err := getClaimsFromRequest(r)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to get claims from request")
|
||||
http.Error(w, "invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionID, ok := getSessionIDFromClaims(claims)
|
||||
if !ok {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("session is not present, this is a misconfigured request")
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
userID, ok := getUserIDFromClaims(claims)
|
||||
if !ok {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("user id is not present, this is a misconfigured request")
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := oauth21.ParseCodeGrantAuthorizeRequest(r, sessionID)
|
||||
v, err := oauth21.ParseCodeGrantAuthorizeRequest(r)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to parse authorization request")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest)
|
||||
return
|
||||
}
|
||||
v.UserId = userID
|
||||
v.SessionId = sessionID
|
||||
if err := protovalidate.Validate(v); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to validate authorization request")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := srv.storage.GetClient(ctx, v.ClientId)
|
||||
if err != nil && status.Code(err) == codes.NotFound {
|
||||
|
@ -61,19 +82,51 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
id, err := srv.storage.CreateAuthorizationRequest(ctx, v)
|
||||
requiresUpstreamOAuth2Token := srv.relyingParties.HasConfigForHost(r.Host)
|
||||
var authReqID string
|
||||
var hasUpstreamOAuth2Token bool
|
||||
{
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
authReqID, err = srv.storage.CreateAuthorizationRequest(ctx, v)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to store authorization request")
|
||||
http.Error(w, "cannot create authorization request", http.StatusInternalServerError)
|
||||
return fmt.Errorf("failed to create authorization request: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
if !requiresUpstreamOAuth2Token {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
token, err := srv.GetUpstreamOAuth2Token(ctx, r.Host, userID)
|
||||
if err != nil && status.Code(err) != codes.NotFound {
|
||||
return fmt.Errorf("failed to get upstream oauth2 token: %w", err)
|
||||
}
|
||||
hasUpstreamOAuth2Token = token != ""
|
||||
return nil
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("prepare for authorization redirect")
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !requiresUpstreamOAuth2Token || hasUpstreamOAuth2Token {
|
||||
srv.AuthorizationResponse(ctx, w, r, authReqID, v)
|
||||
return
|
||||
}
|
||||
|
||||
loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, id)
|
||||
loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, authReqID)
|
||||
if ok {
|
||||
http.Redirect(w, r, loginURL, http.StatusFound)
|
||||
} else {
|
||||
srv.AuthorizationResponse(ctx, w, r, id, v)
|
||||
}
|
||||
log.Ctx(ctx).Error().Msg("authorize: must have login URL, this is a bug")
|
||||
}
|
||||
|
||||
// AuthorizationResponse generates the successful authorization response
|
||||
|
@ -111,22 +164,31 @@ func (srv *Handler) AuthorizationResponse(
|
|||
http.Redirect(w, r, to.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func getSessionFromRequest(r *http.Request) (string, error) {
|
||||
func getClaimsFromRequest(r *http.Request) (map[string]any, error) {
|
||||
h := r.Header.Get(httputil.HeaderPomeriumJWTAssertion)
|
||||
if h == "" {
|
||||
return "", fmt.Errorf("missing %s header", httputil.HeaderPomeriumJWTAssertion)
|
||||
return nil, fmt.Errorf("missing %s header", httputil.HeaderPomeriumJWTAssertion)
|
||||
}
|
||||
|
||||
token, err := jwt.ParseSigned(h)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse JWT: %w", err)
|
||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
var m map[string]any
|
||||
_ = token.UnsafeClaimsWithoutVerification(&m)
|
||||
sessionID, ok := m["sid"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing session ID in JWT")
|
||||
err = token.UnsafeClaimsWithoutVerification(&m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return sessionID, nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func getSessionIDFromClaims(claims map[string]any) (string, bool) {
|
||||
sessionID, ok := claims["sid"].(string)
|
||||
return sessionID, ok
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims map[string]any) (string, bool) {
|
||||
userID, ok := claims["sub"].(string)
|
||||
return userID, ok
|
||||
}
|
||||
|
|
|
@ -1,9 +1,63 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
)
|
||||
|
||||
func (srv *Handler) OAuthCallback(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
func (srv *Handler) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
authReqID := r.URL.Query().Get("state")
|
||||
if code == "" || authReqID == "" {
|
||||
http.Error(w, "Invalid callback request: missing code or state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var token *oauth2.Token
|
||||
var authReq *oauth21proto.AuthorizationRequest
|
||||
|
||||
{
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
token, err = srv.relyingParties.CodeExchangeForHost(ctx, r.Host, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth2: failed to exchange code: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
authReq, err = srv.storage.GetAuthorizationRequest(ctx, authReqID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get authorization request: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to exchange code")
|
||||
http.Error(w, "Failed to exchange code", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err := srv.storage.StoreUpstreamOAuth2Token(ctx, authReq.UserId, r.Host, OAuth2TokenToPB(token))
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to store upstream oauth2 token")
|
||||
http.Error(w, "Failed to store upstream oauth2 token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
srv.AuthorizationResponse(ctx, w, r, authReqID, authReq)
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
|
|||
return
|
||||
}
|
||||
|
||||
accessToken, err := srv.CreateAccessTokenForSession(session.Id, session.ExpiresAt.AsTime())
|
||||
accessToken, err := srv.GetAccessTokenForSession(session.Id, session.ExpiresAt.AsTime())
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
)
|
||||
|
||||
type OAuth2Configs struct {
|
||||
|
@ -31,6 +36,26 @@ func NewOAuthConfig(
|
|||
}
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) CodeExchangeForHost(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
code string,
|
||||
) (*oauth2.Token, error) {
|
||||
r.buildOnce.Do(r.build)
|
||||
cfg, ok := r.perHost[host]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no oauth2 config for host %s", host)
|
||||
}
|
||||
|
||||
return cfg.Exchange(ctx, code)
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) HasConfigForHost(host string) bool {
|
||||
r.buildOnce.Do(r.build)
|
||||
_, ok := r.perHost[host]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
|
||||
r.buildOnce.Do(r.build)
|
||||
|
||||
|
@ -91,3 +116,25 @@ func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
|
|||
return oauth2.AuthStyleAutoDetect
|
||||
}
|
||||
}
|
||||
|
||||
func OAuth2TokenToPB(src *oauth2.Token) *oauth21proto.TokenResponse {
|
||||
return &oauth21proto.TokenResponse{
|
||||
AccessToken: src.AccessToken,
|
||||
TokenType: src.TokenType,
|
||||
RefreshToken: proto.String(src.RefreshToken),
|
||||
ExpiresIn: proto.Int64(src.ExpiresIn),
|
||||
}
|
||||
}
|
||||
|
||||
func PBToOAuth2Token(src *oauth21proto.TokenResponse, now time.Time) oauth2.Token {
|
||||
token := oauth2.Token{
|
||||
AccessToken: src.GetAccessToken(),
|
||||
TokenType: src.GetTokenType(),
|
||||
ExpiresIn: src.GetExpiresIn(),
|
||||
RefreshToken: src.GetRefreshToken(),
|
||||
}
|
||||
if token.ExpiresIn > 0 {
|
||||
token.Expiry = now.Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
|
|
@ -145,3 +145,47 @@ func (storage *Storage) GetSession(ctx context.Context, id string) (*session.Ses
|
|||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoreUpstreamOAuth2Token stores the upstream OAuth2 token for a given session and a host
|
||||
func (storage *Storage) StoreUpstreamOAuth2Token(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
userID string,
|
||||
token *oauth21proto.TokenResponse,
|
||||
) error {
|
||||
data := protoutil.NewAny(token)
|
||||
_, err := storage.client.Put(ctx, &databroker.PutRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Id: fmt.Sprintf("%s|%s", host, userID),
|
||||
Data: data,
|
||||
Type: data.TypeUrl,
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store upstream oauth2 token for session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUpstreamOAuth2Token loads the upstream OAuth2 token for a given session and a host
|
||||
func (storage *Storage) GetUpstreamOAuth2Token(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
userID string,
|
||||
) (*oauth21proto.TokenResponse, error) {
|
||||
v := new(oauth21proto.TokenResponse)
|
||||
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
|
||||
Type: protoutil.GetTypeURL(v),
|
||||
Id: fmt.Sprintf("%s|%s", host, userID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get upstream oauth2 token for session: %w", err)
|
||||
}
|
||||
|
||||
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal upstream oauth2 token: %w", err)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -27,9 +28,9 @@ func CheckPKCE(
|
|||
return nil
|
||||
}
|
||||
|
||||
// CreateAuthorizationCode creates an access token based on the session
|
||||
func (srv *Handler) CreateAccessTokenForSession(id string, expiresAt time.Time) (string, error) {
|
||||
return CreateCode(CodeTypeAccess, id, expiresAt, "", srv.cipher)
|
||||
// GetAccessTokenForSession returns an access token for a given session and expiration time.
|
||||
func (srv *Handler) GetAccessTokenForSession(sessionID string, sessionExpiresAt time.Time) (string, error) {
|
||||
return CreateCode(CodeTypeAccess, sessionID, sessionExpiresAt, "", srv.cipher)
|
||||
}
|
||||
|
||||
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
|
||||
|
@ -41,3 +42,16 @@ func (srv *Handler) GetSessionIDFromAccessToken(accessToken string) (string, err
|
|||
|
||||
return code.Id, nil
|
||||
}
|
||||
|
||||
func (srv *Handler) GetUpstreamOAuth2Token(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
userID string,
|
||||
) (string, error) {
|
||||
token, err := srv.storage.GetUpstreamOAuth2Token(ctx, userID, host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get upstream oauth2 token: %w", err)
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
|
|
@ -4,15 +4,13 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bufbuild/protovalidate-go"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
)
|
||||
|
||||
// ParseCodeGrantAuthorizeRequest parses the authorization request for the code grant flow.
|
||||
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1
|
||||
// scopes are ignored
|
||||
func ParseCodeGrantAuthorizeRequest(r *http.Request, sessionID string) (*gen.AuthorizationRequest, error) {
|
||||
func ParseCodeGrantAuthorizeRequest(r *http.Request) (*gen.AuthorizationRequest, error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse form: %w", err)
|
||||
}
|
||||
|
@ -24,11 +22,6 @@ func ParseCodeGrantAuthorizeRequest(r *http.Request, sessionID string) (*gen.Aut
|
|||
State: optionalFormParam(r, "state"),
|
||||
CodeChallenge: r.Form.Get("code_challenge"),
|
||||
CodeChallengeMethod: optionalFormParam(r, "code_challenge_method"),
|
||||
SessionId: sessionID,
|
||||
}
|
||||
|
||||
if err := protovalidate.Validate(v); err != nil {
|
||||
return nil, fmt.Errorf("invalid request: %w", err)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
|
|
|
@ -51,6 +51,9 @@ type AuthorizationRequest struct {
|
|||
// session this authorization request is associated with.
|
||||
// This is a Pomerium implementation specific field.
|
||||
SessionId string `protobuf:"bytes,8,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"`
|
||||
// user id this authorization request is associated with.
|
||||
// This is a Pomerium implementation specific field.
|
||||
UserId string `protobuf:"bytes,9,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
@ -141,6 +144,13 @@ func (x *AuthorizationRequest) GetSessionId() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (x *AuthorizationRequest) GetUserId() string {
|
||||
if x != nil {
|
||||
return x.UserId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_authorization_request_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_authorization_request_proto_rawDesc = string([]byte{
|
||||
|
@ -148,7 +158,7 @@ var file_authorization_request_proto_rawDesc = string([]byte{
|
|||
0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f,
|
||||
0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69,
|
||||
0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x22, 0xaa, 0x03, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a,
|
||||
0x6f, 0x74, 0x6f, 0x22, 0xcb, 0x03, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x09,
|
||||
0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42,
|
||||
0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
|
||||
|
@ -171,21 +181,23 @@ var file_authorization_request_proto_rawDesc = string([]byte{
|
|||
0x64, 0x65, 0x43, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x4d, 0x65, 0x74, 0x68, 0x6f,
|
||||
0x64, 0x88, 0x01, 0x01, 0x12, 0x25, 0x0a, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f,
|
||||
0x69, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01,
|
||||
0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x5f,
|
||||
0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x42, 0x08, 0x0a, 0x06,
|
||||
0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x5f,
|
||||
0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64,
|
||||
0x42, 0x97, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31,
|
||||
0x42, 0x19, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65,
|
||||
0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e,
|
||||
0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31,
|
||||
0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75,
|
||||
0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61,
|
||||
0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x33,
|
||||
0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x07, 0x75,
|
||||
0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x42, 0x06, 0xba, 0x48,
|
||||
0x03, 0xc8, 0x01, 0x01, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x42, 0x0f, 0x0a, 0x0d,
|
||||
0x5f, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x42, 0x08, 0x0a,
|
||||
0x06, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x63, 0x6f, 0x64, 0x65,
|
||||
0x5f, 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f,
|
||||
0x64, 0x42, 0x97, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32,
|
||||
0x31, 0x42, 0x19, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31,
|
||||
0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72,
|
||||
0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74,
|
||||
0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65,
|
||||
0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32,
|
||||
0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61,
|
||||
0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
|
||||
0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x33,
|
||||
})
|
||||
|
||||
var (
|
||||
|
|
|
@ -42,10 +42,12 @@ message AuthorizationRequest {
|
|||
|
||||
// OPTIONAL, defaults to plain if not present in the request. Code verifier
|
||||
// transformation method is S256 or plain.
|
||||
optional string code_challenge_method = 7
|
||||
[ (buf.validate.field).string = {in : [ "S256", "plain" ]} ];
|
||||
optional string code_challenge_method = 7 [(buf.validate.field).string = {in: ["S256", "plain"]}];
|
||||
|
||||
// session this authorization request is associated with.
|
||||
// This is a Pomerium implementation specific field.
|
||||
string session_id = 8 [(buf.validate.field).required = true];
|
||||
// user id this authorization request is associated with.
|
||||
// This is a Pomerium implementation specific field.
|
||||
string user_id = 9 [(buf.validate.field).required = true];
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue