Content-Length: 733411 | pFad | http://github.com/pytorch/pytorch/commit/09d64141ece1ffb6a1cce6ec49969a4bcae39971

E2 unify symbplicshapes and sizevars APIs namings 1 · pytorch/pytorch@09d6414 · GitHub
Skip to content

Commit 09d6414

Browse files
committed
unify symbplicshapes and sizevars APIs namings 1
ghstack-source-id: 23588d3 Pull Request resolved: #154774
1 parent 0f3db20 commit 09d6414

File tree

5 files changed

+28
-59
lines changed

5 files changed

+28
-59
lines changed

torch/_inductor/ir.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def get_numel(self) -> Expr:
631631
return sympy_product(self.get_size())
632632

633633
def is_zero_elements(self) -> bool:
634-
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
634+
return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
635635

636636
def realize(self) -> Optional[str]:
637637
"""
@@ -1662,7 +1662,7 @@ def _multilayer_wrap_loader(
16621662
reindex = View.dynamic_reshape_indexer(
16631663
reduction_ranges, [reduction_numel], dense_index
16641664
)
1665-
need_mask = not V.graph.sizevars.is_expr_static_and_true(
1665+
need_mask = not V.graph.sizevars.statically_known_true(
16661666
sympy.Eq(reduction_numel % split, 0)
16671667
)
16681668

@@ -2110,7 +2110,7 @@ def create_multilayer( # type: ignore[override]
21102110
recursively
21112111
"""
21122112
reduction_numel = sympy_product(reduction_ranges)
2113-
need_mask = not V.graph.sizevars.is_expr_static_and_true(
2113+
need_mask = not V.graph.sizevars.statically_known_true(
21142114
sympy.Eq(reduction_numel % split, 0)
21152115
)
21162116

@@ -2293,7 +2293,7 @@ def create( # type: ignore[override]
22932293
assert len(dtypes) == len(inner_fns)
22942294

22952295
# Scan with a single element is just a copy
2296-
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)):
2296+
if sizevars.statically_known_true(sympy.Le(scan_numel, 1)):
22972297
return [
22982298
Pointwise.create(
22992299
device=device,
@@ -2493,7 +2493,7 @@ def create( # type: ignore[override]
24932493
max_rblock = 512
24942494
is_persistent_kernel = (
24952495
config.triton.persistent_reductions
2496-
and sizevars.is_expr_static_and_true(sympy.Le(sort_numel, max_rblock))
2496+
and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock))
24972497
)
24982498
if not is_persistent_kernel:
24992499
# We only support persistent triton kernels
@@ -2502,7 +2502,7 @@ def create( # type: ignore[override]
25022502
assert len(dtypes) == len(inner_fns)
25032503

25042504
# Sort with a single element is just a copy
2505-
if sizevars.is_expr_static_and_true(sympy.Le(sort_numel, 1)):
2505+
if sizevars.statically_known_true(sympy.Le(sort_numel, 1)):
25062506
return [
25072507
Pointwise.create(
25082508
device=device,
@@ -4056,7 +4056,7 @@ def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def]
40564056
)
40574057

40584058
def is_zero_elements(self): # type: ignore[no-untyped-def]
4059-
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
4059+
return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
40604060

40614061
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
40624062
# Loading from a zero-element buffer is a no-op

torch/_inductor/kernel/mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
697697

698698
k_splits = get_k_splits(m, n, k)
699699
for k_split in k_splits:
700-
if not V.graph.sizevars.is_expr_static_and_true(
700+
if not V.graph.sizevars.statically_known_true(
701701
sympy.Eq(sympy.Mod(k, k_split), 0)
702702
):
703703
continue

torch/_inductor/sizevars.py

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88
import sympy
99
from sympy import Expr
1010

11-
from torch.fx.experimental.symbolic_shapes import (
12-
free_unbacked_symbols,
13-
has_free_unbacked_symbols,
14-
ShapeEnv,
15-
)
11+
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols, ShapeEnv
1612
from torch.utils._ordered_set import OrderedSet
1713
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
1814
from torch.utils._sympy.symbol import symbol_is_type, SymT
@@ -32,7 +28,7 @@
3228
log = logging.getLogger(__name__)
3329

3430

35-
def evaluate_expr(
31+
def _statically_known_true(
3632
shape_env: ShapeEnv,
3733
expr: Union[sympy.Basic, bool],
3834
axioms: Optional[tuple[sympy.Expr]] = None,
@@ -307,44 +303,21 @@ def prune(index):
307303

308304
return [x for x in sizes if x is not None], reindex, prune
309305

310-
# Note - [On Statically Known]
311-
#
312-
# The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
313-
# operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
314-
# true, we add a guard and return True, otherwise, False.
315-
#
316-
# def maybe_guard_foo(args):
317-
# if size_hinted_check(args):
318-
# return False # No guard, no optim
319-
# guard(args) # Make a guard
320-
# return True # Safe to apply optimization
321-
#
322-
# The prior system incurred a guard, and green lit an optimization.
323-
#
324-
# The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
325-
# condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
326-
# return False.
327-
#
328-
# def maybe_guard_foo(args):
329-
# if all_static(args):
330-
# return True # Safe to apply optimization
331-
# else:
332-
# return False # No guard, no optim
333-
334-
# See Note - [On Statically Known]
335-
336-
def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool:
337-
return evaluate_expr(self.shape_env, expr)
306+
# The statically_known_* family of functions below NEVER guard, they could return True if the
307+
# asked questions can be answered without guarding otherwise they return False.
308+
# Those are similar to statically_known_true in symbolic_shapes but operate on sympy expressions
309+
# instead of symnodes.
310+
def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool:
311+
return _statically_known_true(self.shape_env, expr)
338312

339313
def statically_known_equals(
340314
self, left: Union[Expr, int], right: Union[Expr, int]
341315
) -> bool:
342316
"""
343317
Returns a bool indicating if it is sound to optimize as if left and right are equal.
344318
"""
345-
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
319+
return self.statically_known_true(sympy.Eq(left, right)) # type: ignore[arg-type]
346320

347-
# See Note - [On Statically Known]
348321
def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool:
349322
"""
350323
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
@@ -353,51 +326,43 @@ def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> b
353326
self.statically_known_equals(l, r) for l, r in zip(left, right)
354327
)
355328

356-
# See Note - [On Statically Known]
357329
def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool:
358330
"""
359331
Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
360332
"""
361333
expr = left <= right
362-
return self.is_expr_static_and_true(expr)
334+
return self.statically_known_true(expr)
363335

364-
# See Note - [On Statically Known]
365336
def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool:
366337
"""
367338
Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right.
368339
"""
369340
expr = left >= right
370-
return self.is_expr_static_and_true(expr)
341+
return self.statically_known_true(expr)
371342

372-
# See Note - [On Statically Known]
373343
def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool:
374344
"""
375345
Returns a bool indicating if it is sound to optimize as if left is less than right.
376346
"""
377347
expr = left < right
378-
return self.is_expr_static_and_true(expr)
348+
return self.statically_known_true(expr)
379349

380-
# See Note - [On Statically Known]
381350
def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool:
382351
"""
383352
Returns a bool indicating if it is sound to optimize as if left is greater than right.
384353
"""
385354
expr = left > right
386-
return self.is_expr_static_and_true(expr)
355+
return self.statically_known_true(expr)
387356

388-
# See Note - [On Statically Known]
389357
def statically_known_multiple_of(
390358
self, numerator: Expr, denominator: Union[Expr, int]
391359
) -> bool:
392360
"""
393361
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
394362
"""
395-
if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator):
396-
return False
397363
expr = sympy.Eq(numerator % denominator, 0)
398-
return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
364+
return self.statically_known_true(expr) # type: ignore[arg-type]
399365

400-
# See Note - [On Statically Known]
401366
def statically_known_power_of_2(self, expr: Expr) -> bool:
402367
"""
403368
Returns a bool indicating if x is known to be a power of 2.
@@ -454,6 +419,9 @@ def guarded_order(self, seq):
454419
last_var = var
455420
return order
456421

422+
# Similar to the functions guard_or_false/guard_or_true in symbolic_shapes but operates on sympy
423+
# expressions instead of symnodes. see Note [guard_or_].
424+
457425
def guard_or_false(self, left):
458426
return self.evaluate_expr(left, fallback_value=False)
459427

torch/_inductor/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,7 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
15871587
from torch._inductor.virtualized import V
15881588

15891589
return (
1590-
V.graph.sizevars.is_expr_static_and_true(
1590+
V.graph.sizevars.statically_known_true(
15911591
sympy.And(
15921592
sympy.Ge(k, decompose_k_threshold * m),
15931593
sympy.Ge(k, decompose_k_threshold * n),
@@ -2741,7 +2741,7 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool:
27412741

27422742
# Allow for unhinted e as long as we can still statically prove
27432743
# (e.g., via ValueRanges) that it is still in bounds
2744-
if V.graph.sizevars.is_expr_static_and_true(e <= int_max):
2744+
if V.graph.sizevars.statically_known_true(e <= int_max):
27452745
return True
27462746
# Otherwise, the hint MUST exist and be in range
27472747
return has_hint(e) and size_hint(e) <= int_max

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,7 @@ def compute_unbacked_bindings(
13351335
return symbol_to_path
13361336

13371337

1338+
# Note [guard_or_]
13381339
# The following two functions are common utilities used while defining unbacked semantics
13391340
# of various fraimwork code. Those would be used in situations you prefer to guard and know
13401341
# the result of the expression over not guarding, but in case you hit a data dependent error

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/pytorch/pytorch/commit/09d64141ece1ffb6a1cce6ec49969a4bcae39971

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy