diff --git a/authorize/evaluator/config.go b/authorize/evaluator/config.go index be0957faa..c7a6d01ad 100644 --- a/authorize/evaluator/config.go +++ b/authorize/evaluator/config.go @@ -1,8 +1,7 @@ package evaluator import ( - "time" - + "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/hashutil" ) @@ -19,7 +18,7 @@ type evaluatorConfig struct { JWTClaimsHeaders config.JWTClaimHeaders JWTGroupsFilter config.JWTGroupsFilter DefaultJWTIssuerFormat config.JWTIssuerFormat - MCPAccessTokenProvider func(string, time.Time) (string, error) `hash:"-"` + MCPAccessTokenProvider store.MCPAccessTokenProvider `hash:"-"` } // cacheKey() returns a hash over the configuration, except for the policies. @@ -118,7 +117,7 @@ func WithDefaultJWTIssuerFormat(format config.JWTIssuerFormat) Option { } // WithMCPAccessTokenProvider sets the MCP access token in the config. -func WithMCPAccessTokenProvider(fn func(sessionID string, expires time.Time) (string, error)) Option { +func WithMCPAccessTokenProvider(fn store.MCPAccessTokenProvider) Option { return func(cfg *evaluatorConfig) { cfg.MCPAccessTokenProvider = fn } diff --git a/authorize/evaluator/headers_evaluator_evaluation.go b/authorize/evaluator/headers_evaluator_evaluation.go index c6bbbaf68..82ba647d2 100644 --- a/authorize/evaluator/headers_evaluator_evaluation.go +++ b/authorize/evaluator/headers_evaluator_evaluation.go @@ -93,24 +93,37 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) er return nil } -func (e *headersEvaluatorEvaluation) fillMCPHeaders() error { - if e.request == nil || - e.request.Policy == nil || - e.request.Policy.MCP == nil || - !e.request.Policy.MCP.PassUpstreamAccessToken { +func (e *headersEvaluatorEvaluation) fillMCPHeaders(ctx context.Context) (err error) { + if e.request == nil || e.request.Policy == nil || e.request.Policy.MCP == nil || e.request.Session.ID == "" { return nil } - if e.request.Session.ID == "" { + p := e.evaluator.store.GetMCPAccessTokenProvider() + if p == nil { return nil } - accessToken, err := e.evaluator.store.GetMCPAccessTokenProvider()(e.request.Session.ID, e.now.Add(time.Hour)) - if err != nil { - return fmt.Errorf("authorize/header-evaluator: error getting MCP access token: %w", err) + var accessToken string + if e.request.Policy.MCP.IsUpstreamClientNeedsAccessToken() { + accessToken, err = p.GetAccessTokenForSession(e.request.Session.ID, time.Now().Add(5*time.Minute)) + if err != nil { + return fmt.Errorf("authorize/header-evaluator: error getting MCP access token: %w", err) + } + e.response.Headers.Set("Authorization", "Bearer "+accessToken) + return nil } - e.response.Headers.Set("Authorization", "Bearer "+accessToken) + if e.request.Policy.MCP.HasUpstreamOAuth2() { + user := e.getUser(ctx) + accessToken, err = p.GetUpstreamOAuth2Token(ctx, e.request.HTTP.Host, user.Id) + if err != nil { + return fmt.Errorf("authorize/header-evaluator: error getting upstream oauth2 token: %w", err) + } + e.response.Headers.Set("Authorization", "Bearer "+accessToken) + return nil + } + + e.response.Headers.Del("Authorization") return nil } @@ -198,7 +211,7 @@ func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error { if err := e.fillJWTClaimHeaders(ctx); err != nil { return err } - err := e.fillMCPHeaders() + err := e.fillMCPHeaders(ctx) if err != nil { return err } diff --git a/authorize/internal/store/store.go b/authorize/internal/store/store.go index 970b97da5..9ac99acc7 100644 --- a/authorize/internal/store/store.go +++ b/authorize/internal/store/store.go @@ -37,7 +37,17 @@ type Store struct { defaultJWTIssuerFormat atomic.Pointer[config.JWTIssuerFormat] signingKey atomic.Pointer[jose.JSONWebKey] - mcpAccessTokenProvider atomicutil.Value[func(string, time.Time) (string, error)] + mcpAccessTokenProvider atomicutil.Value[MCPAccessTokenProvider] +} + +type MCPAccessTokenProvider interface { + // GetAccessToken returns an access token for the given session ID and expiration time, + // that may be upsed by the MCP client to interact with the MCP servers fronted by Pomerium. + GetAccessTokenForSession(sessionID string, expiresAt time.Time) (string, error) + + // GetUpstreamOAuth2Token returns an upstream OAuth2 token for the given host and session ID + // that is used by the MCP server to interact with the upstream APIs. + GetUpstreamOAuth2Token(ctx context.Context, host, sessionID string) (string, error) } // New creates a new Store. @@ -81,13 +91,8 @@ func (s *Store) GetSigningKey() *jose.JSONWebKey { return s.signingKey.Load() } -func (s *Store) GetMCPAccessTokenProvider() func(string, time.Time) (string, error) { - if f := s.mcpAccessTokenProvider.Load(); f != nil { - return f - } - return func(string, time.Time) (string, error) { - return "", fmt.Errorf("no mcp access token provider") - } +func (s *Store) GetMCPAccessTokenProvider() MCPAccessTokenProvider { + return s.mcpAccessTokenProvider.Load() } // UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication @@ -127,9 +132,11 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) { s.signingKey.Store(signingKey) } -func (s *Store) UpdateMCPAccessTokenProvider(mcpAccessTokenProvider func(string, time.Time) (string, error)) { +func (s *Store) UpdateMCPAccessTokenProvider(mcpAccessTokenProvider MCPAccessTokenProvider) { // This isn't used by the Rego code, so we don't need to write it to the opastorage.Store instance. - s.mcpAccessTokenProvider.Store(mcpAccessTokenProvider) + if mcpAccessTokenProvider != nil { + s.mcpAccessTokenProvider.Store(mcpAccessTokenProvider) + } } func (s *Store) write(rawPath string, value any) { diff --git a/authorize/state.go b/authorize/state.go index d52bbceb6..e6545a9ce 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -65,7 +65,7 @@ func newAuthorizeStateFromConfig( } state.mcp = mcp - state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp.CreateAccessTokenForSession)) + state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp)) if err != nil { return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) } diff --git a/config/policy.go b/config/policy.go index f3e817802..e0786cf1b 100644 --- a/config/policy.go +++ b/config/policy.go @@ -215,6 +215,16 @@ type MCP struct { PassUpstreamAccessToken bool `mapstructure:"pass_upstream_access_token" yaml:"pass_upstream_access_token,omitempty" json:"pass_upstream_access_token,omitempty"` } +// HasUpstreamOAuth2 checks if the route is for the MCP Server and if it has an upstream OAuth2 configuration +func (p *MCP) HasUpstreamOAuth2() bool { + return p != nil && p.UpstreamOAuth2 != nil +} + +// IsUpstreamClientNeedsAccessToken checks if the route is for the MCP Client and if it needs to pass the upstream access token +func (p *MCP) IsUpstreamClientNeedsAccessToken() bool { + return p != nil && p.UpstreamOAuth2 != nil && p.PassUpstreamAccessToken +} + type UpstreamOAuth2 struct { ClientID string `mapstructure:"client_id" yaml:"client_id,omitempty" json:"client_id,omitempty"` ClientSecret string `mapstructure:"client_secret" yaml:"client_secret,omitempty" json:"client_secret,omitempty"` diff --git a/internal/mcp/handler_authorization.go b/internal/mcp/handler_authorization.go index 7e8c5f0cf..e7c3cc0ce 100644 --- a/internal/mcp/handler_authorization.go +++ b/internal/mcp/handler_authorization.go @@ -8,7 +8,9 @@ import ( "net/url" "time" + "github.com/bufbuild/protovalidate-go" "github.com/go-jose/go-jose/v3/jwt" + "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -27,19 +29,38 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - sessionID, err := getSessionFromRequest(r) + claims, err := getClaimsFromRequest(r) if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to get claims from request") + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + sessionID, ok := getSessionIDFromClaims(claims) + if !ok { log.Ctx(ctx).Error().Err(err).Msg("session is not present, this is a misconfigured request") http.Error(w, "internal server error", http.StatusInternalServerError) return } + userID, ok := getUserIDFromClaims(claims) + if !ok { + log.Ctx(ctx).Error().Err(err).Msg("user id is not present, this is a misconfigured request") + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } - v, err := oauth21.ParseCodeGrantAuthorizeRequest(r, sessionID) + v, err := oauth21.ParseCodeGrantAuthorizeRequest(r) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("failed to parse authorization request") oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest) return } + v.UserId = userID + v.SessionId = sessionID + if err := protovalidate.Validate(v); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to validate authorization request") + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest) + return + } client, err := srv.storage.GetClient(ctx, v.ClientId) if err != nil && status.Code(err) == codes.NotFound { @@ -61,19 +82,51 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) { return } - id, err := srv.storage.CreateAuthorizationRequest(ctx, v) - if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("failed to store authorization request") - http.Error(w, "cannot create authorization request", http.StatusInternalServerError) + requiresUpstreamOAuth2Token := srv.relyingParties.HasConfigForHost(r.Host) + var authReqID string + var hasUpstreamOAuth2Token bool + { + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + var err error + authReqID, err = srv.storage.CreateAuthorizationRequest(ctx, v) + if err != nil { + return fmt.Errorf("failed to create authorization request: %w", err) + } + return nil + }) + eg.Go(func() error { + if !requiresUpstreamOAuth2Token { + return nil + } + + var err error + token, err := srv.GetUpstreamOAuth2Token(ctx, r.Host, userID) + if err != nil && status.Code(err) != codes.NotFound { + return fmt.Errorf("failed to get upstream oauth2 token: %w", err) + } + hasUpstreamOAuth2Token = token != "" + return nil + }) + + err := eg.Wait() + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("prepare for authorization redirect") + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + } + + if !requiresUpstreamOAuth2Token || hasUpstreamOAuth2Token { + srv.AuthorizationResponse(ctx, w, r, authReqID, v) return } - loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, id) + loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, authReqID) if ok { http.Redirect(w, r, loginURL, http.StatusFound) - } else { - srv.AuthorizationResponse(ctx, w, r, id, v) } + log.Ctx(ctx).Error().Msg("authorize: must have login URL, this is a bug") } // AuthorizationResponse generates the successful authorization response @@ -111,22 +164,31 @@ func (srv *Handler) AuthorizationResponse( http.Redirect(w, r, to.String(), http.StatusFound) } -func getSessionFromRequest(r *http.Request) (string, error) { +func getClaimsFromRequest(r *http.Request) (map[string]any, error) { h := r.Header.Get(httputil.HeaderPomeriumJWTAssertion) if h == "" { - return "", fmt.Errorf("missing %s header", httputil.HeaderPomeriumJWTAssertion) + return nil, fmt.Errorf("missing %s header", httputil.HeaderPomeriumJWTAssertion) } token, err := jwt.ParseSigned(h) if err != nil { - return "", fmt.Errorf("failed to parse JWT: %w", err) + return nil, fmt.Errorf("failed to parse JWT: %w", err) } var m map[string]any - _ = token.UnsafeClaimsWithoutVerification(&m) - sessionID, ok := m["sid"].(string) - if !ok { - return "", fmt.Errorf("missing session ID in JWT") + err = token.UnsafeClaimsWithoutVerification(&m) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT claims: %w", err) } - return sessionID, nil + return m, nil +} + +func getSessionIDFromClaims(claims map[string]any) (string, bool) { + sessionID, ok := claims["sid"].(string) + return sessionID, ok +} + +func getUserIDFromClaims(claims map[string]any) (string, bool) { + userID, ok := claims["sub"].(string) + return userID, ok } diff --git a/internal/mcp/handler_oauth_callback.go b/internal/mcp/handler_oauth_callback.go index 911a7aef4..60ac5a226 100644 --- a/internal/mcp/handler_oauth_callback.go +++ b/internal/mcp/handler_oauth_callback.go @@ -1,9 +1,63 @@ package mcp import ( + "fmt" "net/http" + + "golang.org/x/oauth2" + "golang.org/x/sync/errgroup" + + "github.com/pomerium/pomerium/internal/log" + oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" ) -func (srv *Handler) OAuthCallback(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotImplemented) +func (srv *Handler) OAuthCallback(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + code := r.URL.Query().Get("code") + authReqID := r.URL.Query().Get("state") + if code == "" || authReqID == "" { + http.Error(w, "Invalid callback request: missing code or state", http.StatusBadRequest) + return + } + + var token *oauth2.Token + var authReq *oauth21proto.AuthorizationRequest + + { + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + var err error + token, err = srv.relyingParties.CodeExchangeForHost(ctx, r.Host, code) + if err != nil { + return fmt.Errorf("oauth2: failed to exchange code: %w", err) + } + return nil + }) + eg.Go(func() error { + var err error + authReq, err = srv.storage.GetAuthorizationRequest(ctx, authReqID) + if err != nil { + return fmt.Errorf("failed to get authorization request: %w", err) + } + + return nil + }) + + err := eg.Wait() + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to exchange code") + http.Error(w, "Failed to exchange code", http.StatusInternalServerError) + return + } + } + + err := srv.storage.StoreUpstreamOAuth2Token(ctx, authReq.UserId, r.Host, OAuth2TokenToPB(token)) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to store upstream oauth2 token") + http.Error(w, "Failed to store upstream oauth2 token", http.StatusInternalServerError) + return + } + + srv.AuthorizationResponse(ctx, w, r, authReqID, authReq) } diff --git a/internal/mcp/handler_token.go b/internal/mcp/handler_token.go index aaae8d953..d9618ff79 100644 --- a/internal/mcp/handler_token.go +++ b/internal/mcp/handler_token.go @@ -87,7 +87,7 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http. return } - accessToken, err := srv.CreateAccessTokenForSession(session.Id, session.ExpiresAt.AsTime()) + accessToken, err := srv.GetAccessTokenForSession(session.Id, session.ExpiresAt.AsTime()) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return diff --git a/internal/mcp/oauth_config.go b/internal/mcp/oauth_config.go index 648384e56..1fe3ac4df 100644 --- a/internal/mcp/oauth_config.go +++ b/internal/mcp/oauth_config.go @@ -1,14 +1,19 @@ package mcp import ( + "context" + "fmt" "net/http" "net/url" "path" "sync" + "time" "golang.org/x/oauth2" + "google.golang.org/protobuf/proto" "github.com/pomerium/pomerium/config" + oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" ) type OAuth2Configs struct { @@ -31,6 +36,26 @@ func NewOAuthConfig( } } +func (r *OAuth2Configs) CodeExchangeForHost( + ctx context.Context, + host string, + code string, +) (*oauth2.Token, error) { + r.buildOnce.Do(r.build) + cfg, ok := r.perHost[host] + if !ok { + return nil, fmt.Errorf("no oauth2 config for host %s", host) + } + + return cfg.Exchange(ctx, code) +} + +func (r *OAuth2Configs) HasConfigForHost(host string) bool { + r.buildOnce.Do(r.build) + _, ok := r.perHost[host] + return ok +} + func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) { r.buildOnce.Do(r.build) @@ -91,3 +116,25 @@ func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle { return oauth2.AuthStyleAutoDetect } } + +func OAuth2TokenToPB(src *oauth2.Token) *oauth21proto.TokenResponse { + return &oauth21proto.TokenResponse{ + AccessToken: src.AccessToken, + TokenType: src.TokenType, + RefreshToken: proto.String(src.RefreshToken), + ExpiresIn: proto.Int64(src.ExpiresIn), + } +} + +func PBToOAuth2Token(src *oauth21proto.TokenResponse, now time.Time) oauth2.Token { + token := oauth2.Token{ + AccessToken: src.GetAccessToken(), + TokenType: src.GetTokenType(), + ExpiresIn: src.GetExpiresIn(), + RefreshToken: src.GetRefreshToken(), + } + if token.ExpiresIn > 0 { + token.Expiry = now.Add(time.Duration(token.ExpiresIn) * time.Second) + } + return token +} diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 69da0824d..675818786 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -145,3 +145,47 @@ func (storage *Storage) GetSession(ctx context.Context, id string) (*session.Ses return v, nil } + +// StoreUpstreamOAuth2Token stores the upstream OAuth2 token for a given session and a host +func (storage *Storage) StoreUpstreamOAuth2Token( + ctx context.Context, + host string, + userID string, + token *oauth21proto.TokenResponse, +) error { + data := protoutil.NewAny(token) + _, err := storage.client.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{{ + Id: fmt.Sprintf("%s|%s", host, userID), + Data: data, + Type: data.TypeUrl, + }}, + }) + if err != nil { + return fmt.Errorf("failed to store upstream oauth2 token for session: %w", err) + } + return nil +} + +// GetUpstreamOAuth2Token loads the upstream OAuth2 token for a given session and a host +func (storage *Storage) GetUpstreamOAuth2Token( + ctx context.Context, + host string, + userID string, +) (*oauth21proto.TokenResponse, error) { + v := new(oauth21proto.TokenResponse) + rec, err := storage.client.Get(ctx, &databroker.GetRequest{ + Type: protoutil.GetTypeURL(v), + Id: fmt.Sprintf("%s|%s", host, userID), + }) + if err != nil { + return nil, fmt.Errorf("failed to get upstream oauth2 token for session: %w", err) + } + + err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal upstream oauth2 token: %w", err) + } + + return v, nil +} diff --git a/internal/mcp/token.go b/internal/mcp/token.go index 1ebac34c4..8342d4c04 100644 --- a/internal/mcp/token.go +++ b/internal/mcp/token.go @@ -1,6 +1,7 @@ package mcp import ( + "context" "fmt" "time" @@ -27,9 +28,9 @@ func CheckPKCE( return nil } -// CreateAuthorizationCode creates an access token based on the session -func (srv *Handler) CreateAccessTokenForSession(id string, expiresAt time.Time) (string, error) { - return CreateCode(CodeTypeAccess, id, expiresAt, "", srv.cipher) +// GetAccessTokenForSession returns an access token for a given session and expiration time. +func (srv *Handler) GetAccessTokenForSession(sessionID string, sessionExpiresAt time.Time) (string, error) { + return CreateCode(CodeTypeAccess, sessionID, sessionExpiresAt, "", srv.cipher) } // DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID @@ -41,3 +42,16 @@ func (srv *Handler) GetSessionIDFromAccessToken(accessToken string) (string, err return code.Id, nil } + +func (srv *Handler) GetUpstreamOAuth2Token( + ctx context.Context, + host string, + userID string, +) (string, error) { + token, err := srv.storage.GetUpstreamOAuth2Token(ctx, userID, host) + if err != nil { + return "", fmt.Errorf("failed to get upstream oauth2 token: %w", err) + } + + return token.AccessToken, nil +} diff --git a/internal/oauth21/authorize.go b/internal/oauth21/authorize.go index bd0fc0d4d..3beb781b5 100644 --- a/internal/oauth21/authorize.go +++ b/internal/oauth21/authorize.go @@ -4,15 +4,13 @@ import ( "fmt" "net/http" - "github.com/bufbuild/protovalidate-go" - "github.com/pomerium/pomerium/internal/oauth21/gen" ) // ParseCodeGrantAuthorizeRequest parses the authorization request for the code grant flow. // see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1 // scopes are ignored -func ParseCodeGrantAuthorizeRequest(r *http.Request, sessionID string) (*gen.AuthorizationRequest, error) { +func ParseCodeGrantAuthorizeRequest(r *http.Request) (*gen.AuthorizationRequest, error) { if err := r.ParseForm(); err != nil { return nil, fmt.Errorf("failed to parse form: %w", err) } @@ -24,11 +22,6 @@ func ParseCodeGrantAuthorizeRequest(r *http.Request, sessionID string) (*gen.Aut State: optionalFormParam(r, "state"), CodeChallenge: r.Form.Get("code_challenge"), CodeChallengeMethod: optionalFormParam(r, "code_challenge_method"), - SessionId: sessionID, - } - - if err := protovalidate.Validate(v); err != nil { - return nil, fmt.Errorf("invalid request: %w", err) } return v, nil diff --git a/internal/oauth21/gen/authorization_request.pb.go b/internal/oauth21/gen/authorization_request.pb.go index 8a426080d..d1af9ebb6 100644 --- a/internal/oauth21/gen/authorization_request.pb.go +++ b/internal/oauth21/gen/authorization_request.pb.go @@ -50,7 +50,10 @@ type AuthorizationRequest struct { CodeChallengeMethod *string `protobuf:"bytes,7,opt,name=code_challenge_method,json=codeChallengeMethod,proto3,oneof" json:"code_challenge_method,omitempty"` // session this authorization request is associated with. // This is a Pomerium implementation specific field. - SessionId string `protobuf:"bytes,8,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + SessionId string `protobuf:"bytes,8,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + // user id this authorization request is associated with. + // This is a Pomerium implementation specific field. + UserId string `protobuf:"bytes,9,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -141,6 +144,13 @@ func (x *AuthorizationRequest) GetSessionId() string { return "" } +func (x *AuthorizationRequest) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + var File_authorization_request_proto protoreflect.FileDescriptor var file_authorization_request_proto_rawDesc = string([]byte{ @@ -148,7 +158,7 @@ var file_authorization_request_proto_rawDesc = string([]byte{ 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x22, 0xaa, 0x03, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x6f, 0x74, 0x6f, 0x22, 0xcb, 0x03, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, @@ -171,21 +181,23 @@ var file_authorization_request_proto_rawDesc = string([]byte{ 0x64, 0x65, 0x43, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x88, 0x01, 0x01, 0x12, 0x25, 0x0a, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, - 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x5f, - 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x42, 0x08, 0x0a, 0x06, - 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x5f, - 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, - 0x42, 0x97, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, - 0x42, 0x19, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, - 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, - 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, - 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, - 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, - 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x07, 0x75, + 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x42, 0x06, 0xba, 0x48, + 0x03, 0xc8, 0x01, 0x01, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x42, 0x0f, 0x0a, 0x0d, + 0x5f, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x42, 0x08, 0x0a, + 0x06, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x63, 0x6f, 0x64, 0x65, + 0x5f, 0x63, 0x68, 0x61, 0x6c, 0x6c, 0x65, 0x6e, 0x67, 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, + 0x64, 0x42, 0x97, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, + 0x31, 0x42, 0x19, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, + 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, + 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, + 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, + 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, + 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, }) var ( diff --git a/internal/oauth21/proto/authorization_request.proto b/internal/oauth21/proto/authorization_request.proto index 532167ce2..a76e54d2e 100644 --- a/internal/oauth21/proto/authorization_request.proto +++ b/internal/oauth21/proto/authorization_request.proto @@ -10,7 +10,7 @@ option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen"; // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.1 message AuthorizationRequest { // The client identifier as described in Section 2.2. - string client_id = 1 [ (buf.validate.field).required = true ]; + string client_id = 1 [(buf.validate.field).required = true]; // OPTIONAL if only one redirect URI is registered for this client. REQUIRED // if multiple redirict URIs are registered for this client. @@ -23,7 +23,7 @@ message AuthorizationRequest { // code flow. string response_type = 3 [ (buf.validate.field).required = true, - (buf.validate.field).string = {in : [ "code" ]} + (buf.validate.field).string = {in: ["code"]} ]; // OPTIONAL. An opaque value used by the client to maintain state between the @@ -37,15 +37,17 @@ message AuthorizationRequest { // REQUIRED, assumes https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1 string code_challenge = 6 [ (buf.validate.field).required = true, - (buf.validate.field).string = {min_len : 43, max_len : 128} + (buf.validate.field).string = {min_len: 43, max_len: 128} ]; // OPTIONAL, defaults to plain if not present in the request. Code verifier // transformation method is S256 or plain. - optional string code_challenge_method = 7 - [ (buf.validate.field).string = {in : [ "S256", "plain" ]} ]; + optional string code_challenge_method = 7 [(buf.validate.field).string = {in: ["S256", "plain"]}]; // session this authorization request is associated with. // This is a Pomerium implementation specific field. - string session_id = 8 [ (buf.validate.field).required = true ]; + string session_id = 8 [(buf.validate.field).required = true]; + // user id this authorization request is associated with. + // This is a Pomerium implementation specific field. + string user_id = 9 [(buf.validate.field).required = true]; }