package mcp import ( "fmt" "net/http" "golang.org/x/oauth2" "golang.org/x/sync/errgroup" "github.com/pomerium/pomerium/internal/log" oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" ) func (srv *Handler) OAuthCallback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() code := r.URL.Query().Get("code") authReqID := r.URL.Query().Get("state") if code == "" || authReqID == "" { http.Error(w, "Invalid callback request: missing code or state", http.StatusBadRequest) return } var token *oauth2.Token var authReq *oauth21proto.AuthorizationRequest { eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { var err error token, err = srv.relyingParties.CodeExchangeForHost(ctx, r.Host, code) if err != nil { return fmt.Errorf("oauth2: failed to exchange code: %w", err) } return nil }) eg.Go(func() error { var err error authReq, err = srv.storage.GetAuthorizationRequest(ctx, authReqID) if err != nil { return fmt.Errorf("failed to get authorization request: %w", err) } return nil }) err := eg.Wait() if err != nil { log.Ctx(ctx).Error().Err(err).Msg("failed to exchange code") http.Error(w, "Failed to exchange code", http.StatusInternalServerError) return } } err := srv.storage.StoreUpstreamOAuth2Token(ctx, authReq.UserId, r.Host, OAuth2TokenToPB(token)) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("failed to store upstream oauth2 token") http.Error(w, "Failed to store upstream oauth2 token", http.StatusInternalServerError) return } srv.AuthorizationResponse(ctx, w, r, authReqID, authReq) }