From cef15378e19542adc209f32f5571d74929ffc487 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 25 Apr 2025 09:07:07 -0700 Subject: [PATCH 1/2] SAC: fix recompute tag propagation for ops with list[tensor] inputs [ghstack-poisoned] --- test/dynamo/test_activation_checkpointing.py | 42 ++++++++++++++++++++ torch/utils/checkpoint.py | 8 +--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 17021eb46565..8c456f7fb59b 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -967,6 +967,48 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + def test_compile_selective_checkpoint_list_ops(self, device): + def selective_checkpointing_context_fn(): + # recompute everything + no_recompute_list = [] + return create_selective_checkpoint_contexts( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.cat([x, y]).sin() + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device=device) + y = torch.randn(4, 4, requires_grad=True, device=device) + + fw_compiler = functools.partial( + count_ops, + freqs=[1], + ops=[torch.ops.aten.cat.default], + ) + bw_compiler = functools.partial( + count_ops, + freqs=[1], + ops=[torch.ops.aten.cat.default], + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=min_cut_rematerialization_partition, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skip( "In-place op support in selective checkpointing + torch.compile " diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 8ca576818da8..05b49288e6bc 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -1153,12 +1153,8 @@ def unpack_hook_with_error_cb(holder): def _is_compiling(func, args, kwargs): # Check if we are under AOTAutograd tracing - # There should probably be a better way to do this... - # TODO: unify _is_compiling across all compile stacks - for arg in args: - if isinstance(arg, torch.Tensor) and is_fun(arg): - return True - return False + # Checking that a functional mode is active should always do what we want + return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None class _VersionWrapper: From 5b752e34cf790864662ecafc85b94f87ea52014c Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 25 Apr 2025 10:44:20 -0700 Subject: [PATCH 2/2] Update on "SAC: fix recompute tag propagation for ops with list[tensor] inputs" There's an "are we compiling" check in SAC, which we rely on to know when to propagate recompute tags during tracing. This check was a bit brittle, and missed cases where input ops accept list of tensors - I updated it to check if a `FunctionalTensorMode` is active, which should be a 100% reliable way to know if AOTDispatcher is in the middle of running. There is a long-standing followup here around unifying `torch.compiler.is_compiling()` to work in all cases. We should probably just update it to always check if FakeMode/FunctionalMode are active and use it there. This has a bit of BC risk though so I opted for the more local fix to SAC. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned] --- torch/utils/checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 05b49288e6bc..ccde9707a7eb 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -11,7 +11,6 @@ import torch import torch.fx.traceback as fx_traceback -from torch._functorch._aot_autograd.functional_utils import is_fun from torch.utils._pytree import tree_map from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode from torch.utils._python_dispatch import TorchDispatchMode pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy