mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
mcp: authorize: load session from the access token (#5591)
This commit is contained in:
parent
0602f5e00d
commit
daaf5b8e30
4 changed files with 112 additions and 17 deletions
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/mcp"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -28,6 +29,7 @@ type Authorize struct {
|
|||
store *store.Store
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
mcp *atomicutil.Value[*mcp.Handler]
|
||||
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
tracer oteltrace.Tracer
|
||||
|
@ -37,11 +39,18 @@ type Authorize struct {
|
|||
func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
|
||||
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
||||
|
||||
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
|
||||
}
|
||||
|
||||
a := &Authorize{
|
||||
currentConfig: atomicutil.NewValue(cfg),
|
||||
store: store.New(),
|
||||
tracerProvider: tracerProvider,
|
||||
tracer: tracer,
|
||||
mcp: atomicutil.NewValue(mcp),
|
||||
}
|
||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||
|
||||
|
@ -151,4 +160,11 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
} else {
|
||||
a.state.Store(newState)
|
||||
}
|
||||
|
||||
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update authorize state from configuration settings")
|
||||
} else {
|
||||
a.mcp.Store(mcp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,9 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||
)
|
||||
|
@ -88,29 +90,14 @@ func (a *Authorize) loadSession(
|
|||
) (s sessionOrServiceAccount, err error) {
|
||||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||
|
||||
// attempt to create a session from an incoming idp token
|
||||
s, err = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||
return nil
|
||||
},
|
||||
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
||||
s, err = a.maybeGetSessionFromRequest(ctx, hreq, req.Policy)
|
||||
if err == nil {
|
||||
return s, nil
|
||||
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
|
||||
log.Ctx(ctx).Info().
|
||||
Str("request-id", requestID).
|
||||
Err(err).
|
||||
Msg("error creating session for incoming idp token")
|
||||
Msg("error creating session from incoming request")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -131,6 +118,76 @@ func (a *Authorize) loadSession(
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) maybeGetSessionFromRequest(
|
||||
ctx context.Context,
|
||||
hreq *http.Request,
|
||||
policy *config.Policy,
|
||||
) (*session.Session, error) {
|
||||
if policy.IsMCP() {
|
||||
s, err := a.getMCPSession(ctx, hreq)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session")
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// attempt to create a session from an incoming idp token
|
||||
return config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||
return nil
|
||||
},
|
||||
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
|
||||
}
|
||||
|
||||
func (a *Authorize) getMCPSession(
|
||||
ctx context.Context,
|
||||
hreq *http.Request,
|
||||
) (*session.Session, error) {
|
||||
auth := hreq.Header.Get(httputil.HeaderAuthorization)
|
||||
if auth == "" {
|
||||
return nil, fmt.Errorf("no authorization header was provided: %w", sessions.ErrNoSessionFound)
|
||||
}
|
||||
|
||||
prefix := "Bearer "
|
||||
if !strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return nil, fmt.Errorf("authorization header does not start with %q: %w", prefix, sessions.ErrNoSessionFound)
|
||||
}
|
||||
|
||||
accessToken := auth[len(prefix):]
|
||||
sessionID, ok := a.mcp.Load().GetSessionIDFromAccessToken(ctx, accessToken)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no session found for access token: %w", sessions.ErrNoSessionFound)
|
||||
}
|
||||
|
||||
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, 0)
|
||||
if storage.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("session databroker record not found: %w", sessions.ErrNoSessionFound)
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling session: %w: %w", err, sessions.ErrNoSessionFound)
|
||||
}
|
||||
|
||||
s, ok := msg.(*session.Session)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected session type: %T: %w", msg, sessions.ErrNoSessionFound)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
||||
ctx context.Context,
|
||||
in *envoy_service_auth_v3.CheckRequest,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -116,3 +117,12 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
|
|||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
func (srv *Handler) GetSessionIDFromAccessToken(ctx context.Context, accessToken string) (string, bool) {
|
||||
sessionID, err := DecryptAccessToken(accessToken, srv.cipher)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to decrypt access token")
|
||||
return "", false
|
||||
}
|
||||
return sessionID, true
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package mcp
|
|||
import (
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/oauth21"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
@ -28,6 +29,17 @@ func CheckPKCE(
|
|||
return nil
|
||||
}
|
||||
|
||||
// CreateAuthorizationCode creates an access token based on the session
|
||||
func CreateAccessToken(src *session.Session, cipher cipher.AEAD) (string, error) {
|
||||
return CreateCode(CodeTypeAccess, src.Id, src.ExpiresAt.AsTime(), "", cipher)
|
||||
}
|
||||
|
||||
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
|
||||
func DecryptAccessToken(accessToken string, cipher cipher.AEAD) (string, error) {
|
||||
code, err := DecryptCode(CodeTypeAccess, accessToken, cipher, "", time.Now())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return code.Id, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue