-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Add a test for AsyncCollectiveTensor handling for maybe-view ops #152688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152688
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 9cc4a79 with merge base a4a7716 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
out2 = out1.to(dtype=torch.bfloat16) | ||
# this tests that the call to .to() properly triggered a wait() on the AsyncCollectiveTensor | ||
self.assertTrue(type(out2) is torch.Tensor) | ||
self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kwen2501 I wasn't able to repro the actual silent correctness in this test when I backed out the original fix (I confirmed that we we're calling out1.to(...)
and doing a proper dtype conversion before issuing a proper wait_tensor()
on the input, but the final output was still correct).
I did confirm that the earlier assert failed before though, which I think is a reasonable proxy - the output of calling .to()
should give us a plain tensor, not an ACT, to represent the fact that we properly issued a sync.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bdhirsh I think assertEqual
would eventually trigger a wait on out2
-- that's why you wouldn't see data corruption, with or without the fix.
In fact the return of ACT instead of Tensor is the only problem (copied from #152534):
The issue seems to be that for an AsyncCollectiveTensor t, invoking t.float() does not trigger the wait_tensor, in which case it would return a regular torch.Tensor, but instead it returns a new AsyncCollectiveTensor with garbage data.
If the fix makes a difference on self.assertTrue(type(out2) is torch.Tensor)
, then we are good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think assertEqual would eventually trigger a wait on out2 -- that's why you wouldn't see data corruption, with or without the fix.
so we are agreed on this. What I'm surprised by, though, is that even if assertEqual
issues a wait, that is still not early enough. If we issue the .to()
dtype conversion kernel before running wait_tensor
, I would have imagined that the .to()
kernel would read garbage data (the allgather output buffer) before synchronizing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way, I agree that my torch.Tensor
assertion is enough to convince myself that this test is ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arg, you are right. After we made .to()
return a torch.Tensor
, we'd still need to decide whether that torch.Tensor
should be a waited tensor or not.
If it is an unwaited tensor, it seems we would get into trouble as well, because the user has now lost the ACT handle and cannot call wait_tensor
anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But wait, wouldn't the transition from ACT to Tensor already trigger a wait? Are you saying that this is not observed in your test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the test!
out2 = out1.to(dtype=torch.bfloat16) | ||
# this tests that the call to .to() properly triggered a wait() on the AsyncCollectiveTensor | ||
self.assertTrue(type(out2) is torch.Tensor) | ||
self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bdhirsh I think assertEqual
would eventually trigger a wait on out2
-- that's why you wouldn't see data corruption, with or without the fix.
In fact the return of ACT instead of Tensor is the only problem (copied from #152534):
The issue seems to be that for an AsyncCollectiveTensor t, invoking t.float() does not trigger the wait_tensor, in which case it would return a regular torch.Tensor, but instead it returns a new AsyncCollectiveTensor with garbage data.
If the fix makes a difference on self.assertTrue(type(out2) is torch.Tensor)
, then we are good.
Resolves #152534. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
We never added a proper test for the fix from #134661
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k