-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Also support non-contiguous activation for torch._weight_int8pack_mm on CPU #147588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Also support non-contiguous activation for torch._weight_int8pack_mm on CPU #147588
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147588
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 91b6790 with merge base 8a5265c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This comment was marked as resolved.
This comment was marked as resolved.
the activation could be a slice of another tensor, but it should still be in row-major order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need to benchmark this, won't affect the performance. just change the error message a little bit.
TORCH_CHECK(A.dim() == 2, | ||
__func__, " : expect A to be 2D tensor."); | ||
|
||
TORCH_CHECK(A.stride(1) == 1, | ||
__func__, " : A must be row-major even if it's non-contiguous"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the error message to "A must be contiguous on last dimension."
`A must be contiguous on last dimension.` seems more sensible.
@pytorchbot merge |
Merge failedReason: Approvers from one of the following sets are needed:
|
Hi @malfet, can you please help review this PR? It requires approval from a core reviewer/maintainer for landing. Thanks! |
Hmm, this change looks BC breaking to me(at quick glance, afk atm). Any reason making it incompatible with contiguous A? |
Hi @malfet, thanks for promptly following up! This change makes non-contiguous activations (so long as they're contiguous in the last dimension) compatible with Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the explanation
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…on CPU (pytorch#147588) ### Problem Non-contiguous activation for `torch._weight_int8pack_mm` is unsupported on CPU. So, with int8 WoQ with B16 activation with torchao, for batch-size 2 & above, an assertion is hit regarding non-contiguous A being unsupported. Such an issue was encountered with LLaMA models. ### Solution Also support non-contiguous activation for `torch._weight_int8pack_mm`, so long as it's contiguous on the last dimension & remove the assertion that requires contiguous activation. ### Alternative solutions considered Could modify LLaMA model in transformers library to call `contiguous` after obtaining the final hidden state, just before computing logits with the LM head. However, [it](huggingface/transformers#36078) might cause some regression for other users of that code. Another aspect to this issue is - is latency always lower if we make an activation tensor contiguous before linear or `torch._weight_int8pack_mm` is called on CPU? I guess we need some data-points to analyze this part, although I think the performance should be good enough with this patch, since the first cache lines of rows of A are being explicitly prefetched in the existing code (and it also avoids copy, which a `contiguous` call would do). Pull Request resolved: pytorch#147588 Approved by: https://github.com/mingfeima, https://github.com/leslie-fang-intel, https://github.com/malfet
Problem
Non-contiguous activation for
torch._weight_int8pack_mm
is unsupported on CPU.So, with int8 WoQ with B16 activation with torchao, for batch-size 2 & above, an assertion is hit regarding non-contiguous A being unsupported. Such an issue was encountered with LLaMA models.
Solution
Also support non-contiguous activation for
torch._weight_int8pack_mm
, so long as it's contiguous on the last dimension & remove the assertion that requires contiguous activation.Alternative solutions considered
Could modify LLaMA model in transformers library to call
contiguous
after obtaining the final hidden state, just before computing logits with the LM head. However, it might cause some regression for other users of that code.Another aspect to this issue is - is latency always lower if we make an activation tensor contiguous before linear or
torch._weight_int8pack_mm
is called on CPU? I guess we need some data-points to analyze this part, although I think the performance should be good enough with this patch, since the first cache lines of rows of A are being explicitly prefetched in the existing code (and it also avoids copy, which acontiguous
call would do).cc @jgong5 @mingfeima @XiaobingSuper @ashokei @jingxu10