Skip to content

[cp] dispatch flex_attention to CP impl in TorchDispatchMode #151497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: gh/XilunWu/133/base
Choose a base branch
from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Apr 17, 2025

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151497

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 362b6e8 with merge base ba51f48 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category labels Apr 17, 2025
XilunWu added a commit that referenced this pull request Apr 17, 2025
ghstack-source-id: 1b88275
Pull Request resolved: #151497
@XilunWu XilunWu marked this pull request as draft April 17, 2025 01:15
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 17, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 17, 2025
… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 17, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 17, 2025
… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 17, 2025
ghstack-source-id: 126968f
Pull Request resolved: #151497
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 18, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 18, 2025
… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 18, 2025
ghstack-source-id: b60a89c
Pull Request resolved: #151497
@bdhirsh
Copy link
Contributor

bdhirsh commented Apr 18, 2025

With #151719 (and a few changes to DTensor's flex_attention handling see here: P1789852198), I could run the E2E ring attention tests locally!

## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 22, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 22, 2025
… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
@drisspg
Copy link
Contributor

drisspg commented May 12, 2025

Looks pretty good let me know when ready for another review

@XilunWu XilunWu requested review from fegin and drisspg May 21, 2025 23:08
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp to unblock, please address the comments before landing.



def _context_parallel_buffers(
mesh: DeviceMesh,
buffers: list[torch.Tensor],
buffer_seq_dims: list[int],
sharder: Optional[Union[_Sharder, type[_Sharder]]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If FlexAttention expects users to always pass a _Sharder object instead of a class, we should just required users to do so -- either None for SDPA to always use _RoundRobinSharder or users have to pass in a _Sharder object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For public APIs (i.e. context_parallel and context_parallel_unshard), I think it's nice to follow this logic.

For private APIs (i.e. _context_parallel, _context_parallel_buffers), I suggest we enforce sharder be not None.

WDYT?

XilunWu added 3 commits May 27, 2025 15:51
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
@XilunWu
Copy link
Contributor Author

XilunWu commented May 30, 2025

@pytorchbot rebase

@XilunWu XilunWu added the ciflow/trunk Trigger trunk jobs on your pull request label May 30, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/XilunWu/133/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/151497)

XilunWu added 2 commits May 29, 2025 22:49
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes for a few reasons:

  • For a relatively complicated changes like this, please update your PR summary with the design so that reviewers can read it, instead of only leaving a test command.
  • From API side, I don't think the CP APIs should leak the "Sharder" concept to the users, it is confusing with some other concepts (i.e. the Shard placement). I think it should not be a "Sharder" actually, it is mainly a way to achieve load balancing, the API design needs some further discussion.

cp_v = v.detach().clone()

# create sharder
sharder = _FlexAttentionSequentialSharder(device_mesh, block_mask=block_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think exposing a way for user to control how the mask is being partitioned is fine. But I don't think we should do it in the current way, specifically I don't think we should leak the concept of "Sharder" as it would confuse with user its relationship with the Shard placement in DTensor.



# register flex_attention HOP to CommDebugMode
@flex_attention_hop.py_impl(CommDebugMode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why one need to register a hop here?

@fegin
Copy link
Contributor

fegin commented Jun 2, 2025

@wanchaol

From API side, I don't think the CP APIs should leak the "Sharder" concept to the users, it is confusing with some other concepts (i.e. the Shard placement). I think it should not be a "Sharder" actually, it is mainly a way to achieve load balancing, the API design needs some further discussion.

Regarding the naming issue, we can discuss alternative names if Sharder is not suitable. However, it's essential to note that Sharder (or any other name) in this context refers specifically to partitioning the input batch along the sequence dimension to achieve load balancing, which is unrelated to DTensor as we don't represent the input batch as a DTensor.

To clarify, is the concern related to the naming of Sharder or the exposure of the input partition concept itself? The input partition needs to be exposed as only users know how to achieve a good load balancing given a customized BlockMask. Though we also plan to do auto load balancing based on the mask, some advanced users have expressed the need to manually control the partition.

An alternative is make configuring Sharder an opt-in API that users can call before CP. Thus, most users can just use the built-in feature without worrying the load balancing concept. In such a case, users will just pass in the global BlockMask. If users would like to perform its own partitioning, then users can pass the Sharder instead of BlockMask.

@XilunWu Can you update the summary as @wanchaol mentioned?

@drisspg
Copy link
Contributor

drisspg commented Jun 3, 2025

I very much believe we want to not rely on auto-load balancing for all users and need way for users to explicit partition

@wanchaol
Copy link
Collaborator

wanchaol commented Jun 3, 2025

@fegin @drisspg I am not debating whether to expose a way to do load balancing, I think it could be very valuable to allow users (especially the advanced users) to control how to do load balancing instead of fully rely on auto load balancing. But we should NOT leak any sort of sharding concept (or the input partition concept you called) to the users, it would confuse users on its relationship with DTensor sharding, why there's another sharding method, etc. So yes the concern is mainly the exposure of the sharding (or you called as input_partition) concept, I think we should not mix the concept of load balancing and sharding.

Regarding the naming issue, we can discuss alternative names if Sharder is not suitable. However, it's essential to note that Sharder (or any other name) in this context refers specifically to partitioning the input batch along the sequence dimension to achieve load balancing, which is unrelated to DTensor as we don't represent the input batch as a DTensor.

It is not just the naming issue. I don't think this is "unrelated to DTensor as we don't represent the input batch as a DTensor", context parallelism is a prototype feature developed within the dtensor module and it uses quite a few concepts from DTensor. It is critical to make sure CP works in a good way with DTensor. IMO If it's not working well with DTensor and do not provide good abstraction layers, I don't know if there're materialized values for users to use our solutions compared to using other OSS solutions or even writing their own one (i.e. if we are just allgathering k/v, and rely on user to define how to shard in details and does not do anything else, doing this manually is even better for user).

To be concrete, the current interaction between DTensor and CP is very weird: Although the Sharder/input_partition does not use DTensor (this is debatable, what if user just want to pass their own created DTensor to this flex attention thing?) and we don't represent the input batch as DTensor, but the current way to dispatch to the F.scaled_dot_product_attention to run ring attention (not this flex attention impl) is through constructing DTensor as inputs that sharded on sequence dim to the SDPA function, and dispatch to the op. But when using the RoundRobinSharder for load balancing, we should NOT construct DTensor anymore as the RoundRobinSharder does not align with the Shard placement behavior, so the constructed DTensors are actually not valid. This tells that the current Sharder abstraction in CP does not work in a correct way, and this is one main reason I think we should not expose it as is in the current way :)

@fegin
Copy link
Contributor

fegin commented Jun 3, 2025

@wanchaol I agree current CP doesn't work well with DTensor. But I don't agree with you that load balancing should be decoupled from the input sharding. How the input is sharded and arranged will affect the load balancing. Even though we just simply perform allgather on KV, Q is not gathered and how Q is sharded is decided by the input sharding.

The use case you mentioned is definitely an important one and advanced one. However, a collective (likely to be alltoall) for Q is required for this case, even for ring attention implementation. IMO, we should not sacrifice the performance of the basic use case. I definitely agree this is a direction we should solve, but I will argue that we still need to pass the information of how the input is partitioned to CP so that CP knows how to perform load balancing (reshuffling or alltoall for Q if needed).

IMO, what @kwen2501 proposed may be one way to achieve the good DTensor semantics for the load balancing? We let DTensor to support a special shard (a sharding that can represent rank0 has the chunk0 and chunk N and rank has the chunk 1 and chunk N -1, ...), I don't know a suitable name, bit-map-annotation shard?

With this representation, users for the current CP use case can simply partition the input batch using this bit-map-annotation shard to achieve load balancing without extra alltoall. Users can also opt to not shard the input in this way if they have a special use case to do attention and are okay to do alltoall. Either way, this sharding information serves a carrier to pass the information to CP to perform load balancing.

Could you give me some high level idea what you imagine how to decouple load balancing from sharding? I may misinterpret what you mean about the sharding being decoupled from sharding.

@fegin
Copy link
Contributor

fegin commented Jun 3, 2025

@wanchaol As for dispatcher, we can ask users to wrap SDPA/FlexAttention inside a module and use parallelize_module as we discussed. I don't have a strong opinion about this. cc., @XilunWu

@wanchaol
Copy link
Collaborator

wanchaol commented Jun 5, 2025

But I don't agree with you that load balancing should be decoupled from the input sharding. How the input is sharded and arranged will affect the load balancing. Even though we just simply perform allgather on KV, Q is not gathered and how Q is sharded is decided by the input sharding.

@fegin Oh IIUC even allgathering KV requires one to know how the load balancing is performed, not only Q is decided by load balancing. So if user defined their own load balancing approach, it might be easier for them to shard/allgather QKVs. I do think load balancing and shard input are two separate concepts, maybe not totally orthogonal though.

IMO, we should not sacrifice the performance of the basic use case. I definitely agree this is a direction we should solve, but I will argue that this is the gap of DTensor.
IMO, what @kwen2501 proposed is the right solution to support both with a good UX -- we let DTensor to support a special shard (a sharding that can represent rank0 has the chunk0 and chunk N and rank has the chunk 1 and chunk N -1, ...), I don't know a suitable name, bit-map-annotation shard?

I have pretty strong opinion that we should NOT add new special Shard placement that could not be generalized, as that would significantly complicates DTensor without bringing much values (i.e. the current roundrobin is just one way to help load balancing, there could be many other ways, so it could not generalize well). I also feel there is actually no gap in DTensor to implement CP today, i.e. even in today's state, I could refactor how CP works with SDPA or flex attention without touching any core parts of DTensor, and still provide good UX. It is more like how the CP and its load balancing functionality should be designed.

Could you give me some high level idea what you imagine how to decouple load balancing from sharding? I may misinterpret what you mean about the sharding being decoupled from sharding.

Sure, so the high level idea is: load balancing and sharding need to be decoupled to separate layers, the load balancing layer is responsible to reorder the input sequence so that the attention computation are evenly balanced across rows/ranks. Sharding is a separate layer after it, it simply take the input reordered by the load balancer and shard it exactly as how DTensor shard tensor sequentially today. DTensor does not need to aware of load balancing at all as load balancing is a separate layer that exposed and maybe controlled by users. I actually had a draft refactor while I was refactoring the current ring attention implementation, I'll publish the changes soon :)

superiwan pushed a commit to superiwan/pytorch that referenced this pull request Jul 14, 2025
superiwan pushed a commit to superiwan/pytorch that referenced this pull request Jul 14, 2025
ghstack-source-id: ed62ca7
Pull Request resolved: pytorch/pytorch#151497
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request module: context parallel PyTorch Context Parallel oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: context parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy