-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🚀 The feature, motivation and pitch
Currently, I'm working on a feature to compare model outputs throughout multiple devices for accuracy purposes. Currently, when executing in eager mode I have a dispatch mode take care of this, and when compiling graph regions with dynamo I have a custom backend handle this.
However, if I run a model in a mix of both (i.e. compiling a model with some dynamo disabled modules, I would like to use both at the same time to validate outputs).
For example, for the given model:
@torch._dynamo.disable()
class TestSubModule(torch.nn.Module):
def forward(self, x):
return 2 * x
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = TestSubModule()
def forward(self, x):
x = self.submod(tmp)
return tmp + 4
I would like like my custom dispatch mode to ONLY run on the disabled region of TestSubModule.
Alternatives
-
A (hacky) alternative I have working is to search through modules, and wrap them with the dispatch mode if the flags that dynamo uses to disable are present.
-
Rewrite the dispatch mode as a tensor subclass, and use torch.compiler.is_compiling to see if we are in a disabled/untraced region
Additional context
No response
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @zou3519 @bdhirsh