diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index b0a6c4d4441e..895c536ed326 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -116,6 +116,9 @@ def _assert_reduction_ndims(self, code, num_dims: int) -> None: for unexpected_block in reduction_blocks[num_dims:]: self.assertNotIn(unexpected_block, code) + def _get_lines_containing_substr(self, code: str, substr: str) -> str: + return "\n".join(line for line in code.split("\n") if substr in line) + @instantiate_parametrized_tests class CommonTemplate: @@ -348,29 +351,29 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): # Check the code for broadcasts. # We shouldn't see any strides of 0. load_lines, store_lines = tuple( - [line for line in triton_code.split("\n") if substr in line] + self._get_lines_containing_substr(triton_code, substr) for substr in ("tl.load", "tl.store") ) if prefer_nd_tiling: self.assertExpectedInline( - "\n".join(load_lines), + load_lines, """\ - tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1]) - tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]""", # noqa: B950 + tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1]) + tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[:, None]""", # noqa: B950 ) self.assertExpectedInline( - "\n".join(store_lines), - """ tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tl.broadcast_to(tmp2, [XBLOCK, YBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950 + store_lines, + """ tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp2, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950 ) else: self.assertExpectedInline( - "\n".join(load_lines), + load_lines, """\ tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0]) tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950 ) self.assertExpectedInline( - "\n".join(store_lines), + store_lines, """ tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])""", # noqa: B950 ) @@ -952,6 +955,54 @@ def fn(a): rtol=0.06, ) + def test_pointwise_index_order(self): + """ + Test the order of indices in pointwise kernels. Expect Z to be the leading dim, + then Y, then X. + """ + + inps = [ + self._discontiguous_tensor((5, 5, 5), device=self.device) for _ in range(2) + ] + + result, (triton_code,) = run_and_compare( + self, + torch.add, + *inps, + expected_num_triton_kernels=1, + expected_num_block_pointers=3, + config_patches={ + "triton.max_tiles": 3, + "triton.prefer_nd_tiling": True, + }, + ) + + # Check the load and store for block pointer strides. + load_lines, store_lines, index_lines = tuple( + self._get_lines_containing_substr(triton_code, substr) + for substr in ("tl.load", "tl.store", "index =") + ) + self.assertExpectedInline( + load_lines, + """\ + tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2]) + tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", # noqa: B950 + ) + + self.assertExpectedInline( + store_lines, + """ tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])""", # noqa: B950 + ) + + # Check the indices. These are used for non-block pointers. + self.assertExpectedInline( + index_lines, + """\ + zindex = zoffset + tl.arange(0, ZBLOCK)[:, None, None] + yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None] + xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""", # noqa: B950 + ) + @unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend") @config.patch(cpu_backend="triton") diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 6e5d56fecb41..37f8a4aee2c6 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -424,13 +424,14 @@ def filtered_index_map(seq, mask) -> dict[Any, int]: } grid_dims = ["x", "y", "z"] + pointwise_tensor_dims = list(reversed(grid_dims)) reduction_dims = ["r0_", "r1_"] if no_x_dim: tensor_dims = reduction_dims elif no_r_dim: - tensor_dims = grid_dims + tensor_dims = pointwise_tensor_dims else: - tensor_dims = grid_dims + reduction_dims + tensor_dims = pointwise_tensor_dims + reduction_dims # Filter out unused tensor dims. # Convert to dicts for O(1) index lookup. @@ -814,17 +815,10 @@ def prepare_indexing( return self.codegen_indexing(simp_index) - def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot]: - trees = [ + def active_range_trees(self) -> list[IterationRangesRoot]: + return [ t for t in self.range_trees if not t.is_reduction or self.inside_reduction ] - if reorder and len(trees) > 1: - count = sum(t.prefix in "xyz" for t in trees) - assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ - t.prefix for t in trees[:count] - ] - trees[:count] = reversed(trees[:count]) - return trees def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 6e34b99cb70c..a64a1c88c735 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1966,7 +1966,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: index_relative_to_xyr_index = sympy_subs( index, {v: t.expr for v, t in self.range_tree_nodes.items()} ) - range_trees = self.active_range_trees(reorder=True) + range_trees = self.active_range_trees() # Partition the index into subexpressions pertaining to each range tree. # For example xindex * 5 + r0_index * 3 is partitioned to pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy