File tree Expand file tree Collapse file tree 2 files changed +3
-26
lines changed Expand file tree Collapse file tree 2 files changed +3
-26
lines changed Original file line number Diff line number Diff line change @@ -815,34 +815,11 @@ def prepare_indexing(
815
815
816
816
return self .codegen_indexing (simp_index )
817
817
818
- def active_range_trees (self , reorder : bool = False ) -> list [IterationRangesRoot ]:
819
- trees = [
818
+ def active_range_trees (self ) -> list [IterationRangesRoot ]:
819
+ return [
820
820
t for t in self .range_trees if not t .is_reduction or self .inside_reduction
821
821
]
822
822
823
- if not reorder :
824
- return trees
825
-
826
- def tree_key (tree : IterationRangesRoot ) -> int :
827
- assert tree .tensor_dim is not None , f"Missing tensor_dim for tree { tree } "
828
- return tree .tensor_dim
829
-
830
- # Keep trees with tensor_dim=None in their current positions.
831
- sorted_trees = sorted (
832
- (tree for tree in trees if tree .tensor_dim is not None ),
833
- key = tree_key ,
834
- )
835
- sorted_idx = 0
836
- final_trees = []
837
- for tree in trees :
838
- if tree .tensor_dim is None :
839
- final_trees .append (tree )
840
- else :
841
- final_trees .append (sorted_trees [sorted_idx ])
842
- sorted_idx += 1
843
-
844
- return final_trees
845
-
846
823
def codegen_indexing (self , expr : sympy .Expr ) -> sympy .Expr :
847
824
expr = V .graph .sizevars .simplify_with_ranges (expr , self .var_ranges ())
848
825
for sym in sorted (expr .free_symbols , key = str ):
Original file line number Diff line number Diff line change @@ -1966,7 +1966,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]:
1966
1966
index_relative_to_xyr_index = sympy_subs (
1967
1967
index , {v : t .expr for v , t in self .range_tree_nodes .items ()}
1968
1968
)
1969
- range_trees = self .active_range_trees (reorder = True )
1969
+ range_trees = self .active_range_trees ()
1970
1970
1971
1971
# Partition the index into subexpressions pertaining to each range tree.
1972
1972
# For example xindex * 5 + r0_index * 3 is partitioned to
You can’t perform that action at this time.
0 commit comments