-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Closed
Description
run_decompositions
(Could be AOT or functionalization) generates inefficient operations composed of meaningless slice and use slice_scatter to copy. It might be worth to optimize this pattern.
Repro: It's a very common pattern in LLMs
import os
import torch
import transformers
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, causal_mask, fill_value):
causal_mask = causal_mask.clone()
mask_length = fill_value.shape[-1]
causal_mask[:, :, :, :mask_length] = fill_value
return causal_mask
model = MyModule()
model.eval()
B = 2
N = 2
S = 3
T = 4
T2 = 3
causal_mask = torch.randn(B, N, S, T)
fill_value = torch.randn(B, N, S, T2)
inputs = (causal_mask, fill_value)
symM = torch.export.Dim("M")
symN = torch.export.Dim("N")
dynamic_shapes = {
'causal_mask': {3: symM},
'fill_value': {3: symN},
}
program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
print(program)
decomposed_program = program.run_decompositions()
print(decomposed_program)
Exported_program:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, causal_mask: "f32[2, 2, 3, s30]", fill_value: "f32[2, 2, 3, s58]"):
#
sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(fill_value, 3)
# File: /home/titaiwang/pytorch/test_slice_scatter.py:10 in forward, code: causal_mask = causal_mask.clone()
clone: "f32[2, 2, 3, s30]" = torch.ops.aten.clone.default(causal_mask); causal_mask = None
# File: /home/titaiwang/pytorch/test_slice_scatter.py:12 in forward, code: causal_mask[:, :, :, :mask_length] = fill_value
slice_1: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_2: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807); slice_1 = None
slice_3: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_2, 2, 0, 9223372036854775807); slice_2 = None
slice_4: "f32[2, 2, 3, s58]" = torch.ops.aten.slice.Tensor(slice_3, 3, 0, sym_size_int_4); slice_3 = sym_size_int_4 = None
copy_: "f32[2, 2, 3, s58]" = torch.ops.aten.copy_.default(slice_4, fill_value); slice_4 = fill_value = copy_ = None
return (clone,)
Graph signature:
# inputs
causal_mask: USER_INPUT
fill_value: USER_INPUT
# outputs
clone: USER_OUTPUT
Range constraints: {s30: VR[0, int_oo], s58: VR[0, int_oo]}
Decomposed_program:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, causal_mask: "f32[2, 2, 3, s30]", fill_value: "f32[2, 2, 3, s58]"):
#
sym_size_int_5: "Sym(s58)" = torch.ops.aten.sym_size.int(fill_value, 3)
# File: /home/titaiwang/pytorch/test_slice_scatter.py:10 in forward, code: causal_mask = causal_mask.clone()
clone: "f32[2, 2, 3, s30]" = torch.ops.aten.clone.default(causal_mask); causal_mask = None
# File: /home/titaiwang/pytorch/test_slice_scatter.py:12 in forward, code: causal_mask[:, :, :, :mask_length] = fill_value
slice_1: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_2: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807); slice_1 = None
slice_3: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_2, 2, 0, 9223372036854775807); slice_2 = None
slice_4: "f32[2, 2, 3, s58]" = torch.ops.aten.slice.Tensor(slice_3, 3, 0, sym_size_int_5); slice_3 = None
copy: "f32[2, 2, 3, s58]" = torch.ops.aten.copy.default(slice_4, fill_value); slice_4 = fill_value = None
slice_5: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_6: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807)
slice_7: "f32[2, 2, 3, s30]" = torch.ops.aten.slice.Tensor(slice_6, 2, 0, 9223372036854775807)
slice_scatter: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(slice_7, copy, 3, 0, sym_size_int_5); slice_7 = copy = sym_size_int_5 = None
slice_scatter_1: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(slice_6, slice_scatter, 2, 0, 9223372036854775807); slice_6 = slice_scatter = None
slice_scatter_2: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(slice_5, slice_scatter_1, 1, 0, 9223372036854775807); slice_5 = slice_scatter_1 = None
slice_scatter_3: "f32[2, 2, 3, s30]" = torch.ops.aten.slice_scatter.default(clone, slice_scatter_2, 0, 0, 9223372036854775807); clone = slice_scatter_2 = None
return (slice_scatter_3,)
Graph signature:
# inputs
causal_mask: USER_INPUT
fill_value: USER_INPUT
# outputs
slice_scatter_3: USER_OUTPUT
Range constraints: {s30: VR[0, int_oo], s58: VR[0, int_oo]}
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4