8
8
import sympy
9
9
from sympy import Expr
10
10
11
- from torch .fx .experimental .symbolic_shapes import (
12
- free_unbacked_symbols ,
13
- has_free_unbacked_symbols ,
14
- ShapeEnv ,
15
- )
11
+ from torch .fx .experimental .symbolic_shapes import has_free_unbacked_symbols , ShapeEnv
16
12
from torch .utils ._ordered_set import OrderedSet
17
13
from torch .utils ._sympy .functions import FloorDiv , ModularIndexing
18
14
from torch .utils ._sympy .symbol import symbol_is_type , SymT
32
28
log = logging .getLogger (__name__ )
33
29
34
30
35
- def evaluate_expr (
31
+ def statically_known_true (
36
32
shape_env : ShapeEnv ,
37
33
expr : Union [sympy .Basic , bool ],
38
34
axioms : Optional [tuple [sympy .Expr ]] = None ,
@@ -307,44 +303,21 @@ def prune(index):
307
303
308
304
return [x for x in sizes if x is not None ], reindex , prune
309
305
310
- # Note - [On Statically Known]
311
- #
312
- # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
313
- # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
314
- # true, we add a guard and return True, otherwise, False.
315
- #
316
- # def maybe_guard_foo(args):
317
- # if size_hinted_check(args):
318
- # return False # No guard, no optim
319
- # guard(args) # Make a guard
320
- # return True # Safe to apply optimization
321
- #
322
- # The prior system incurred a guard, and green lit an optimization.
323
- #
324
- # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
325
- # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
326
- # return False.
327
- #
328
- # def maybe_guard_foo(args):
329
- # if all_static(args):
330
- # return True # Safe to apply optimization
331
- # else:
332
- # return False # No guard, no optim
333
-
334
- # See Note - [On Statically Known]
335
-
336
- def is_expr_static_and_true (self , expr : Union [sympy .Basic , bool ]) -> bool :
337
- return evaluate_expr (self .shape_env , expr )
306
+ # The statically_known_* family of functions below NEVER guard, they could return True if the
307
+ # asked questions can be answered without guarding otherwise they return False.
308
+ # Those are similar to statically_known_true in symbolic_shapes but operate on sympy expressions
309
+ # instead of symnodes.
310
+ def statically_known_true (self , expr : Union [sympy .Basic , bool ]) -> bool :
311
+ return statically_known_true (self .shape_env , expr )
338
312
339
313
def statically_known_equals (
340
314
self , left : Union [Expr , int ], right : Union [Expr , int ]
341
315
) -> bool :
342
316
"""
343
317
Returns a bool indicating if it is sound to optimize as if left and right are equal.
344
318
"""
345
- return self .is_expr_static_and_true (sympy .Eq (left , right )) # type: ignore[arg-type]
319
+ return self .statically_known_true (sympy .Eq (left , right )) # type: ignore[arg-type]
346
320
347
- # See Note - [On Statically Known]
348
321
def statically_known_list_equals (self , left : list [Expr ], right : list [Expr ]) -> bool :
349
322
"""
350
323
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
@@ -353,51 +326,43 @@ def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> b
353
326
self .statically_known_equals (l , r ) for l , r in zip (left , right )
354
327
)
355
328
356
- # See Note - [On Statically Known]
357
329
def statically_known_leq (self , left : Expr , right : Union [Expr , int ]) -> bool :
358
330
"""
359
331
Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
360
332
"""
361
333
expr = left <= right
362
- return self .is_expr_static_and_true (expr )
334
+ return self .statically_known_true (expr )
363
335
364
- # See Note - [On Statically Known]
365
336
def statically_known_geq (self , left : Expr , right : Union [Expr , int ]) -> bool :
366
337
"""
367
338
Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right.
368
339
"""
369
340
expr = left >= right
370
- return self .is_expr_static_and_true (expr )
341
+ return self .statically_known_true (expr )
371
342
372
- # See Note - [On Statically Known]
373
343
def statically_known_lt (self , left : Expr , right : Union [Expr , int ]) -> bool :
374
344
"""
375
345
Returns a bool indicating if it is sound to optimize as if left is less than right.
376
346
"""
377
347
expr = left < right
378
- return self .is_expr_static_and_true (expr )
348
+ return self .statically_known_true (expr )
379
349
380
- # See Note - [On Statically Known]
381
350
def statically_known_gt (self , left : Expr , right : Union [Expr , int ]) -> bool :
382
351
"""
383
352
Returns a bool indicating if it is sound to optimize as if left is greater than right.
384
353
"""
385
354
expr = left > right
386
- return self .is_expr_static_and_true (expr )
355
+ return self .statically_known_true (expr )
387
356
388
- # See Note - [On Statically Known]
389
357
def statically_known_multiple_of (
390
358
self , numerator : Expr , denominator : Union [Expr , int ]
391
359
) -> bool :
392
360
"""
393
361
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
394
362
"""
395
- if free_unbacked_symbols (numerator ) or free_unbacked_symbols (denominator ):
396
- return False
397
363
expr = sympy .Eq (numerator % denominator , 0 )
398
- return self .is_expr_static_and_true (expr ) # type: ignore[arg-type]
364
+ return self .statically_known_true (expr ) # type: ignore[arg-type]
399
365
400
- # See Note - [On Statically Known]
401
366
def statically_known_power_of_2 (self , expr : Expr ) -> bool :
402
367
"""
403
368
Returns a bool indicating if x is known to be a power of 2.
@@ -454,6 +419,9 @@ def guarded_order(self, seq):
454
419
last_var = var
455
420
return order
456
421
422
+ # Similar to the functions guard_or_false/guard_or_true in symbolic_shapes but operates on sympy
423
+ # expressions instead of symnodes. see Note [guard_or_].
424
+
457
425
def guard_or_false (self , left ):
458
426
return self .evaluate_expr (left , fallback_value = False )
459
427
0 commit comments