|
3 | 3 | import functools
|
4 | 4 | import unittest
|
5 | 5 | from collections import defaultdict
|
| 6 | +from typing import Optional |
6 | 7 | from unittest.mock import patch
|
7 | 8 |
|
8 | 9 | import torch
|
|
15 | 16 | from torch._C import FileCheck
|
16 | 17 | from torch._dynamo.testing import CompileCounter
|
17 | 18 | from torch._dynamo.utils import same
|
| 19 | +from torch._inductor.comms import ( |
| 20 | + _reorder_communication_preserving_peak_memory_internal, |
| 21 | + ReorderInfo, |
| 22 | +) |
18 | 23 | from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
| 24 | +from torch._inductor.scheduler import BaseSchedulerNode |
19 | 25 | from torch._inductor.utils import run_and_get_triton_code
|
20 | 26 | from torch.distributed.distributed_c10d import GroupMember
|
21 | 27 | from torch.fx.experimental.proxy_tensor import make_fx
|
@@ -1400,6 +1406,86 @@ def func(inp, *, tag, ranks, group_size):
|
1400 | 1406 | correct = func(inputs, **self.get_world_trs())
|
1401 | 1407 | assert same(out, correct), f"{out} va {correct}"
|
1402 | 1408 |
|
| 1409 | + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") |
| 1410 | + def test_reorder_peak_memory(self): |
| 1411 | + """ |
| 1412 | + TODO(whc) |
| 1413 | + - check each of the `limiting_factor` cases |
| 1414 | + - confirm peak memory is respected in some adversarial case |
| 1415 | + - check whether it is expected / correct that the "buf7 = buf0; del buf0 # reuse" statement materially changes |
| 1416 | + """ |
| 1417 | + |
| 1418 | + def func(inp, *, tag, ranks, group_size): |
| 1419 | + x = inp + 1 |
| 1420 | + tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced( |
| 1421 | + [x, inp], "sum", tag, ranks, group_size |
| 1422 | + ) |
| 1423 | + y = x + 2 |
| 1424 | + ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0]) |
| 1425 | + ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1]) |
| 1426 | + # ensure other is not incorrectly aliasing ar's buffer |
| 1427 | + other = torch.ones_like(inp) + 22 |
| 1428 | + return ar0, y, other, ar1 |
| 1429 | + |
| 1430 | + inputs = torch.ones(4, 4, device="cuda") |
| 1431 | + |
| 1432 | + # get stats directly from the internal helper without affecting the real pass's signature |
| 1433 | + node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None |
| 1434 | + |
| 1435 | + def _reorder_communication_preserving_peak_memory( |
| 1436 | + snodes: list[BaseSchedulerNode], |
| 1437 | + ) -> list[BaseSchedulerNode]: |
| 1438 | + nonlocal node_stats |
| 1439 | + ( |
| 1440 | + reordered_snodes, |
| 1441 | + node_stats, |
| 1442 | + ) = _reorder_communication_preserving_peak_memory_internal(snodes) |
| 1443 | + return reordered_snodes |
| 1444 | + |
| 1445 | + with torch._inductor.config.patch( |
| 1446 | + { |
| 1447 | + "reorder_for_compute_comm_overlap": True, |
| 1448 | + "reorder_for_compute_comm_overlap_passes": [ |
| 1449 | + "sink_waits", |
| 1450 | + # same as reorder_communication_preserving_peak_memory but returns debug info structures directly |
| 1451 | + _reorder_communication_preserving_peak_memory, |
| 1452 | + ], |
| 1453 | + } |
| 1454 | + ): |
| 1455 | + compiled = torch.compile(func) |
| 1456 | + code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) |
| 1457 | + # NOTE: The first return value should be the output of the first wait_tensor. |
| 1458 | + # We want to make sure no unneccessary copy is made. |
| 1459 | + ( |
| 1460 | + FileCheck() |
| 1461 | + .check("buf0 = empty_strided") |
| 1462 | + .check("buf6 = empty_strided") |
| 1463 | + .check(".run(arg0_1, buf0, buf6, 16") |
| 1464 | + .check( |
| 1465 | + "buf1 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default([buf0, arg0_1]" |
| 1466 | + ) |
| 1467 | + # .check("buf2 = buf1[0]") |
| 1468 | + # .check("buf3 = buf1[1]") |
| 1469 | + .check("torch.ops._c10d_functional.wait_tensor.default(buf2") |
| 1470 | + # .check("buf7 = buf0; del buf0 # reuse") |
| 1471 | + # .check(".run(buf7, 16") |
| 1472 | + .check("torch.ops._c10d_functional.wait_tensor.default(buf3") |
| 1473 | + .check("return (buf2, buf6, buf7, buf3") |
| 1474 | + .run(code) |
| 1475 | + ) |
| 1476 | + out = compiled(inputs, **self.get_world_trs()) |
| 1477 | + correct = func(inputs, **self.get_world_trs()) |
| 1478 | + assert same(out, correct), f"{out} va {correct}" |
| 1479 | + |
| 1480 | + # TODO make the test case more interesting and validate the actual desired behavior |
| 1481 | + assert node_stats is not None |
| 1482 | + self.assertTrue(isinstance(node_stats, dict)) |
| 1483 | + self.assertEqual(len(node_stats), 1) |
| 1484 | + for stats in node_stats.values(): |
| 1485 | + self.assertEqual(stats.initial_exposed, 0) |
| 1486 | + self.assertEqual(stats.limiting_factor, "data dependency") |
| 1487 | + self.assertEqual(stats.moves, 0) |
| 1488 | + |
1403 | 1489 |
|
1404 | 1490 | if __name__ == "__main__":
|
1405 | 1491 | from torch._dynamo.test_case import run_tests
|
|
0 commit comments