diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index 52e222a64764..2bcbb094d0f4 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -460,7 +460,7 @@ def causal_mask(b, h, q_idx, kv_idx): from torch.nn.attention.flex_attention import create_block_mask, flex_attention # Compile the flex_attention function - flex_attention = torch.compile(flex_attention, dynamic=False) + flex_attention = torch.compile(flex_attention, dynamic=False, fullgraph=True) Q_BLOCK_SIZE_DEFAULT = 128 KV_BLOCK_SIZE_DEFAULT = Q_BLOCK_SIZE_DEFAULT @@ -500,10 +500,12 @@ def causal_mask(b, h, q_idx, kv_idx): device=self.device_type, ) + """ expect_out, expect_lse = flex_attention( q, k, v, block_mask=block_mask, return_lse=True ) expect_out.sum().backward() + """ # test flex attention on DTensor device_mesh = init_device_mesh( @@ -544,41 +546,53 @@ def causal_mask(b, h, q_idx, kv_idx): # create sharder sharder = _FlexAttentionSequentialSharder(device_mesh, block_mask=block_mask) - with CommDebugMode() as comm_mode: - with context_parallel( - device_mesh, - buffers=[cp_q, cp_k, cp_v], - buffer_seq_dims=[2, 2, 2], - sharder=sharder, - ): - cp_q.requires_grad = True - cp_k.requires_grad = True - cp_v.requires_grad = True - cp_out, cp_lse = flex_attention( - cp_q, - cp_k, - cp_v, - block_mask=block_mask_post_sharding, - return_lse=True, - ) - cp_out.sum().backward() + # dump profiling snapshot + import os - cp_q.requires_grad = False - cp_k.requires_grad = False - cp_v.requires_grad = False + snapshot_dir = os.path.join("./", "memory_snapshot/") + if not os.path.exists(snapshot_dir): + os.makedirs(snapshot_dir, exist_ok=True) - self.assertDictEqual( - comm_mode.get_comm_counts(), - { - c10d_functional.all_gather_into_tensor: 2 * 2, # forward + backward - c10d_functional.reduce_scatter_tensor: 2, # backward - }, # currently we have k and v all-gather separate - ) + torch.cuda.memory._record_memory_history(max_entries=100000) + + with context_parallel( + device_mesh, + buffers=[cp_q, cp_k, cp_v], + buffer_seq_dims=[2, 2, 2], + sharder=sharder, + ): + cp_q.requires_grad = True + cp_k.requires_grad = True + cp_v.requires_grad = True + cp_out, cp_lse = flex_attention( + cp_q, + cp_k, + cp_v, + block_mask=block_mask_post_sharding, + return_lse=True, + ) + cp_out.sum().backward() + + cp_q.requires_grad = False + cp_k.requires_grad = False + cp_v.requires_grad = False + + try: + if self.rank == 0: + import pickle + + with open( + f"./memory_snapshot/rank{self.rank}_memory_snapshot.pickle", "wb" + ) as output: + pickle.dump(torch.cuda.memory._snapshot(), output) + finally: + torch.cuda.memory._record_memory_history(enabled=None) # unshard the output cp_out, cp_lse = context_parallel_unshard( device_mesh, [cp_out, cp_lse], [2, 2], sharder ) + """ torch.testing.assert_close(cp_out, expect_out, atol=1e-6, rtol=1e-2) torch.testing.assert_close(cp_lse, expect_lse, atol=1e-6, rtol=1e-2) @@ -592,7 +606,7 @@ def causal_mask(b, h, q_idx, kv_idx): torch.testing.assert_close(cp_q_grad, q.grad, atol=1e-6, rtol=1e-2) torch.testing.assert_close(cp_k_grad, k.grad, atol=1e-6, rtol=1e-2) torch.testing.assert_close(cp_v_grad, v.grad, atol=1e-6, rtol=1e-2) - + """ # reset to the default mode _set_dispatch_mode("monkey_patch")
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: