Skip to content

Commit 536d77e

Browse files
committed
RWKV v7
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent b597b6e commit 536d77e

File tree

6 files changed

+260
-82
lines changed

6 files changed

+260
-82
lines changed

python/convert_pytorch_to_ggml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
129129
'.time_',
130130
'.k_k', '.k_a', '.r_k',
131131
'.x_rwkvag', '.x_k',
132+
'.w0', '.a0', '.v0',
132133
]
133134
):
134135
tensor = tensor.half()

rwkv.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
107107
ggml_backend_cpu_set_n_threads(cpu_backend, n_threads);
108108
ctx->model->backends.push_back(cpu_backend);
109109

110-
RWKV_ENSURE_OR_NULL(rwkv_load_model_from_file(file_path, *ctx->model, n_gpu_layers));
110+
int ngl = n_gpu_layers;
111+
if (ctx->model->backends.size() == 1) {
112+
ngl = 0;
113+
}
114+
115+
RWKV_ENSURE_OR_NULL(rwkv_load_model_from_file(file_path, *ctx->model, ngl));
111116

112117
RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*ctx->model, ctx->serial_graph));
113118

rwkv_graph.inc

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,103 @@ static struct ggml_tensor * rwkv_att_v6(
388388
return ggml_mul_mat(ctx, layer.att_output, x);
389389
}
390390

391+
static struct ggml_tensor * rwkv_att_v7(
392+
struct ggml_context * ctx,
393+
struct ggml_tensor * x,
394+
struct ggml_tensor * &v_first,
395+
struct rwkv_layer layer,
396+
struct rwkv_layer_state & state,
397+
const int64_t head_count,
398+
const int64_t head_size
399+
) {
400+
size_t n_embed = x->ne[0];
401+
size_t sequence_length = x->ne[1];
402+
403+
struct ggml_tensor * x_prev;
404+
rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx);
405+
state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float));
406+
407+
// sx = x - x_prev
408+
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, x);
409+
struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embed, sequence_length, 6);
410+
sx = ggml_repeat(ctx, sx, dummy);
411+
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer.att_x_rwkvag), x);
412+
413+
struct ggml_tensor *xr = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], 0);
414+
struct ggml_tensor *xw = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * sizeof(float));
415+
struct ggml_tensor *xk = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 2 * sizeof(float));
416+
struct ggml_tensor *xv = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 3 * sizeof(float));
417+
struct ggml_tensor *xa = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 4 * sizeof(float));
418+
struct ggml_tensor *xg = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 5 * sizeof(float));
419+
420+
struct ggml_tensor * r = ggml_reshape_3d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, head_count, sequence_length);
421+
struct ggml_tensor * g = ggml_mul_mat(ctx, layer.att_g2, ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_g1, xg)));
422+
struct ggml_tensor * a = ggml_sigmoid(ctx,
423+
ggml_add(
424+
ctx,
425+
ggml_mul_mat(ctx, layer.att_a2, ggml_mul_mat(ctx, layer.att_a1, xa)),
426+
layer.att_a0
427+
)
428+
);
429+
430+
struct ggml_tensor * w = ggml_add(
431+
ctx,
432+
ggml_mul_mat(ctx, layer.att_w2, ggml_tanh(ctx, ggml_mul_mat(ctx, layer.att_w1, xw))),
433+
layer.att_w0
434+
);
435+
w = ggml_exp(ctx, ggml_scale(ctx, ggml_cast(ctx, ggml_sigmoid(ctx, w), GGML_TYPE_F32), -0.606531));
436+
437+
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
438+
struct ggml_tensor * kk = ggml_reshape_3d(ctx, ggml_mul(ctx, k, layer.att_k_k), head_size, head_count, sequence_length);
439+
kk = rwkv_l2norm(ctx, kk);
440+
struct ggml_tensor * ka = ggml_mul(ctx, k, layer.att_k_a);
441+
k = ggml_add(ctx, k, ggml_sub(ctx, ggml_mul(ctx, a, ka), ka));
442+
443+
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
444+
if (v_first == NULL) {
445+
v_first = v;
446+
} else {
447+
v = ggml_add(ctx, v, ggml_mul(ctx,
448+
ggml_sub(ctx, v_first, v),
449+
ggml_sigmoid(ctx,
450+
ggml_add(ctx,
451+
ggml_mul_mat(ctx, layer.att_v2, ggml_mul_mat(ctx, layer.att_v1, xv)),
452+
layer.att_v0
453+
)
454+
)
455+
)
456+
);
457+
}
458+
459+
w = ggml_reshape_3d(ctx, w, head_size, head_count, sequence_length);
460+
k = ggml_reshape_3d(ctx, k, head_size, head_count, sequence_length);
461+
v = ggml_reshape_3d(ctx, v, head_size, head_count, sequence_length);
462+
a = ggml_reshape_3d(ctx, a, head_size, head_count, sequence_length);
463+
464+
struct ggml_tensor * wkv_out = rwkv_wkv_v7(ctx, state.att_heads, r, w, k, v, ggml_neg(ctx, kk), ggml_mul(ctx, kk, a));
465+
x = ggml_view_1d(ctx, wkv_out, n_embed * sequence_length, 0);
466+
467+
state.att_heads = ggml_view_1d(ctx, wkv_out, n_embed * head_size, n_embed * sequence_length * sizeof(float));
468+
469+
// group norm with head_count groups
470+
x = ggml_reshape_3d(ctx, x, head_size, head_count, sequence_length);
471+
x = ggml_norm(ctx, x, 64e-5f);
472+
// Convert back to a regular vector.
473+
x = ggml_reshape_2d(ctx, x, n_embed, sequence_length);
474+
x = ggml_add(ctx, ggml_mul(ctx, x, layer.att_ln_x_weight), layer.att_ln_x_bias);
475+
476+
x = ggml_add(ctx, x,
477+
ggml_reshape_2d(ctx,
478+
ggml_mul(ctx, v, ggml_sum_rows(ctx, ggml_mul(ctx, ggml_mul(ctx, k, r), layer.att_r_k))),
479+
n_embed, sequence_length
480+
)
481+
);
482+
483+
x = ggml_mul(ctx, x, g);
484+
485+
return ggml_mul_mat(ctx, layer.att_output, x);
486+
}
487+
391488
static struct ggml_tensor * rwkv_ffn_v4_v5(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
392489
struct ggml_tensor * x_prev;
393490
rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx, true);
@@ -437,6 +534,18 @@ static struct ggml_tensor * rwkv_ffn_v6(struct ggml_context * ctx, struct ggml_t
437534
return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
438535
}
439536

537+
static struct ggml_tensor * rwkv_ffn_v7(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
538+
struct ggml_tensor * x_prev;
539+
rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx, true);
540+
x_prev = ggml_sub(ctx, x_prev, x);
541+
542+
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, x_prev, layer.ffn_x_k), x);
543+
544+
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk)));
545+
546+
return ggml_mul_mat(ctx, layer.ffn_value, k);
547+
}
548+
440549
static void rwkv_create_input_and_output_views(
441550
struct ggml_context * ctx,
442551
struct rwkv_layer_state * inputs,
@@ -543,6 +652,9 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu
543652
ggml_set_name(output, "state.out");
544653
ggml_set_input(graph.tokens);
545654

655+
// For v7.
656+
struct ggml_tensor * v_first = NULL;
657+
546658
// x = self.w.emb.weight[token]
547659
struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens);
548660

@@ -556,7 +668,8 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu
556668

557669
switch (model.arch_version_major) {
558670
case 7:
559-
671+
x = ggml_add(ctx, x, rwkv_att_v7(ctx, x, v_first, layer, state, model.head_count, model.head_size));
672+
x = ggml_add(ctx, x, rwkv_ffn_v7(ctx, x, layer, state));
560673
break;
561674
case 6:
562675
x = ggml_add(ctx, x, rwkv_att_v6(ctx, x, layer, state, model.head_count, model.head_size));
@@ -671,6 +784,9 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c
671784
ggml_set_name(output, "state.out");
672785
ggml_set_input(graph.tokens);
673786

787+
// For v7.
788+
struct ggml_tensor * v_first = NULL;
789+
674790
// x = self.w.emb.weight[token]
675791
struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens);
676792

