9
9
#define SET1 (x ) _mm512_set1_ps(x)
10
10
#define MULTIPLY (x, y ) _mm512_mul_ps(x, y)
11
11
#define MULTADD (x, y, z ) _mm512_fmadd_ps(x, y, z)
12
+ #define ADD (x, y ) _mm512_add_ps(x, y)
13
+ #define ZEROS () _mm512_setzero_ps()
12
14
#elif defined(__AVX2__)
13
15
#include < immintrin.h>
14
16
#define SIMD_WIDTH 8
18
20
#define MULTIPLY (x, y ) _mm256_mul_ps(x, y)
19
21
#define MULTADD (x, y, z ) _mm256_fmadd_ps(x, y, z)
20
22
#define ADD (x, y ) _mm256_add_ps(x, y)
23
+ #define ZEROS () _mm256_setzero_ps()
21
24
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
22
25
#include < arm_neon.h>
23
26
#define SIMD_WIDTH 4
36
39
#endif
37
40
38
41
39
- inline float horizontal_sum_avx (__m256 vec) {
40
- // 水平相加:将8个float两两相加,得到4个结果
42
+ inline float horizontal_sum (__m256 vec) {
41
43
__m256 sum1 = _mm256_hadd_ps (vec, vec);
42
-
43
- // 再次水平相加:将4个结果两两相加,得到2个结果
44
44
__m256 sum2 = _mm256_hadd_ps (sum1, sum1);
45
-
46
- // 提取低128位和高128位
47
45
__m128 sum128 = _mm_add_ps (_mm256_extractf128_ps (sum2, 0 ),
48
46
_mm256_extractf128_ps (sum2, 1 ));
49
-
50
- // 从SSE寄存器中提取最终结果
51
47
float result;
52
48
_mm_store_ss (&result, sum128);
53
49
return result;
@@ -81,7 +77,6 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
81
77
size_t t_offset = t * t_stride;
82
78
83
79
float * state_in = (t == 0 ) ? state : state_out;
84
- // transpose_square_inplace(state_in, C/H);
85
80
for (size_t h = ith; h < H; h += nth) {
86
81
size_t h_offset = h * h_stride;
87
82
size_t t_h_offset = t_offset + h_offset;
@@ -94,14 +89,24 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
94
89
memset (&result_data[t_h_offset], 0 , h_stride * sizeof (float ));
95
90
}
96
91
92
+ // auto sa_vec = ZEROS();
93
+ // for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
94
+ // sa_vec = ADD(sa_vec, MULTIPLY(
95
+ // LOAD(&a[t_h_offset + j]),
96
+ // LOAD(&state_in[h_2d_i_offset + j])
97
+ // )
98
+ // );
99
+ // }
100
+ // float sa = horizontal_sum(sa_vec);
97
101
float sa = .0 ;
98
102
for (size_t j = 0 ; j < C / H; j++) {
99
103
sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
100
104
}
105
+
101
106
auto v_vec = SET1 (v[t_h_i_offset]);
102
- auto sa_vec = SET1 (sa);
107
+ sa_vec = SET1 (sa);
103
108
104
- auto sum = _mm256_setzero_ps ();
109
+ auto sum = ZEROS ();
105
110
for (size_t j = 0 ; j < C / H; j += SIMD_WIDTH) {
106
111
size_t t_h_j_offset = t_h_offset + j;
107
112
size_t h_2d_i_j_offset = h_2d_i_offset + j;
@@ -110,19 +115,23 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
110
115
auto k_val = LOAD (&k[t_h_j_offset]);
111
116
auto b_val = LOAD (&b[t_h_j_offset]);
112
117
auto prev_state_val = LOAD (&state_in[h_2d_i_j_offset]);
118
+
113
119
// auto kv_val = v_val * k_val;
114
120
auto kv_val = MULTIPLY (v_vec, k_val);
121
+
115
122
// state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
116
123
auto sab_val = MULTIPLY (sa_vec, b_val);
117
124
auto state_out_val = MULTADD (prev_state_val, w_val, kv_val);
118
125
state_out_val = ADD (state_out_val, sab_val);
119
126
STORE (&state_out[h_2d_i_j_offset], state_out_val);
127
+
120
128
// result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
121
129
auto result = MULTIPLY (state_out_val, r_val);
130
+
122
131
// auto sum = LOAD(&result_data[t_h_i_offset]);
123
132
sum = ADD (sum, result);
124
133
}
125
- result_data[t_h_i_offset] = horizontal_sum_avx (sum);
134
+ result_data[t_h_i_offset] = horizontal_sum (sum);
126
135
}
127
136
128
137
}
0 commit comments