Skip to content

Commit a4a7716

Browse files
wconstabpytorchmergebot
authored andcommitted
[pt2d] Add reorder_comms_preserving_peak_memory pass (#146562)
This is a new pass to replace the pre-existing passes. It has the same basic goal, to achieve communication overlap (latency hiding), but also constrains the solution to not increase peak memory. The principles of operation are detailed in code comments, but summarized here: - never reorder collectives relative to each other (TBD if we should relax this later) - before performing reordering, push all comm and wait nodes as late as possible, respecting data dependencies - estimate peak memory and current memory at each scheduler node - move collective nodes forward one position at a time, if the move does not increaes curr memory beyond peak memory The pass logs a summary table for each graph to TORCH_LOGS=overlap. e.g. (exact format may have been tweaked but this shows the idea). ``` rank0]:[rank0]:I0210 17:24:28.494000 2711253 torch/_inductor/comms.py:195] [0/0] [__overlap] Collective node initial exposed final exposed improvement limiting factor moves [rank0]:[rank0]:I0210 17:24:28.494000 2711253 torch/_inductor/comms.py:195] [0/0] [__overlap] ----------------------------------------------------------------------------------------------------------------------------------------------------------- ----------------- --------------- ------------- ------------------- ------- [rank0]:[rank0]:I0210 17:24:28.494000 2711253 torch/_inductor/comms.py:195] [0/0] [__overlap] ExternKernelSchedulerNode(name='op2') (torch.ops._c10d_functional.all_gather_into_tensor.default) (size=[2256, 256], stride=[256, 1]) (buf2) (12142 ns) 12141.6 6514.53 5627.08 prefetch limit 75 [rank0]:[rank0]:I0210 17:24:28.494000 2711253 torch/_inductor/comms.py:195] [0/0] [__overlap] ExternKernelSchedulerNode(name='op6') (torch.ops._c10d_functional.reduce_scatter_tensor.default) (size=[282, 256], stride=[256, 1]) (buf7) (32266 ns) 32265.8 28429.2 3836.61 data dependency 78 [rank0]:[rank0]:I0210 17:24:28.494000 2711253 torch/_inductor/comms.py:195] [0/0] [__overlap] ExternKernelSchedulerNode(name='op9') (torch.ops._c10d_functional.all_gather_into_tensor.default) (size=[256], stride=[1]) (buf11) (10801 ns) 10800.6 10732.3 68.254 peak memory 1 [rank0]:[rank0]:I0210 17:24:28.494000 2711253 torch/_inductor/comms.py:195] [0/0] [__overlap] ExternKernelSchedulerNode(name='op14') (torch.ops._c10d_functional.reduce_scatter_tensor.default) (size=[32], stride=[1]) (buf17) (10810 ns) 10809.5 10809.5 0 data dependency 4 [rank ``` Pull Request resolved: #146562 Approved by: https://github.com/eellison ghstack dependencies: #152060, #146561
1 parent e35e316 commit a4a7716

File tree

3 files changed

+303
-14
lines changed

3 files changed

+303
-14
lines changed

test/distributed/test_inductor_collectives.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import unittest
55
from collections import defaultdict
6+
from typing import Optional
67
from unittest.mock import patch
78

89
import torch
@@ -15,7 +16,12 @@
1516
from torch._C import FileCheck
1617
from torch._dynamo.testing import CompileCounter
1718
from torch._dynamo.utils import same
19+
from torch._inductor.comms import (
20+
_reorder_communication_preserving_peak_memory_internal,
21+
ReorderInfo,
22+
)
1823
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
24+
from torch._inductor.scheduler import BaseSchedulerNode
1925
from torch._inductor.utils import run_and_get_triton_code
2026
from torch.distributed.distributed_c10d import GroupMember
2127
from torch.fx.experimental.proxy_tensor import make_fx
@@ -1400,6 +1406,86 @@ def func(inp, *, tag, ranks, group_size):
14001406
correct = func(inputs, **self.get_world_trs())
14011407
assert same(out, correct), f"{out} va {correct}"
14021408

1409+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
1410+
def test_reorder_peak_memory(self):
1411+
"""
1412+
TODO(whc)
1413+
- check each of the `limiting_factor` cases
1414+
- confirm peak memory is respected in some adversarial case
1415+
- check whether it is expected / correct that the "buf7 = buf0; del buf0 # reuse" statement materially changes
1416+
"""
1417+
1418+
def func(inp, *, tag, ranks, group_size):
1419+
x = inp + 1
1420+
tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced(
1421+
[x, inp], "sum", tag, ranks, group_size
1422+
)
1423+
y = x + 2
1424+
ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
1425+
ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
1426+
# ensure other is not incorrectly aliasing ar's buffer
1427+
other = torch.ones_like(inp) + 22
1428+
return ar0, y, other, ar1
1429+
1430+
inputs = torch.ones(4, 4, device="cuda")
1431+
1432+
# get stats directly from the internal helper without affecting the real pass's signature
1433+
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
1434+
1435+
def _reorder_communication_preserving_peak_memory(
1436+
snodes: list[BaseSchedulerNode],
1437+
) -> list[BaseSchedulerNode]:
1438+
nonlocal node_stats
1439+
(
1440+
reordered_snodes,
1441+
node_stats,
1442+
) = _reorder_communication_preserving_peak_memory_internal(snodes)
1443+
return reordered_snodes
1444+
1445+
with torch._inductor.config.patch(
1446+
{
1447+
"reorder_for_compute_comm_overlap": True,
1448+
"reorder_for_compute_comm_overlap_passes": [
1449+
"sink_waits",
1450+
# same as reorder_communication_preserving_peak_memory but returns debug info structures directly
1451+
_reorder_communication_preserving_peak_memory,
1452+
],
1453+
}
1454+
):
1455+
compiled = torch.compile(func)
1456+
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
1457+
# NOTE: The first return value should be the output of the first wait_tensor.
1458+
# We want to make sure no unneccessary copy is made.
1459+
(
1460+
FileCheck()
1461+
.check("buf0 = empty_strided")
1462+
.check("buf6 = empty_strided")
1463+
.check(".run(arg0_1, buf0, buf6, 16")
1464+
.check(
1465+
"buf1 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default([buf0, arg0_1]"
1466+
)
1467+
# .check("buf2 = buf1[0]")
1468+
# .check("buf3 = buf1[1]")
1469+
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
1470+
# .check("buf7 = buf0; del buf0 # reuse")
1471+
# .check(".run(buf7, 16")
1472+
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
1473+
.check("return (buf2, buf6, buf7, buf3")
1474+
.run(code)
1475+
)
1476+
out = compiled(inputs, **self.get_world_trs())
1477+
correct = func(inputs, **self.get_world_trs())
1478+
assert same(out, correct), f"{out} va {correct}"
1479+
1480+
# TODO make the test case more interesting and validate the actual desired behavior
1481+
assert node_stats is not None
1482+
self.assertTrue(isinstance(node_stats, dict))
1483+
self.assertEqual(len(node_stats), 1)
1484+
for stats in node_stats.values():
1485+
self.assertEqual(stats.initial_exposed, 0)
1486+
self.assertEqual(stats.limiting_factor, "data dependency")
1487+
self.assertEqual(stats.moves, 0)
1488+
14031489

14041490
if __name__ == "__main__":
14051491
from torch._dynamo.test_case import run_tests

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy