Skip to content

[cp] dispatch flex_attention to CP impl in TorchDispatchMode #151497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: gh/XilunWu/133/base
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6bfb512
DTensor HOP dispatch
XilunWu Apr 17, 2025
0e1171f
Update on "DTensor HOP dispatch"
XilunWu Apr 17, 2025
efc5e67
Update on "DTensor HOP dispatch"
XilunWu Apr 17, 2025
4f8e9d4
Update on "DTensor HOP dispatch"
XilunWu Apr 18, 2025
19ac1ce
Update on "DTensor HOP dispatch"
XilunWu Apr 18, 2025
9ecc674
Update on "DTensor HOP dispatch"
XilunWu Apr 22, 2025
c71ff48
Update
XilunWu Apr 24, 2025
2836766
Update
XilunWu Apr 24, 2025
103f0ef
Update
XilunWu Apr 24, 2025
bdda34a
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu Apr 27, 2025
929f157
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu Apr 27, 2025
0ba9169
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu Apr 28, 2025
644ec19
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 9, 2025
048010a
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 9, 2025
6ad5a79
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 21, 2025
bced5f0
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 27, 2025
177f177
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 27, 2025
446fb80
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 29, 2025
ddd1ab5
Update
XilunWu May 30, 2025
43c91ff
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 30, 2025
362b6e8
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
## Test
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
  • Loading branch information
XilunWu committed May 30, 2025
commit 362b6e88e289e270b55f512335a282f017b4bbb2
8 changes: 3 additions & 5 deletions torch/distributed/tensor/experimental/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,13 +1768,11 @@ def cp_flex_attention_dispatch_mode(
# NOTE: if we know that there will only be one block_mask in the model, we can
# memorize this cp_block_mask in the context instead of hitting cache every time
cp_block_mask = mode._sharder.get_cp_block_mask(mode._sharder._block_mask)
device_mesh = mode._sharder._mesh

seq_dim = 2
sharding = Shard(seq_dim)
k_dist = DTensor.from_local(key, mode._sharder._mesh, [sharding])
v_dist = DTensor.from_local(value, mode._sharder._mesh, [sharding])
k_global = k_dist.full_tensor()
v_global = v_dist.full_tensor()
k_global = mode._sharder.unshard(key, device_mesh, seq_dim)
v_global = mode._sharder.unshard(value, device_mesh, seq_dim)

# TODO: add kv reorder

Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy