Skip to content

SAC: fix recompute tag propagation for ops with list[tensor] inputs #152195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
SAC: fix recompute tag propagation for ops with list[tensor] inputs
[ghstack-poisoned]
  • Loading branch information
bdhirsh committed Apr 25, 2025
commit cef15378e19542adc209f32f5571d74929ffc487
42 changes: 42 additions & 0 deletions test/dynamo/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,48 @@ def fn(x, y):
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)

@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)

def gn(x, y):
return torch.cat([x, y]).sin()

def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)

x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)

fw_compiler = functools.partial(
count_ops,
freqs=[1],
ops=[torch.ops.aten.cat.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[1],
ops=[torch.ops.aten.cat.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)

@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skip(
"In-place op support in selective checkpointing + torch.compile "
Expand Down
8 changes: 2 additions & 6 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,12 +1153,8 @@ def unpack_hook_with_error_cb(holder):

def _is_compiling(func, args, kwargs):
# Check if we are under AOTAutograd tracing
# There should probably be a better way to do this...
# TODO: unify _is_compiling across all compile stacks
for arg in args:
if isinstance(arg, torch.Tensor) and is_fun(arg):
return True
return False
# Checking that a functional mode is active should always do what we want
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None


class _VersionWrapper:
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy