@@ -1231,10 +1231,6 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str:
1231
1231
else :
1232
1232
assert block_k == 32 , "Only support block_k = 32 for AMX Bfloat16/Float16"
1233
1233
num_columns = block_n // 16
1234
- if self .is_woq_int4 ():
1235
- # block_n for woq int4 is 64, which is too large for micro kernel
1236
- # so we split it into 2x32. Here num_columns = 2.
1237
- num_columns //= 2
1238
1234
options = {
1239
1235
"declare_kernel" : self .get_kernel_declaration (),
1240
1236
"use_cached_dequantized_B" : (
@@ -1633,8 +1629,8 @@ def is_woq_int4(self):
1633
1629
* generate_gemm_config (
1634
1630
VecAMX ,
1635
1631
[ # (block_m, block_n, block_k)
1636
- (16 , 64 , 32 ),
1637
- (32 , 64 , 32 ),
1632
+ (16 , 32 , 32 ),
1633
+ (32 , 32 , 32 ),
1638
1634
],
1639
1635
input_dtype = torch .bfloat16 ,
1640
1636
input2_dtype = torch .uint8 ,
@@ -1646,8 +1642,8 @@ def is_woq_int4(self):
1646
1642
class CppMicroGemmWoQInt4Amx (CppMicroGemmAMX ):
1647
1643
"""
1648
1644
This class generates the code for WoQ int4 micro gemm using AMX intrinsics,
1649
- which are available on 4th and 5th generation Intel Xeon.
1650
- Shape of packed weight = [N // 64 , K, 32 ], viewed as [N, K // 2]
1645
+ which are available on 4th and newer generations of Intel Xeon.
1646
+ Shape of packed weight = [N // 32 , K, 16 ], viewed as [N, K // 2]
1651
1647
Shape of packed ScalesAndZeros = [K // group_size, N, 2]
1652
1648
Reuse TEMPLATE_KERNEL of CppMicroGemmAMX.
1653
1649
"""
@@ -1660,7 +1656,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
1660
1656
{{declare_kernel}} {
1661
1657
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
1662
1658
{{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
1663
- {{kernel.assert_function}}({{block_n}} == 64 , "block_n must be 64 for WOQ int4");
1659
+ {{kernel.assert_function}}({{block_n}} == 32 , "block_n must be 32 for WOQ int4");
1664
1660
1665
1661
// Create a stack-allocated buffer for tiles of B.
1666
1662
// Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
@@ -1674,6 +1670,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
1674
1670
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
1675
1671
const int KB = K / BLOCK_K;
1676
1672
1673
+ __m512i b32[COLS * 2];
1677
1674
__m512 vb[COLS * 2];
1678
1675
__m512 scale[COLS];
1679
1676
__m512 zero[COLS];
@@ -1759,7 +1756,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
1759
1756
// Dequantize a B block of 2 * block_n into bf16
1760
1757
// So, it handles k and k+1 at the same time
1761
1758
auto dequantize_B = [&](int n) {
1762
- constexpr int64_t ldb_int4 = BLOCK_N / 2; // 32
1759
+ constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16
1763
1760
for (int k = 0, kb = 0; k < K; k += 2) {
1764
1761
// Since block_k must be 32 for AMX microkernels, k_start may not be
1765
1762
// a multiple of q_group_size. In that case, we need to load scales
@@ -1769,35 +1766,25 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
1769
1766
}
1770
1767
1771
1768
// load 256 bits = 64 elements in int4
1772
- __m256i b4 = _mm256_loadu_si256((__m256i*)(B + n * K + k * ldb_int4));
1773
1769
if (k + PREFETCH_SIZE_K < K) {
1774
1770
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0);
1775
1771
}
1776
1772
1777
- __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
1778
- vb[0] = _mm512_permutexvar_ps(b32, lut);
1773
+ __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4));
1774
+ b32[0] = _mm512_cvtepu8_epi32(b4);
1775
+ b32[1] = _mm512_srli_epi32(b32[0], 4);
1776
+ vb[0] = _mm512_permutexvar_ps(b32[0] , lut);
1779
1777
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
1780
- vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1781
- vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
1782
-
1783
- b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
1784
- vb[1] = _mm512_permutexvar_ps(b32, lut);
1778
+ vb[1] = _mm512_permutexvar_ps(b32[1], lut);
1785
1779
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
1786
- vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1787
- vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
1788
1780
1789
- b4 = _mm256_loadu_si256((__m256i*)(B + n * K + (k + 1) * ldb_int4));
1790
- b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
1791
- vb[0 + COLS] = _mm512_permutexvar_ps(b32, lut);
1781
+ b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4));
1782
+ b32[0 + COLS] = _mm512_cvtepu8_epi32(b4);
1783
+ b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4);
1784
+ vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut);
1792
1785
vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]);
1793
- vb[2 + COLS] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1794
- vb[2 + COLS] = _mm512_fmadd_ps(vb[2 + COLS], scale[2], zero[2]);
1795
-
1796
- b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
1797
- vb[1 + COLS] = _mm512_permutexvar_ps(b32, lut);
1786
+ vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut);
1798
1787
vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]);
1799
- vb[3 + COLS] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
1800
- vb[3 + COLS] = _mm512_fmadd_ps(vb[3 + COLS], scale[3], zero[3]);
1801
1788
1802
1789
for (int i = 0; i < COLS; i++) {
1803
1790
// convert to VNNI
@@ -1811,57 +1798,52 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
1811
1798
auto v = _mm512_castsi256_si512(v0_bf16);
1812
1799
v = _mm512_inserti64x4(v, v1_bf16, 1);
1813
1800
// store the VNNI format bfloat16 values
1814
- // split block_n into 2x32
1815
- {{input_t}}* addr = dequantized_B_buf + K * 32 * (i / 2) + k * 32 + (i % 2) * 32;
1801
+ {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32;
1816
1802
_mm512_storeu_si512(addr, v);
1817
1803
}
1818
1804
}
1819
1805
};
1820
1806
1821
- const int64_t updated_ldb = {{block_n}} / 2;
1822
1807
for (int64_t n = 0; n < N; n += {{block_n}}) {
1823
1808
// Dequantize K * block_n int8 B elements into BF16
1824
1809
dequantize_B(n);
1825
- // for woq int4, block_n is 64, which is too large for micro kernel
1826
- for (int64_t ni = 0; ni < {{block_n}}; ni += 32) {
1827
- for (int64_t m = 0; m < M; m += {{block_m}}) {
1828
- int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
1829
- int64_t m_tail = m;
1830
- {%- for num_rows in range(block_m, 0, -16) %}
1831
- {%- if num_rows != block_m %}
1832
- else
1833
- {%- endif %}
1834
- if (block_m >= {{num_rows}}) {
1835
- {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
1836
- amx_state,
1837
- A + m * lda,
1838
- dequantized_B_buf + ni * K,
1839
- C + m * ldc + n + ni,
1840
- K,
1841
- lda,
1842
- updated_ldb,
1843
- ldc,
1844
- 16
1845
- );
1846
- block_m -= {{num_rows}};
1847
- m_tail += {{num_rows}};
1848
- }
1849
- {%- endfor %}
1850
- if (block_m > 0) {
1851
- {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
1852
- amx_state,
1853
- A + m_tail * lda,
1854
- dequantized_B_buf + ni * K,
1855
- C + m_tail * ldc + n + ni,
1856
- K,
1857
- lda,
1858
- updated_ldb,
1859
- ldc,
1860
- block_m
1861
- );
1862
- }
1863
- } // for m
1864
- } // for ni
1810
+ for (int64_t m = 0; m < M; m += {{block_m}}) {
1811
+ int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
1812
+ int64_t m_tail = m;
1813
+ {%- for num_rows in range(block_m, 0, -16) %}
1814
+ {%- if num_rows != block_m %}
1815
+ else
1816
+ {%- endif %}
1817
+ if (block_m >= {{num_rows}}) {
1818
+ {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
1819
+ amx_state,
1820
+ A + m * lda,
1821
+ dequantized_B_buf + n * K,
1822
+ C + m * ldc + n,
1823
+ K,
1824
+ lda,
1825
+ {{block_n}},
1826
+ ldc,
1827
+ 16
1828
+ );
1829
+ block_m -= {{num_rows}};
1830
+ m_tail += {{num_rows}};
1831
+ }
1832
+ {%- endfor %}
1833
+ if (block_m > 0) {
1834
+ {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
1835
+ amx_state,
1836
+ A + m_tail * lda,
1837
+ dequantized_B_buf + n * K,
1838
+ C + m_tail * ldc + n,
1839
+ K,
1840
+ lda,
1841
+ {{block_n}},
1842
+ ldc,
1843
+ block_m
1844
+ );
1845
+ }
1846
+ } // for m
1865
1847
} // for n
1866
1848
}
1867
1849
"""
0 commit comments