Skip to content

[1/n] refactor the ring attention implementation #155441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jun 9, 2025

Stack from ghstack (oldest at bottom):

as titled, I'm working on a series of changes to make ring attention
impl and DTensor works better together, this PR specifically refactor the
current implemtnation to:

  • remove dead/unused code
  • restructure the functions to make them stay organized
  • refactor to remove/make error message better

cc @H-Huang @awgu @fegin @fduwjj @wz337 @wconstab @d4l3k

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jun 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155441

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 Cancelled Job

As of commit 6011e48 with merge base 0f56318 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jun 9, 2025
@wanchaol wanchaol added the topic: not user facing topic category label Jun 9, 2025
[ghstack-poisoned]
@wanchaol wanchaol requested a review from fegin June 9, 2025 21:00
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #155442

pytorchmergebot pushed a commit that referenced this pull request Jun 27, 2025
This PR rewrite how load balancing and sharding works in the current
context parallel implementation.

Why the changes? We should NOT expose another layer of "sharding"
concept as it would confuse the user about its difference with DTensor
sharding. The current CP perform sharding weirdly simply because it
mixed the concept of load balancing and sharding.

I think load balancing and sharding need to be decoupled to separate
layers:

* The load balancing layer is responsible to reorder the input sequence
so that the attention computation are evenly balanced across rows/ranks.
* Sharding is a separate layer after it, it simply take the input reordered by
the load balancer and shard it exactly as how DTensor shard tensor sequentially

In this PR:
* I removed the "Sharder" and "LoadBalancer" mixed usage, and
simply generate a roundrobin indices when the mask is a casual mask
* use `distribute_tensor` to perform the sharding. We still keep the local
shard instead of the DTensor objects to allow maximum compatibility with
arbitrary model architecture given DTensor op coverage is not high
enough.

One alternative design is to still keep the LoadBalancer and add the indices
generation and restore to be the protocol of the LoadBalancer. I thought through
it and think we might want to directly expose the load_balancing indices as
an argument instead of a dedicated class interface, so I removed it here. More
discussion on this is welcomed.

Pull Request resolved: #155442
Approved by: https://github.com/XilunWu
ghstack dependencies: #155441
superiwan pushed a commit to superiwan/pytorch that referenced this pull request Jul 14, 2025
as titled, I'm working on a series of changes to make ring attention
impl and DTensor works better together, this PR specifically refactor the
current implemtnation to:

* remove dead/unused code
* restructure the functions to make them stay organized
* refactor to remove/make error message better

ghstack-source-id: e6c861b
Pull-Request-resolved: pytorch/pytorch#155441
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy