From d3fec42b55546a467fd8c17ce01ce835849e63e2 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 12:21:02 -0700 Subject: [PATCH 1/9] move y first --- torch/_inductor/codegen/simd.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 6e5d56fecb41..07e1ffc4797d 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -423,7 +423,7 @@ def filtered_index_map(seq, mask) -> dict[Any, int]: val: idx for idx, val in enumerate(val for val in seq if val in mask) } - grid_dims = ["x", "y", "z"] + grid_dims = ["z", "y", "x"] reduction_dims = ["r0_", "r1_"] if no_x_dim: tensor_dims = reduction_dims @@ -667,9 +667,9 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: ) return_getters_groups.append(return_getters) - assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( - f"failed to set ranges {remaining} {lengths}" - ) + assert all( + V.graph.sizevars.size_hint(s) == 1 for s in remaining + ), f"failed to set ranges {remaining} {lengths}" return new_ranges, return_getters_groups @@ -818,12 +818,14 @@ def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot] trees = [ t for t in self.range_trees if not t.is_reduction or self.inside_reduction ] - if reorder and len(trees) > 1: - count = sum(t.prefix in "xyz" for t in trees) - assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ - t.prefix for t in trees[:count] - ] - trees[:count] = reversed(trees[:count]) + + def tree_order(tree: IterationRangesRoot) -> int: + assert tree.tensor_dim is not None, f"Invalid tensor dim: {tree.tensor_dim}" + return tree.tensor_dim + + if reorder: + trees = sorted(trees, key=tree_order) + return trees def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: From be6c9beb96919eca1150ba2f29f4ec935b80238e Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 12:48:52 -0700 Subject: [PATCH 2/9] fix --- torch/_inductor/codegen/simd.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 07e1ffc4797d..5d0ed5ce4648 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -423,14 +423,15 @@ def filtered_index_map(seq, mask) -> dict[Any, int]: val: idx for idx, val in enumerate(val for val in seq if val in mask) } - grid_dims = ["z", "y", "x"] + grid_dims = ["x", "y", "z"] + pointwise_tensor_dims = list(reversed(grid_dims)) reduction_dims = ["r0_", "r1_"] if no_x_dim: tensor_dims = reduction_dims elif no_r_dim: - tensor_dims = grid_dims + tensor_dims = pointwise_tensor_dims else: - tensor_dims = grid_dims + reduction_dims + tensor_dims = pointwise_tensor_dims + reduction_dims # Filter out unused tensor dims. # Convert to dicts for O(1) index lookup. @@ -819,12 +820,16 @@ def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot] t for t in self.range_trees if not t.is_reduction or self.inside_reduction ] - def tree_order(tree: IterationRangesRoot) -> int: - assert tree.tensor_dim is not None, f"Invalid tensor dim: {tree.tensor_dim}" - return tree.tensor_dim + # Put all trees with tensor_dim=None at the end, in their current order. + end_key = ( + max(tree.tensor_dim for tree in trees if tree.tensor_dim is not None) + 1 + ) + + def tree_key(tree: IterationRangesRoot) -> int: + return tree.tensor_dim if tree.tensor_dim is not None else end_key if reorder: - trees = sorted(trees, key=tree_order) + trees = sorted(trees, key=tree_key) return trees From bae117fd75144d157814c6dbd2393f4672608d79 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 18:40:17 -0700 Subject: [PATCH 3/9] supply default for max --- torch/_inductor/codegen/simd.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 5d0ed5ce4648..7f78dc307a32 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -668,9 +668,9 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: ) return_getters_groups.append(return_getters) - assert all( - V.graph.sizevars.size_hint(s) == 1 for s in remaining - ), f"failed to set ranges {remaining} {lengths}" + assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( + f"failed to set ranges {remaining} {lengths}" + ) return new_ranges, return_getters_groups @@ -822,7 +822,11 @@ def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot] # Put all trees with tensor_dim=None at the end, in their current order. end_key = ( - max(tree.tensor_dim for tree in trees if tree.tensor_dim is not None) + 1 + max( + (tree.tensor_dim for tree in trees if tree.tensor_dim is not None), + default=0, + ) + + 1 ) def tree_key(tree: IterationRangesRoot) -> int: From eca6fc5867627aa60ba555f5724760d7f22018d9 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 19:52:12 -0700 Subject: [PATCH 4/9] handle no_x_dim --- torch/_inductor/codegen/simd.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 7f78dc307a32..446d25015cdc 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -820,22 +820,28 @@ def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot] t for t in self.range_trees if not t.is_reduction or self.inside_reduction ] - # Put all trees with tensor_dim=None at the end, in their current order. - end_key = ( - max( - (tree.tensor_dim for tree in trees if tree.tensor_dim is not None), - default=0, - ) - + 1 - ) + if not reorder: + return trees def tree_key(tree: IterationRangesRoot) -> int: - return tree.tensor_dim if tree.tensor_dim is not None else end_key + assert tree.tensor_dim is not None, f"Missing tensor_dim for tree {tree}" + return tree.tensor_dim - if reorder: - trees = sorted(trees, key=tree_key) + # Keep all trees with tensor_dim=None in their current position. + sorted_trees = sorted( + (tree for tree in trees if tree.tensor_dim is not None), + key=tree_key, + ) + sorted_idx = 0 + final_trees = [] + for tree in trees: + if tree.tensor_dim is None: + final_trees.append(tree) + else: + final_trees.append(sorted_trees[sorted_idx]) + sorted_idx += 1 - return trees + return final_trees def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) From d9138c2a0b22781204c0bd38f5d8e93e383939f3 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 20:11:23 -0700 Subject: [PATCH 5/9] add test case --- .../test_torchinductor_strided_blocks.py | 68 ++++++++++++++++--- 1 file changed, 60 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index b0a6c4d4441e..c1a1b71dd98d 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -116,6 +116,9 @@ def _assert_reduction_ndims(self, code, num_dims: int) -> None: for unexpected_block in reduction_blocks[num_dims:]: self.assertNotIn(unexpected_block, code) + def _get_lines_containing_substr(self, code: str, substr: str) -> str: + return "\n".join(line for line in code.split("\n") if substr in line) + @instantiate_parametrized_tests class CommonTemplate: @@ -348,29 +351,29 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): # Check the code for broadcasts. # We shouldn't see any strides of 0. load_lines, store_lines = tuple( - [line for line in triton_code.split("\n") if substr in line] + self._get_lines_containing_substr(triton_code, substr) for substr in ("tl.load", "tl.store") ) if prefer_nd_tiling: self.assertExpectedInline( - "\n".join(load_lines), + load_lines, """\ - tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1]) - tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]""", # noqa: B950 + tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1]) + tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]""", # noqa: B950 ) self.assertExpectedInline( - "\n".join(store_lines), - """ tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tl.broadcast_to(tmp2, [XBLOCK, YBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950 + store_lines, + """ tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp2, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950 ) else: self.assertExpectedInline( - "\n".join(load_lines), + load_lines, """\ tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0]) tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950 ) self.assertExpectedInline( - "\n".join(store_lines), + store_lines, """ tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])""", # noqa: B950 ) @@ -952,6 +955,55 @@ def fn(a): rtol=0.06, ) + def test_pointwise_index_order(self): + """ + Test the order of indices in pointwise kernels. Expect Z to be the leading dim, + then Y, then X. + """ + + inps = [ + self._discontiguous_tensor((5, 5, 5), device=self.device) for _ in range(2) + ] + + result, (triton_code,) = run_and_compare( + self, + torch.add, + *inps, + expected_num_triton_kernels=1, + expected_num_block_pointers=3, + config_patches={ + "triton.max_tiles": 3, + "triton.prefer_nd_tiling": True, + }, + ) + + # Check the load and store for block pointer strides. + load_lines, store_lines, index_lines = tuple( + self._get_lines_containing_substr(triton_code, substr) + for substr in ("tl.load", "tl.store", "index =") + ) + self.assertExpectedInline( + load_lines, + """\ + tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2]) + tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", # noqa: B950 + ) + + store_lines = [line for line in triton_code.split("\n") if "tl.store" in line] + self.assertExpectedInline( + store_lines, + """[' tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])']""", # noqa: B950 + ) + + # Check the indices. These are used for non-block pointers. + self.assertExpectedInline( + index_lines, + """\ + zindex = zoffset + tl.arange(0, ZBLOCK)[:, None, None] + yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None] + xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""", # noqa: B950 + ) + @unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend") @config.patch(cpu_backend="triton") From e5018b9ecff9345526f496107247e103b7859440 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 20:15:50 -0700 Subject: [PATCH 6/9] update expected code --- test/inductor/test_torchinductor_strided_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index c1a1b71dd98d..08b0d593d1c5 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -359,7 +359,7 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): load_lines, """\ tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1]) - tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]""", # noqa: B950 + tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[:, None]""", # noqa: B950 ) self.assertExpectedInline( store_lines, From a8ade54743bd59a7b47d13ef25caea28f7865e1f Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 20:20:16 -0700 Subject: [PATCH 7/9] clean up test --- test/inductor/test_torchinductor_strided_blocks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 08b0d593d1c5..895c536ed326 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -989,10 +989,9 @@ def test_pointwise_index_order(self): tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", # noqa: B950 ) - store_lines = [line for line in triton_code.split("\n") if "tl.store" in line] self.assertExpectedInline( store_lines, - """[' tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])']""", # noqa: B950 + """ tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])""", # noqa: B950 ) # Check the indices. These are used for non-block pointers. From fe18118abfe674d6a5b806d6310166af8bb8dd1b Mon Sep 17 00:00:00 2001 From: blaine-rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 20:37:57 -0700 Subject: [PATCH 8/9] Update torch/_inductor/codegen/simd.py --- torch/_inductor/codegen/simd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 446d25015cdc..7c75448014fc 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -827,7 +827,7 @@ def tree_key(tree: IterationRangesRoot) -> int: assert tree.tensor_dim is not None, f"Missing tensor_dim for tree {tree}" return tree.tensor_dim - # Keep all trees with tensor_dim=None in their current position. + # Keep trees with tensor_dim=None in their current positions. sorted_trees = sorted( (tree for tree in trees if tree.tensor_dim is not None), key=tree_key, From 81b4fca9544771de675893d9e88435bcb9073368 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 17 Mar 2025 23:13:02 -0700 Subject: [PATCH 9/9] remove superfluous code --- torch/_inductor/codegen/simd.py | 27 ++------------------------- torch/_inductor/codegen/triton.py | 2 +- 2 files changed, 3 insertions(+), 26 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 7c75448014fc..37f8a4aee2c6 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -815,34 +815,11 @@ def prepare_indexing( return self.codegen_indexing(simp_index) - def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot]: - trees = [ + def active_range_trees(self) -> list[IterationRangesRoot]: + return [ t for t in self.range_trees if not t.is_reduction or self.inside_reduction ] - if not reorder: - return trees - - def tree_key(tree: IterationRangesRoot) -> int: - assert tree.tensor_dim is not None, f"Missing tensor_dim for tree {tree}" - return tree.tensor_dim - - # Keep trees with tensor_dim=None in their current positions. - sorted_trees = sorted( - (tree for tree in trees if tree.tensor_dim is not None), - key=tree_key, - ) - sorted_idx = 0 - final_trees = [] - for tree in trees: - if tree.tensor_dim is None: - final_trees.append(tree) - else: - final_trees.append(sorted_trees[sorted_idx]) - sorted_idx += 1 - - return final_trees - def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) for sym in sorted(expr.free_symbols, key=str): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 6e34b99cb70c..a64a1c88c735 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1966,7 +1966,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: index_relative_to_xyr_index = sympy_subs( index, {v: t.expr for v, t in self.range_tree_nodes.items()} ) - range_trees = self.active_range_trees(reorder=True) + range_trees = self.active_range_trees() # Partition the index into subexpressions pertaining to each range tree. # For example xindex * 5 + r0_index * 3 is partitioned to
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: