-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[dtensor][view_op] add as_strided op support to DTensor in FakeTensorMode #151495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/XilunWu/131/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151495
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6a84ea5 with merge base b7c7000 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151496 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain further why we need this support only in fake mode? Is it because we trace through this op and the resulting graph needs to capture a local as_strided operation, but the operation would run on a plain tensor after compilation?
|
||
|
||
@register_op_strategy(aten.as_strided.default, schema_info=RuntimeSchemaInfo(1)) | ||
def as_strided_strategy(op_schema: OpSchema) -> StrategyType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we enforce that this is only used in fake mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the part I don't feel confident -- I did not enforce that. I only tested that it works out for my use case in Fake Tensor mode and totally rely on users to read the comment and not use it in other case.
Can you explain further why we need this support only in fake mode?
When I try callingflex_attention
on DTensor in #145353 , I hit error on DTensoras_strided
not being implemented. This is becauseflex_attention
in dynamo callsflex_attention_fake_impl
to trace shape and stride of its output, andflex_attention_fake_impl
usesas_strided
. As @bdhirsh said in #145353 comment, one solution is to have dynamo callflex_attention
rather thanflex_attention_fake_impl
but that requires code change in dynamo. This PR is a quick workaround.
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
Starting merge as part of PR stack under #151507 |
… FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
…rt to DTensor in FakeTensorMode" Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #151497 * #151507 * __->__ #151495 ## Introduction `flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. ## Limited Support Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset: 1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation. 2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor. ## Test `pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l [ghstack-poisoned]
ghstack-source-id: 2ed5ea5 Pull Request resolved: pytorch/pytorch#151495
ghstack-source-id: 0fc1788 Pull Request resolved: pytorch/pytorch#151495
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comments inlined
|
||
for target_stride in itertools.permutations(stride): | ||
dtensor_y = dtensor_x.as_strided(size, target_stride) | ||
tensor_y = tensor_x.as_strided(size, target_stride) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should also test on the shape/stride assertion you added in the op?
target_size = op_schema.args_schema[1] | ||
assert isinstance(target_size, (tuple, list)) | ||
target_stride = op_schema.args_schema[2] | ||
assert isinstance(target_stride, (tuple, list)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need the tuple/list assertion if you are converting them to tuple anyways?
f"as_strided only supports the same size: input has size {inp_size} but target size is {target_size}" | ||
) | ||
|
||
assert len(target_size) == len(target_stride), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aren't you also need to guard the stride are the same?
f"size and stride should have the same length, but got {len(target_size)} and {len(target_stride)}" | ||
) | ||
|
||
from torch.distributed.tensor._ops._tensor_ops import default_strategy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to not re-use the default strategy, it's simple to just write its own
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
ghstack-source-id: 2ed5ea5 Pull Request resolved: pytorch/pytorch#151495
Stack from ghstack (oldest at bottom):
Introduction
flex_attention
's FakeTensor propagationflex_attention_fake_impl
permutes the stride ofout
(the attention score) based onquery
's stride. To enableflex_attention
call on DTensor, this requires us addas_strided
support on DTensor inFakeTensorMode
.Limited Support
Due to the complexity of supporting actual
as_strided
on DTensor, I choose to only enable a limited subset:as_strided
only works correctly inFakeTensorMode
i.e. shape and strided propagation.as_strided
is only allowed in case wheresize == input.shape
because this PR specifically unblocks the use case offlex_attention_fake_impl
.as_strided
requiresstorage_offset=None
because the other case is not defined in DTensor.Test
pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @tianyu-l