@@ -684,7 +800,7 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c
684800

685801
switch (model.arch_version_major) {
686802
case 7:
687-
803+
x = ggml_add(ctx, x, rwkv_att_v7(ctx, x, v_first, layer, state, model.head_count, model.head_size));
688804
break;
689805
case 6:
690806
x = ggml_add(ctx, x, rwkv_att_v6(ctx, x, layer, state, model.head_count, model.head_size));
@@ -703,7 +819,7 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c
703819
// TODO Can we skip ffn for all but the last token, the same way we skip unembedding?
704820
switch (model.arch_version_major) {
705821
case 7:
706-
822+
x = ggml_add(ctx, x, rwkv_ffn_v7(ctx, x, layer, state));
707823
break;
708824
case 6:
709825
x = ggml_add(ctx, x, rwkv_ffn_v6(ctx, x, layer, state));

rwkv_model_loading.inc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,16 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback, const uint32_
170170
}
171171

172172
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.r_k"), buffer), layer.att_r_k, offload_layer));
173-
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.k_k"), buffer), layer.att_k_k, offload_layer));
173+
// Somehow offloading this layer makes the model output NaN after several iterations.
174+
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.k_k"), buffer), layer.att_k_k, false));
174175
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.k_a"), buffer), layer.att_k_a, offload_layer));
175176

176177
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key, offload_layer));
177178
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value, offload_layer));
178179
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance, offload_layer));
179180
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output, offload_layer));
180181

182+
// These too.
181183
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight, false));
182184
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias, false));
183185
break;

rwkv_operators.inc

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// #include "rwkv_operators_wkv_v5.inc"
1+
#include "rwkv_operators_wkv_v7.inc"
22

33
#define SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP() { (void) ith; (void) nth; (void) userdata; }
44

@@ -36,11 +36,60 @@ static void rwkv_max_impl(
3636
SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP();
3737
}
3838

39+
// TODO: Upstream to ggml
40+
static void rwkv_l2norm_impl(
41+
struct ggml_tensor * dst,
42+
const struct ggml_tensor * src0,
43+
int ith,
44+
int nth,
45+
void * userdata
46+
) {
47+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
48+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
49+
GGML_ASSERT(ggml_is_contiguous(dst));
50+
GGML_ASSERT(ggml_is_contiguous(src0));
51+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
52+
53+
GGML_TENSOR_UNARY_OP_LOCALS
54+
55+
float eps = 1e-12f;
56+
57+
// TODO: optimize
58+
for (int64_t i03 = 0; i03 < ne03; i03++) {
59+
for (int64_t i02 = 0; i02 < ne02; i02++) {
60+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
61+
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
62+
63+
float sum = 0.0;
64+
for (int64_t i00 = 0; i00 < ne00; i00++) {
65+
float v = x[i00];
66+
sum += v*v;
67+
}
68+
69+
float * y = (float *) ((char *) dst->data + i01*nb01 + i02*nb02 + i03*nb03);
70+
71+
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
72+
73+
// ggml_vec_scale_f32(ne00, y, scale);
74+
for (int64_t i00 = 0; i00 < ne00; i00++) {
75+
y[i00] = x[i00] * scale;
76+
}
77+
}
78+
}
79+
}
80+
81+
SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP();
82+
}
83+
3984
// Element-wise max(x, y)
4085
struct ggml_tensor * rwkv_max(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) {
4186
return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL);
4287
}
4388

89+
struct ggml_tensor * rwkv_l2norm(struct ggml_context * ctx, struct ggml_tensor * x) {
90+
return ggml_map_custom1(ctx, x, rwkv_l2norm_impl, 1, NULL);
91+
}
92+
4493
struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
4594
// LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
4695
// Looks like ggml_norm does the first part, we only need to apply weight & bias.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy