Skip to content

Commit 493682d

Browse files
author
lshAlgorithm
committed
change format
Signed-off-by: lshAlgorithm <lishuhuai_brain@163.com>
1 parent 6e89c96 commit 493682d

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

rwkv_operators_wkv_v7.inc

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#define SET1(x) _mm512_set1_ps(x)
1010
#define MULTIPLY(x, y) _mm512_mul_ps(x, y)
1111
#define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z)
12+
#define ADD(x, y) _mm512_add_ps(x, y)
13+
#define ZEROS() _mm512_setzero_ps()
1214
#elif defined(__AVX2__)
1315
#include <immintrin.h>
1416
#define SIMD_WIDTH 8
@@ -18,6 +20,7 @@
1820
#define MULTIPLY(x, y) _mm256_mul_ps(x, y)
1921
#define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z)
2022
#define ADD(x, y) _mm256_add_ps(x, y)
23+
#define ZEROS() _mm256_setzero_ps()
2124
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
2225
#include <arm_neon.h>
2326
#define SIMD_WIDTH 4
@@ -36,18 +39,11 @@
3639
#endif
3740

3841

39-
inline float horizontal_sum_avx(__m256 vec) {
40-
// 水平相加:将8个float两两相加,得到4个结果
42+
inline float horizontal_sum(__m256 vec) {
4143
__m256 sum1 = _mm256_hadd_ps(vec, vec);
42-
43-
// 再次水平相加:将4个结果两两相加,得到2个结果
4444
__m256 sum2 = _mm256_hadd_ps(sum1, sum1);
45-
46-
// 提取低128位和高128位
4745
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum2, 0),
4846
_mm256_extractf128_ps(sum2, 1));
49-
50-
// 从SSE寄存器中提取最终结果
5147
float result;
5248
_mm_store_ss(&result, sum128);
5349
return result;
@@ -81,7 +77,6 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
8177
size_t t_offset = t * t_stride;
8278

8379
float * state_in = (t == 0) ? state : state_out;
84-
// transpose_square_inplace(state_in, C/H);
8580
for (size_t h = ith; h < H; h += nth) {
8681
size_t h_offset = h * h_stride;
8782
size_t t_h_offset = t_offset + h_offset;
@@ -94,14 +89,24 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
9489
memset(&result_data[t_h_offset], 0, h_stride * sizeof(float));
9590
}
9691

92+
// auto sa_vec = ZEROS();
93+
// for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
94+
// sa_vec = ADD(sa_vec, MULTIPLY(
95+
// LOAD(&a[t_h_offset + j]),
96+
// LOAD(&state_in[h_2d_i_offset + j])
97+
// )
98+
// );
99+
// }
100+
// float sa = horizontal_sum(sa_vec);
97101
float sa = .0;
98102
for (size_t j = 0; j < C / H; j++) {
99103
sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
100104
}
105+
101106
auto v_vec = SET1(v[t_h_i_offset]);
102-
auto sa_vec = SET1(sa);
107+
sa_vec = SET1(sa);
103108

104-
auto sum = _mm256_setzero_ps();
109+
auto sum = ZEROS();
105110
for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
106111
size_t t_h_j_offset = t_h_offset + j;
107112
size_t h_2d_i_j_offset = h_2d_i_offset + j;
@@ -110,19 +115,23 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
110115
auto k_val = LOAD(&k[t_h_j_offset]);
111116
auto b_val = LOAD(&b[t_h_j_offset]);
112117
auto prev_state_val = LOAD(&state_in[h_2d_i_j_offset]);
118+
113119
// auto kv_val = v_val * k_val;
114120
auto kv_val = MULTIPLY(v_vec, k_val);
121+
115122
// state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
116123
auto sab_val = MULTIPLY(sa_vec, b_val);
117124
auto state_out_val = MULTADD(prev_state_val, w_val, kv_val);
118125
state_out_val = ADD(state_out_val, sab_val);
119126
STORE(&state_out[h_2d_i_j_offset], state_out_val);
127+
120128
// result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
121129
auto result = MULTIPLY(state_out_val, r_val);
130+
122131
// auto sum = LOAD(&result_data[t_h_i_offset]);
123132
sum = ADD(sum, result);
124133
}
125-
result_data[t_h_i_offset] = horizontal_sum_avx(sum);
134+
result_data[t_h_i_offset] = horizontal_sum(sum);
126135
}
127136

128137
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy