Skip to content

Commit 1bb9b18

Browse files
Xia-Weiwenpytorchmergebot
authored andcommitted
[CPU][Inductor] Improve A16W4 GEMM template performance by using block_n=32 (#156174)
**Summary** We found that using `block_n=32` brings better performance for A16W4 than `block_n=64` because cache locality is better and parallelism is better if N is small and more cores are used. For example, when running Llama-3.1-8B with A16W4 and batch size = 16 on 43 cores, `block_n=32` is faster by >10% E2E for both first and next token. **Test plan** ``` pytest test/inductor/test_cpu_select_algorithm.py -k test_int4_woq_mm_amx ``` Pull Request resolved: #156174 Approved by: https://github.com/leslie-fang-intel
1 parent d99cac2 commit 1bb9b18

File tree

2 files changed

+122
-76
lines changed

2 files changed

+122
-76
lines changed

torch/_inductor/codegen/cpp_gemm_template.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@
234234
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
235235
for (int64_t nci = nc; nci < nc_block_end; nci++) {
236236
{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
237-
{%- if template.should_block_weights %}
237+
{%- if template.should_block_weights and not is_woq_int4 %}
238238
{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
239239
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
240240
{%- else %}
@@ -1125,9 +1125,12 @@ def prep_weight(
11251125
new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight)
11261126
padding = padded_n - n
11271127

1128-
if should_block_weight:
1128+
if should_block_weight and not cls.is_woq_int4():
11291129
blocked_w = cls.block_weight(W, new_size, padding)
11301130
new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size)
1131+
elif should_block_weight:
1132+
assert cls.is_woq_int4()
1133+
new_inputs[1] = cls.block_weight(W, new_size, padding)
11311134
elif isinstance(W, ir.IRNode):
11321135
# Require W layout to be fixed & contiguous, happens inplace.
11331136
ir.ExternKernel.require_contiguous(W)
@@ -1689,7 +1692,68 @@ def q_group_size(cls):
16891692
@staticmethod
16901693
def check_if_block_weight(W, micro_gemm):
16911694
# For WOQ INT4, weight is already packed
1692-
return False
1695+
# However, for AMX microkernel, we want to change the blocking of weight
1696+
from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx
1697+
1698+
return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx)
1699+
1700+
@classmethod
1701+
def block_weight(cls, W, new_size, padding):
1702+
# This method is called only if AMX microkernels are used.
1703+
# In this case, we unpack and repack weight so that block_n=32
1704+
# the format of packed weight is described here:
1705+
# https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583
1706+
if isinstance(W, ir.IRNode):
1707+
# in this case, we do nothing
1708+
ir.ExternKernel.require_contiguous(W)
1709+
blocked_w = W
1710+
else:
1711+
# in this case, we unpack and repack weight
1712+
assert isinstance(W, torch.Tensor)
1713+
assert W.dim() == 2
1714+
N = W.size(0)
1715+
K = W.size(-1) * 2
1716+
G = cls.q_group_size()
1717+
# x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel
1718+
# so that the unpacking process is faster
1719+
x = torch.eye(K).bfloat16()
1720+
# Here we use scale=1 and qzero=8 because we want to unpack weight
1721+
# without dequantizing it. The qzero here is 8 instead of 0 because
1722+
# int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel:
1723+
# https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95
1724+
qscales_and_zeros = (
1725+
torch.tensor([1.0, 8.0])
1726+
.bfloat16()
1727+
.expand(K // G, N, 2)
1728+
.contiguous()
1729+
)
1730+
# shape: [K, N]
1731+
unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu(
1732+
x,
1733+
W,
1734+
G,
1735+
qscales_and_zeros,
1736+
).to(torch.uint8)
1737+
block_n = 32
1738+
# shape: [N // block_n, K, block_n]
1739+
w_blocked = (
1740+
unpacked_w.view(K, N // block_n, block_n)
1741+
.permute(1, 0, 2)
1742+
.contiguous()
1743+
)
1744+
# pack 2 int4 -> 1 int8
1745+
# block_n: [a0, a1, ..., a15, b0, b1, ..., b15]
1746+
# -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...]
1747+
# shape: [N // block_n, K, 2, block_n // 2]
1748+
w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2)
1749+
# shape: [N // block_n, K, block_n // 2]
1750+
w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | (
1751+
w_blocked[:, :, 1, :] << 4
1752+
)
1753+
# shape: [N, K // 2]
1754+
blocked_w = w_blocked_packed.view(N, K // 2)
1755+
1756+
return blocked_w
16931757

16941758
return CppWoqInt4GemmTemplateInstance
16951759

torch/_inductor/codegen/cpp_micro_gemm.py

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,10 +1231,6 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str:
12311231
else:
12321232
assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
12331233
num_columns = block_n // 16
1234-
if self.is_woq_int4():
1235-
# block_n for woq int4 is 64, which is too large for micro kernel
1236-
# so we split it into 2x32. Here num_columns = 2.
1237-
num_columns //= 2
12381234
options = {
12391235
"declare_kernel": self.get_kernel_declaration(),
12401236
"use_cached_dequantized_B": (
@@ -1633,8 +1629,8 @@ def is_woq_int4(self):
16331629
*generate_gemm_config(
16341630
VecAMX,
16351631
[ # (block_m, block_n, block_k)
1636-
(16, 64, 32),
1637-
(32, 64, 32),
1632+
(16, 32, 32),
1633+
(32, 32, 32),
16381634
],
16391635
input_dtype=torch.bfloat16,
16401636
input2_dtype=torch.uint8,
@@ -1646,8 +1642,8 @@ def is_woq_int4(self):
16461642
class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
16471643
"""
16481644
This class generates the code for WoQ int4 micro gemm using AMX intrinsics,
1649-
which are available on 4th and 5th generation Intel Xeon.
1650-
Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2]
1645+
which are available on 4th and newer generations of Intel Xeon.
1646+
Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2]
16511647
Shape of packed ScalesAndZeros = [K // group_size, N, 2]
16521648
Reuse TEMPLATE_KERNEL of CppMicroGemmAMX.
16531649
"""
@@ -1660,7 +1656,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
16601656
{{declare_kernel}} {
16611657
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
16621658
{{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
1663-
{{kernel.assert_function}}({{block_n}} == 64, "block_n must be 64 for WOQ int4");
1659+
{{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4");
16641660
16651661
// Create a stack-allocated buffer for tiles of B.
16661662
// Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
@@ -1674,6 +1670,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
16741670
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
16751671
const int KB = K / BLOCK_K;
16761672
1673+
__m512i b32[COLS * 2];
16771674
__m512 vb[COLS * 2];
16781675
__m512 scale[COLS];
16791676
__m512 zero[COLS];
@@ -1759,7 +1756,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
17591756
// Dequantize a B block of 2 * block_n into bf16
17601757
// So, it handles k and k+1 at the same time
17611758
auto dequantize_B = [&](int n) {
1762-
constexpr int64_t ldb_int4 = BLOCK_N / 2; // 32
1759+
constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16
17631760
for (int k = 0, kb = 0; k < K; k += 2) {
17641761
// Since block_k must be 32 for AMX microkernels, k_start may not be
17651762
// a multiple of q_group_size. In that case, we need to load scales
@@ -1769,35 +1766,25 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
17691766
}
17701767
17711768
// load 256 bits = 64 elements in int4
1772-
__m256i b4 = _mm256_loadu_si256((__m256i*)(B + n * K + k * ldb_int4));
17731769
if (k + PREFETCH_SIZE_K < K) {
17741770
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0);
17751771
}
17761772
1777-
__m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
1778-
vb[0] = _mm512_permutexvar_ps(b32, lut);
1773+
__m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4));
1774+
b32[0] = _mm512_cvtepu8_epi32(b4);
1775+
b32[1] = _mm512_srli_epi32(b32[0], 4);
1776+
vb[0] = _mm512_permutexvar_ps(b32[0] , lut);
17791777
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
1780-
vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1781-
vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
1782-
1783-
b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
1784-
vb[1] = _mm512_permutexvar_ps(b32, lut);
1778+
vb[1] = _mm512_permutexvar_ps(b32[1], lut);
17851779
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
1786-
vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1787-
vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
17881780
1789-
b4 = _mm256_loadu_si256((__m256i*)(B + n * K + (k + 1) * ldb_int4));
1790-
b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
1791-
vb[0 + COLS] = _mm512_permutexvar_ps(b32, lut);
1781+
b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4));
1782+
b32[0 + COLS] = _mm512_cvtepu8_epi32(b4);
1783+
b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4);
1784+
vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut);
17921785
vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]);
1793-
vb[2 + COLS] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1794-
vb[2 + COLS] = _mm512_fmadd_ps(vb[2 + COLS], scale[2], zero[2]);
1795-
1796-
b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
1797-
vb[1 + COLS] = _mm512_permutexvar_ps(b32, lut);
1786+
vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut);
17981787
vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]);
1799-
vb[3 + COLS] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1800-
vb[3 + COLS] = _mm512_fmadd_ps(vb[3 + COLS], scale[3], zero[3]);
18011788
18021789
for (int i = 0; i < COLS; i++) {
18031790
// convert to VNNI
@@ -1811,57 +1798,52 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
18111798
auto v = _mm512_castsi256_si512(v0_bf16);
18121799
v = _mm512_inserti64x4(v, v1_bf16, 1);
18131800
// store the VNNI format bfloat16 values
1814-
// split block_n into 2x32
1815-
{{input_t}}* addr = dequantized_B_buf + K * 32 * (i / 2) + k * 32 + (i % 2) * 32;
1801+
{{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32;
18161802
_mm512_storeu_si512(addr, v);
18171803
}
18181804
}
18191805
};
18201806
1821-
const int64_t updated_ldb = {{block_n}} / 2;
18221807
for (int64_t n = 0; n < N; n += {{block_n}}) {
18231808
// Dequantize K * block_n int8 B elements into BF16
18241809
dequantize_B(n);
1825-
// for woq int4, block_n is 64, which is too large for micro kernel
1826-
for (int64_t ni = 0; ni < {{block_n}}; ni += 32) {
1827-
for (int64_t m = 0; m < M; m += {{block_m}}) {
1828-
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
1829-
int64_t m_tail = m;
1830-
{%- for num_rows in range(block_m, 0, -16) %}
1831-
{%- if num_rows != block_m %}
1832-
else
1833-
{%- endif %}
1834-
if (block_m >= {{num_rows}}) {
1835-
{{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
1836-
amx_state,
1837-
A + m * lda,
1838-
dequantized_B_buf + ni * K,
1839-
C + m * ldc + n + ni,
1840-
K,
1841-
lda,
1842-
updated_ldb,
1843-
ldc,
1844-
16
1845-
);
1846-
block_m -= {{num_rows}};
1847-
m_tail += {{num_rows}};
1848-
}
1849-
{%- endfor %}
1850-
if (block_m > 0) {
1851-
{{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
1852-
amx_state,
1853-
A + m_tail * lda,
1854-
dequantized_B_buf + ni * K,
1855-
C + m_tail * ldc + n + ni,
1856-
K,
1857-
lda,
1858-
updated_ldb,
1859-
ldc,
1860-
block_m
1861-
);
1862-
}
1863-
} // for m
1864-
} // for ni
1810+
for (int64_t m = 0; m < M; m += {{block_m}}) {
1811+
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
1812+
int64_t m_tail = m;
1813+
{%- for num_rows in range(block_m, 0, -16) %}
1814+
{%- if num_rows != block_m %}
1815+
else
1816+
{%- endif %}
1817+
if (block_m >= {{num_rows}}) {
1818+
{{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
1819+
amx_state,
1820+
A + m * lda,
1821+
dequantized_B_buf + n * K,
1822+
C + m * ldc + n,
1823+
K,
1824+
lda,
1825+
{{block_n}},
1826+
ldc,
1827+
16
1828+
);
1829+
block_m -= {{num_rows}};
1830+
m_tail += {{num_rows}};
1831+
}
1832+
{%- endfor %}
1833+
if (block_m > 0) {
1834+
{{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
1835+
amx_state,
1836+
A + m_tail * lda,
1837+
dequantized_B_buf + n * K,
1838+
C + m_tail * ldc + n,
1839+
K,
1840+
lda,
1841+
{{block_n}},
1842+
ldc,
1843+
block_m
1844+
);
1845+
}
1846+
} // for m
18651847
} // for n
18661848
}
18671849
"""

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy