diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index fd63f88a0d2d..52e222a64764 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -503,6 +503,7 @@ def causal_mask(b, h, q_idx, kv_idx): expect_out, expect_lse = flex_attention( q, k, v, block_mask=block_mask, return_lse=True ) + expect_out.sum().backward() # test flex attention on DTensor device_mesh = init_device_mesh( @@ -550,6 +551,9 @@ def causal_mask(b, h, q_idx, kv_idx): buffer_seq_dims=[2, 2, 2], sharder=sharder, ): + cp_q.requires_grad = True + cp_k.requires_grad = True + cp_v.requires_grad = True cp_out, cp_lse = flex_attention( cp_q, cp_k, @@ -557,11 +561,17 @@ def causal_mask(b, h, q_idx, kv_idx): block_mask=block_mask_post_sharding, return_lse=True, ) + cp_out.sum().backward() + + cp_q.requires_grad = False + cp_k.requires_grad = False + cp_v.requires_grad = False self.assertDictEqual( comm_mode.get_comm_counts(), { - c10d_functional.all_gather_into_tensor: 2 + c10d_functional.all_gather_into_tensor: 2 * 2, # forward + backward + c10d_functional.reduce_scatter_tensor: 2, # backward }, # currently we have k and v all-gather separate ) @@ -572,6 +582,17 @@ def causal_mask(b, h, q_idx, kv_idx): torch.testing.assert_close(cp_out, expect_out, atol=1e-6, rtol=1e-2) torch.testing.assert_close(cp_lse, expect_lse, atol=1e-6, rtol=1e-2) + # unshard the gradient + cp_q_grad, cp_k_grad, cp_v_grad = context_parallel_unshard( + device_mesh, + [cp_q.grad, cp_k.grad, cp_v.grad], + [2, 2, 2], + sharder, + ) + torch.testing.assert_close(cp_q_grad, q.grad, atol=1e-6, rtol=1e-2) + torch.testing.assert_close(cp_k_grad, k.grad, atol=1e-6, rtol=1e-2) + torch.testing.assert_close(cp_v_grad, v.grad, atol=1e-6, rtol=1e-2) + # reset to the default mode _set_dispatch_mode("monkey_patch") diff --git a/torch/distributed/tensor/debug/_comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py index 1221ecbb0425..1a5b078fa1c4 100644 --- a/torch/distributed/tensor/debug/_comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -4,15 +4,19 @@ import re import weakref from collections import defaultdict -from typing import Any, Callable +from typing import Any, Callable, Optional, Union import torch import torch.nn from torch._guards import detect_fake_mode -from torch._higher_order_ops import flex_attention as flex_attention_hop +from torch._higher_order_ops.flex_attention import ( + flex_attention as flex_attention_hop, + flex_attention_backward as flex_attention_backward_hop, +) from torch.autograd.graph import register_multi_grad_hook from torch.distributed._tools.mod_tracker import ModTracker from torch.distributed.tensor._api import DTensor +from torch.fx.graph_module import GraphModule from torch.nn.modules.module import ( register_module_forward_hook, register_module_forward_pre_hook, @@ -713,16 +717,23 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): ] += 1 # adds collective count to parent modules - for par in self.advanced_module_tracker.module_parents_dict[ - self.advanced_module_tracker.name - ]: - # makes sure we aren't double counting when current sub-module hasn't been removed from parents - if par != self.advanced_module_tracker.name: - if par not in self.comm_module_counts: - self.comm_module_counts[par] = {} - self.comm_module_counts[par]["forward"] = defaultdict(int) - self.comm_module_counts[par]["backward"] = defaultdict(int) - self.comm_module_counts[par][key][func_packet] += 1 + # TODO (xilunwu): this is a temporary hack to unblock the issue + # in tracking flex_attention_backward. Need to fix it later on. + # The issue happens when we call flex_attention_backward which + # sets ``self.advanced_module_tracker.name`` to "" and + # ``self.advanced_module_tracker.module_parents_dict[""]`` + # results in KeyError. + if self.advanced_module_tracker.name != "": + for par in self.advanced_module_tracker.module_parents_dict[ + self.advanced_module_tracker.name + ]: + # makes sure we aren't double counting when current sub-module hasn't been removed from parents + if par != self.advanced_module_tracker.name: + if par not in self.comm_module_counts: + self.comm_module_counts[par] = {} + self.comm_module_counts[par]["forward"] = defaultdict(int) + self.comm_module_counts[par]["backward"] = defaultdict(int) + self.comm_module_counts[par][key][func_packet] += 1 # if tensor op uses fake tensors, return if detect_fake_mode(args): @@ -761,3 +772,42 @@ def flex_attention_comm_debug_mode( score_mod_other_buffers, mask_mod_other_buffers, ) + + +# register flex_attention_backward HOP to CommDebugMode +@flex_attention_backward_hop.py_impl(CommDebugMode) +def flex_attention_backward_comm_debug_mode( + mode: CommDebugMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] +]: + return flex_attention_backward_hop( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index be452d23e624..c6ba502bb3c2 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -17,12 +17,16 @@ import torch.distributed._functional_collectives as ft_c import torch.nn.functional as F from torch import nn -from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop +from torch._higher_order_ops.flex_attention import ( + flex_attention as flex_attention_hop, + flex_attention_backward as flex_attention_backward_hop, +) from torch._ops import TorchDispatchMode from torch._prims_common import DeviceLikeType from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard from torch.distributed.tensor.parallel.style import ParallelStyle +from torch.fx.graph_module import GraphModule from torch.nn.attention.flex_attention import ( _identity, _mask_mod_signature, @@ -1789,3 +1793,72 @@ def cp_flex_attention_dispatch_mode( ) return out, lse + + +@flex_attention_backward_hop.py_impl(ContextParallelMode) +def cp_flex_attention_backward_dispatch_mode( + mode: ContextParallelMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] +]: + assert mode._sharder is not None, ( + "flex_attention is called but ContextParallelMode._sharder is not initialized. " + "Please pass the `sharder` argument to `context_parallel`." + ) + + assert isinstance(mode._sharder, _FlexAttentionSharder) + device_mesh = mode._sharder._mesh + cp_block_mask = mode._sharder.get_cp_block_mask(mode._sharder._block_mask) + + # all-gather KV + seq_dim = 2 + k_global = mode._sharder.unshard(key, device_mesh, seq_dim) + v_global = mode._sharder.unshard(value, device_mesh, seq_dim) + + # TODO: add kv reorder + + ( + grad_query, + grad_key, + grad_value, + grad_score_mod_captured, + ) = flex_attention_backward_hop( + query, + k_global, # key + v_global, # value + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + cp_block_mask.as_tuple(), # block_mask + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + # reduce-scatter KV grads + grad_key = ft_c.reduce_scatter_tensor( + grad_key, reduceOp="sum", scatter_dim=seq_dim, group=device_mesh + ) + grad_value = ft_c.reduce_scatter_tensor( + grad_value, reduceOp="sum", scatter_dim=seq_dim, group=device_mesh + ) + + return grad_query, grad_key, grad_value, grad_score_mod_captured pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy