Skip to content

[dtensor][view_op] add as_strided op support to DTensor in FakeTensorMode #151495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: gh/XilunWu/131/base
Choose a base branch
from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Apr 17, 2025

Stack from ghstack (oldest at bottom):

Introduction

flex_attention's FakeTensor propagation flex_attention_fake_impl permutes the stride of out (the attention score) based on query's stride. To enable flex_attention call on DTensor, this requires us add as_strided support on DTensor in FakeTensorMode.

Limited Support

Due to the complexity of supporting actual as_strided on DTensor, I choose to only enable a limited subset:

  1. as_strided only works correctly in FakeTensorMode i.e. shape and strided propagation.
  2. as_strided is only allowed in case where size == input.shape because this PR specifically unblocks the use case of flex_attention_fake_impl.
  3. as_strided requires storage_offset=None because the other case is not defined in DTensor.

Test

pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @tianyu-l

Copy link

pytorch-bot bot commented Apr 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151495

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6a84ea5 with merge base b7c7000 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 17, 2025
@XilunWu XilunWu added the module: dtensor distributed tensor tag label Apr 17, 2025
@XilunWu XilunWu changed the title [dtensor][view_op] add as_strided op support to DTensor [dtensor][view_op] add as_strided op support to DTensor in FakeTensorMode Apr 17, 2025
@XilunWu XilunWu added the topic: not user facing topic category label Apr 17, 2025
@XilunWu XilunWu requested review from wz337, tianyu-l and wanchaol and removed request for tianyu-l April 17, 2025 02:34
… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
@XilunWu XilunWu requested a review from bdhirsh April 17, 2025 02:34
XilunWu added a commit that referenced this pull request Apr 17, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 17, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151496
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain further why we need this support only in fake mode? Is it because we trace through this op and the resulting graph needs to capture a local as_strided operation, but the operation would run on a plain tensor after compilation?



@register_op_strategy(aten.as_strided.default, schema_info=RuntimeSchemaInfo(1))
def as_strided_strategy(op_schema: OpSchema) -> StrategyType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we enforce that this is only used in fake mode?

Copy link
Contributor Author

@XilunWu XilunWu Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the part I don't feel confident -- I did not enforce that. I only tested that it works out for my use case in Fake Tensor mode and totally rely on users to read the comment and not use it in other case.

Can you explain further why we need this support only in fake mode?
When I try calling flex_attention on DTensor in #145353 , I hit error on DTensor as_strided not being implemented. This is because flex_attention in dynamo calls flex_attention_fake_impl to trace shape and stride of its output, and flex_attention_fake_impl uses as_strided. As @bdhirsh said in #145353 comment, one solution is to have dynamo call flex_attention rather than flex_attention_fake_impl but that requires code change in dynamo. This PR is a quick workaround.

… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 18, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #151507

… FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Apr 22, 2025
…rt to DTensor in FakeTensorMode"

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

*  #151497
*  #151507
* __->__ #151495

## Introduction
`flex_attention`'s FakeTensor propagation `flex_attention_fake_impl` [permutes](https://github.com/pytorch/pytorch/blob/fb6ac2f16132f7953711ce6924bc2ee4a033228c/torch/_higher_order_ops/flex_attention.py#L459) the stride of `out` (the attention score) based on `query`'s stride. To enable `flex_attention` call on DTensor, this requires us add `as_strided` support on DTensor in `FakeTensorMode`. 

## Limited Support
Due to the complexity of supporting actual `as_strided` on DTensor, I choose to only enable a limited subset:
1. `as_strided` only works correctly in `FakeTensorMode` i.e. shape and strided propagation.
2. `as_strided` is only allowed in case where `size == input.shape` because this PR specifically unblocks the use case of `flex_attention_fake_impl`. 
3. `as_strided` requires `storage_offset=None` because the other case is not defined in DTensor.

## Test
`pytest test/distributed/tensor/test_view_ops.py -s -k test_as_strided`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k tianyu-l

[ghstack-poisoned]
Divigroup-RAP pushed a commit to Divigroup-RAP/PYTORCH that referenced this pull request Apr 22, 2025
Divigroup-RAP pushed a commit to Divigroup-RAP/PYTORCH that referenced this pull request Apr 22, 2025
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comments inlined


for target_stride in itertools.permutations(stride):
dtensor_y = dtensor_x.as_strided(size, target_stride)
tensor_y = tensor_x.as_strided(size, target_stride)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should also test on the shape/stride assertion you added in the op?

target_size = op_schema.args_schema[1]
assert isinstance(target_size, (tuple, list))
target_stride = op_schema.args_schema[2]
assert isinstance(target_stride, (tuple, list))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need the tuple/list assertion if you are converting them to tuple anyways?

f"as_strided only supports the same size: input has size {inp_size} but target size is {target_size}"
)

assert len(target_size) == len(target_stride), (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't you also need to guard the stride are the same?

f"size and stride should have the same length, but got {len(target_size)} and {len(target_stride)}"
)

from torch.distributed.tensor._ops._tensor_ops import default_strategy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to not re-use the default strategy, it's simple to just write its own

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 23, 2025
superiwan pushed a commit to superiwan/pytorch that referenced this pull request Jul 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue Stale topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy