Skip to content

[export] run_decompositions generates inefficient operations #157289

@titaiwangms

Description

@titaiwangms

run_decompositions (Could be AOT or functionalization) generates inefficient operations composed of meaningless slice and use slice_scatter to copy. It might be worth to optimize this pattern.

Repro: It's a very common pattern in LLMs

import os
import torch
import transformers
 
class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
 
    def forward(self, causal_mask, fill_value):
        causal_mask = causal_mask.clone()
        mask_length = fill_value.shape[-1]
        causal_mask[:, :, :, :mask_length] = fill_value
        return causal_mask
 
model = MyModule()
model.eval()
 
B = 2
N = 2
S = 3
T = 4
T2 = 3
causal_mask = torch.randn(B, N, S, T)
fill_value = torch.randn(B, N, S, T2)
inputs = (causal_mask, fill_value)
 
 
symM = torch.export.Dim("M")
symN = torch.export.Dim("N")
 
dynamic_shapes = {
    'causal_mask': {3: symM},
    'fill_value': {3: symN},
}
 
program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
print(program)
decomposed_program = program.run_decompositions()
print(decomposed_program)

Exported_program:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, causal_mask: "f32[2, 2, 3, s30]", fill_value: "f32[2, 2, 3, s58]"):
             # 
            sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(fill_value, 3)
            
             # File: /home/titaiwang/pytorch/test_slice_scatter.py:10 in forward, code: causal_mask = causal_mask.clone()
            clone: "f32[2, 2, 3, s30]" = torch.ops.aten.clone.default(causal_mask);  causal_mask = None
            
             # File: /home/titaiwang/pytorch/test_slice_scatter.py:12 in forward, code: causal_mask[:, :, :, :mask_length] = fill_value
            slice_1: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_2: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
            slice_3: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_2, 2, 0, 9223372036854775807);  slice_2 = None
            slice_4: "f32[2, 2, 3, s58]" = torch.ops.aten.slice.Tensor(slice_3, 3, 0, sym_size_int_4);  slice_3 = sym_size_int_4 = None
            copy_: "f32[2, 2, 3, s58]" = torch.ops.aten.copy_.default(slice_4, fill_value);  slice_4 = fill_value = copy_ = None
            return (clone,)
            
Graph signature: 
    # inputs
    causal_mask: USER_INPUT
    fill_value: USER_INPUT
    
    # outputs
    clone: USER_OUTPUT
    
Range constraints: {s30: VR[0, int_oo], s58: VR[0, int_oo]}

Decomposed_program:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, causal_mask: "f32[2, 2, 3, s30]", fill_value: "f32[2, 2, 3, s58]"):
             # 
            sym_size_int_5: "Sym(s58)" = torch.ops.aten.sym_size.int(fill_value, 3)
            
             # File: /home/titaiwang/pytorch/test_slice_scatter.py:10 in forward, code: causal_mask = causal_mask.clone()
            clone: "f32[2, 2, 3, s30]" = torch.ops.aten.clone.default(causal_mask);  causal_mask = None
            
             # File: /home/titaiwang/pytorch/test_slice_scatter.py:12 in forward, code: causal_mask[:, :, :, :mask_length] = fill_value
            slice_1: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_2: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
            slice_3: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_2, 2, 0, 9223372036854775807);  slice_2 = None
            slice_4: "f32[2, 2, 3, s58]" = torch.ops.aten.slice.Tensor(slice_3, 3, 0, sym_size_int_5);  slice_3 = None
            copy: "f32[2, 2, 3, s58]" = torch.ops.aten.copy.default(slice_4, fill_value);  slice_4 = fill_value = None
            slice_5: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_6: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807)
            slice_7: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_6, 2, 0, 9223372036854775807)
            slice_scatter: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(slice_7, copy, 3, 0, sym_size_int_5);  slice_7 = copy = sym_size_int_5 = None
            slice_scatter_1: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(slice_6, slice_scatter, 2, 0, 9223372036854775807);  slice_6 = slice_scatter = None
            slice_scatter_2: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(slice_5, slice_scatter_1, 1, 0, 9223372036854775807);  slice_5 = slice_scatter_1 = None
            slice_scatter_3: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(clone, slice_scatter_2, 0, 0, 9223372036854775807);  clone = slice_scatter_2 = None
            return (slice_scatter_3,)
            
Graph signature: 
    # inputs
    causal_mask: USER_INPUT
    fill_value: USER_INPUT
    
    # outputs
    slice_scatter_3: USER_OUTPUT
    
Range constraints: {s30: VR[0, int_oo], s58: VR[0, int_oo]}

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy