-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[cp] dispatch flex_attention to CP impl in TorchDispatchMode #151497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/XilunWu/133/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151497
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit 362b6e8 with merge base ba51f48 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
With #151719 (and a few changes to DTensor's |
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
Looks pretty good let me know when ready for another review |
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamp to unblock, please address the comments before landing.
|
||
|
||
def _context_parallel_buffers( | ||
mesh: DeviceMesh, | ||
buffers: list[torch.Tensor], | ||
buffer_seq_dims: list[int], | ||
sharder: Optional[Union[_Sharder, type[_Sharder]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If FlexAttention expects users to always pass a _Sharder
object instead of a class, we should just required users to do so -- either None
for SDPA to always use _RoundRobinSharder
or users have to pass in a _Sharder
object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For public APIs (i.e. context_parallel
and context_parallel_unshard
), I think it's nice to follow this logic.
For private APIs (i.e. _context_parallel
, _context_parallel_buffers
), I suggest we enforce sharder
be not None
.
WDYT?
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
## Test `pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requesting changes for a few reasons:
- For a relatively complicated changes like this, please update your PR summary with the design so that reviewers can read it, instead of only leaving a test command.
- From API side, I don't think the CP APIs should leak the "Sharder" concept to the users, it is confusing with some other concepts (i.e. the Shard placement). I think it should not be a "Sharder" actually, it is mainly a way to achieve load balancing, the API design needs some further discussion.
cp_v = v.detach().clone() | ||
|
||
# create sharder | ||
sharder = _FlexAttentionSequentialSharder(device_mesh, block_mask=block_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think exposing a way for user to control how the mask is being partitioned is fine. But I don't think we should do it in the current way, specifically I don't think we should leak the concept of "Sharder" as it would confuse with user its relationship with the Shard placement in DTensor.
|
||
|
||
# register flex_attention HOP to CommDebugMode | ||
@flex_attention_hop.py_impl(CommDebugMode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why one need to register a hop here?
Regarding the naming issue, we can discuss alternative names if To clarify, is the concern related to the naming of An alternative is make configuring |
I very much believe we want to not rely on auto-load balancing for all users and need way for users to explicit partition |
@fegin @drisspg I am not debating whether to expose a way to do load balancing, I think it could be very valuable to allow users (especially the advanced users) to control how to do load balancing instead of fully rely on auto load balancing. But we should NOT leak any sort of sharding concept (or the input partition concept you called) to the users, it would confuse users on its relationship with DTensor sharding, why there's another sharding method, etc. So yes the concern is mainly the exposure of the sharding (or you called as input_partition) concept, I think we should not mix the concept of load balancing and sharding.
It is not just the naming issue. I don't think this is "unrelated to DTensor as we don't represent the input batch as a DTensor", context parallelism is a prototype feature developed within the dtensor module and it uses quite a few concepts from DTensor. It is critical to make sure CP works in a good way with DTensor. IMO If it's not working well with DTensor and do not provide good abstraction layers, I don't know if there're materialized values for users to use our solutions compared to using other OSS solutions or even writing their own one (i.e. if we are just allgathering k/v, and rely on user to define how to shard in details and does not do anything else, doing this manually is even better for user). To be concrete, the current interaction between DTensor and CP is very weird: Although the Sharder/input_partition does not use DTensor (this is debatable, what if user just want to pass their own created DTensor to this flex attention thing?) and we don't represent the input batch as DTensor, but the current way to dispatch to the F.scaled_dot_product_attention to run ring attention (not this flex attention impl) is through constructing DTensor as inputs that sharded on sequence dim to the SDPA function, and dispatch to the op. But when using the RoundRobinSharder for load balancing, we should NOT construct DTensor anymore as the RoundRobinSharder does not align with the |
@wanchaol I agree current CP doesn't work well with DTensor. But I don't agree with you that load balancing should be decoupled from the input sharding. How the input is sharded and arranged will affect the load balancing. Even though we just simply perform allgather on KV, Q is not gathered and how Q is sharded is decided by the input sharding. The use case you mentioned is definitely an important one and advanced one. However, a collective (likely to be alltoall) for Q is required for this case, even for ring attention implementation. IMO, we should not sacrifice the performance of the basic use case. I definitely agree this is a direction we should solve, but I will argue that we still need to pass the information of IMO, what @kwen2501 proposed may be one way to achieve the good DTensor semantics for the load balancing? We let DTensor to support a special shard (a sharding that can represent rank0 has the chunk0 and chunk N and rank has the chunk 1 and chunk N -1, ...), I don't know a suitable name, bit-map-annotation shard? With this representation, users for the current CP use case can simply partition the input batch using this bit-map-annotation shard to achieve load balancing without extra alltoall. Users can also opt to not shard the input in this way if they have a special use case to do attention and are okay to do alltoall. Either way, this sharding information serves a carrier to pass the information to CP to perform load balancing. Could you give me some high level idea what you imagine how to decouple load balancing from sharding? I may misinterpret what you mean about the sharding being decoupled from sharding. |
@fegin Oh IIUC even allgathering KV requires one to know how the load balancing is performed, not only Q is decided by load balancing. So if user defined their own load balancing approach, it might be easier for them to shard/allgather QKVs. I do think load balancing and shard input are two separate concepts, maybe not totally orthogonal though.
I have pretty strong opinion that we should NOT add new special Shard placement that could not be generalized, as that would significantly complicates DTensor without bringing much values (i.e. the current roundrobin is just one way to help load balancing, there could be many other ways, so it could not generalize well). I also feel there is actually no gap in DTensor to implement CP today, i.e. even in today's state, I could refactor how CP works with SDPA or flex attention without touching any core parts of DTensor, and still provide good UX. It is more like how the CP and its load balancing functionality should be designed.
Sure, so the high level idea is: load balancing and sharding need to be decoupled to separate layers, the load balancing layer is responsible to reorder the input sequence so that the attention computation are evenly balanced across rows/ranks. Sharding is a separate layer after it, it simply take the input reordered by the load balancer and shard it exactly as how DTensor shard tensor sequentially today. DTensor does not need to aware of load balancing at all as load balancing is a separate layer that exposed and maybe controlled by users. I actually had a draft refactor while I was refactoring the current ring attention implementation, I'll publish the changes soon :) |
ghstack-source-id: deeedba Pull Request resolved: pytorch/pytorch#151497
ghstack-source-id: ed62ca7 Pull Request resolved: pytorch/pytorch#151497
Stack from ghstack (oldest at bottom):
Test
pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k