Skip to content

[Inductor] Improve memory locality by iterating over y dimension before x #149339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
handle no_x_dim
  • Loading branch information
blaine-rister committed Mar 18, 2025
commit eca6fc5867627aa60ba555f5724760d7f22018d9
30 changes: 18 additions & 12 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,22 +820,28 @@ def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot]
t for t in self.range_trees if not t.is_reduction or self.inside_reduction
]

# Put all trees with tensor_dim=None at the end, in their current order.
end_key = (
max(
(tree.tensor_dim for tree in trees if tree.tensor_dim is not None),
default=0,
)
+ 1
)
if not reorder:
return trees

def tree_key(tree: IterationRangesRoot) -> int:
return tree.tensor_dim if tree.tensor_dim is not None else end_key
assert tree.tensor_dim is not None, f"Missing tensor_dim for tree {tree}"
return tree.tensor_dim

if reorder:
trees = sorted(trees, key=tree_key)
# Keep all trees with tensor_dim=None in their current position.
sorted_trees = sorted(
(tree for tree in trees if tree.tensor_dim is not None),
key=tree_key,
)
sorted_idx = 0
final_trees = []
for tree in trees:
if tree.tensor_dim is None:
final_trees.append(tree)
else:
final_trees.append(sorted_trees[sorted_idx])
sorted_idx += 1

return trees
return final_trees

def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr:
expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy