Skip to content

Add a test for AsyncCollectiveTensor handling for maybe-view ops #152688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented May 2, 2025

Copy link

pytorch-bot bot commented May 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152688

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 9cc4a79 with merge base a4a7716 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

out2 = out1.to(dtype=torch.bfloat16)
# this tests that the call to .to() properly triggered a wait() on the AsyncCollectiveTensor
self.assertTrue(type(out2) is torch.Tensor)
self.assertEqual(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 I wasn't able to repro the actual silent correctness in this test when I backed out the original fix (I confirmed that we we're calling out1.to(...) and doing a proper dtype conversion before issuing a proper wait_tensor() on the input, but the final output was still correct).

I did confirm that the earlier assert failed before though, which I think is a reasonable proxy - the output of calling .to() should give us a plain tensor, not an ACT, to represent the fact that we properly issued a sync.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bdhirsh I think assertEqual would eventually trigger a wait on out2 -- that's why you wouldn't see data corruption, with or without the fix.

In fact the return of ACT instead of Tensor is the only problem (copied from #152534):

The issue seems to be that for an AsyncCollectiveTensor t, invoking t.float() does not trigger the wait_tensor, in which case it would return a regular torch.Tensor, but instead it returns a new AsyncCollectiveTensor with garbage data.

If the fix makes a difference on self.assertTrue(type(out2) is torch.Tensor), then we are good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think assertEqual would eventually trigger a wait on out2 -- that's why you wouldn't see data corruption, with or without the fix.

so we are agreed on this. What I'm surprised by, though, is that even if assertEqual issues a wait, that is still not early enough. If we issue the .to() dtype conversion kernel before running wait_tensor, I would have imagined that the .to() kernel would read garbage data (the allgather output buffer) before synchronizing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way, I agree that my torch.Tensor assertion is enough to convince myself that this test is ok

Copy link
Contributor

@kwen2501 kwen2501 May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arg, you are right. After we made .to() return a torch.Tensor, we'd still need to decide whether that torch.Tensor should be a waited tensor or not.
If it is an unwaited tensor, it seems we would get into trouble as well, because the user has now lost the ACT handle and cannot call wait_tensor anymore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But wait, wouldn't the transition from ACT to Tensor already trigger a wait? Are you saying that this is not observed in your test?

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test!

out2 = out1.to(dtype=torch.bfloat16)
# this tests that the call to .to() properly triggered a wait() on the AsyncCollectiveTensor
self.assertTrue(type(out2) is torch.Tensor)
self.assertEqual(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bdhirsh I think assertEqual would eventually trigger a wait on out2 -- that's why you wouldn't see data corruption, with or without the fix.

In fact the return of ACT instead of Tensor is the only problem (copied from #152534):

The issue seems to be that for an AsyncCollectiveTensor t, invoking t.float() does not trigger the wait_tensor, in which case it would return a regular torch.Tensor, but instead it returns a new AsyncCollectiveTensor with garbage data.

If the fix makes a difference on self.assertTrue(type(out2) is torch.Tensor), then we are good.

@kwen2501
Copy link
Contributor

kwen2501 commented May 2, 2025

Resolves #152534.

@albanD albanD removed their request for review May 2, 2025 17:23
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 5, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 5, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy