Skip to content

Commit 38d5e2d

Browse files
committed
Update on "[cp] dispatch flex_attention_backward to CP impl in TorchDispatchMode"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k [ghstack-poisoned]
2 parents 49a3caf + 4c5af9a commit 38d5e2d

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

test/distributed/tensor/test_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,6 @@ def world_size(self) -> int:
449449
def test_ring_flex_attention(self) -> None:
450450
def causal_mask(b, h, q_idx, kv_idx):
451451
return q_idx >= kv_idx
452-
# return q_idx >= 0
453452

454453
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
455454

torch/distributed/tensor/experimental/_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,9 +1509,14 @@ def context_parallel(
15091509
):
15101510
yield
15111511

1512-
return
15131512
for buffer, original_buffer in zip(buffers, original_buffers):
15141513
if original_buffer is not None:
1514+
# tensor cannot resize if requires_grad is True
1515+
# key and value's requires_grad has been set to False in manual comm calls
1516+
# unless via DTensor.
1517+
if buffer.requires_grad:
1518+
buffer.requires_grad = False
1519+
15151520
buffer.resize_(original_buffer.shape)
15161521
buffer.copy_(original_buffer)
15171522

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy