@@ -388,6 +388,103 @@ static struct ggml_tensor * rwkv_att_v6(
388
388
return ggml_mul_mat (ctx, layer.att_output , x);
389
389
}
390
390
391
+ static struct ggml_tensor * rwkv_att_v7 (
392
+ struct ggml_context * ctx,
393
+ struct ggml_tensor * x,
394
+ struct ggml_tensor * &v_first,
395
+ struct rwkv_layer layer,
396
+ struct rwkv_layer_state & state,
397
+ const int64_t head_count,
398
+ const int64_t head_size
399
+ ) {
400
+ size_t n_embed = x->ne [0 ];
401
+ size_t sequence_length = x->ne [1 ];
402
+
403
+ struct ggml_tensor * x_prev;
404
+ rwkv_carry_x (ctx, layer.ln1_weight , layer.ln1_bias , x, x_prev, state.att_xx );
405
+ state.att_xx = ggml_view_1d (ctx, x, n_embed, n_embed * (sequence_length - 1 ) * sizeof (float ));
406
+
407
+ // sx = x - x_prev
408
+ struct ggml_tensor * sx = ggml_sub (ctx, x_prev, x);
409
+ struct ggml_tensor * dummy = ggml_new_tensor_3d (ctx, GGML_TYPE_F32, n_embed, sequence_length, 6 );
410
+ sx = ggml_repeat (ctx, sx, dummy);
411
+ struct ggml_tensor * xxx = ggml_add (ctx, ggml_mul (ctx, sx, layer.att_x_rwkvag ), x);
412
+
413
+ struct ggml_tensor *xr = ggml_view_2d (ctx, xxx, n_embed, sequence_length, xxx->nb [1 ], 0 );
414
+ struct ggml_tensor *xw = ggml_view_2d (ctx, xxx, n_embed, sequence_length, xxx->nb [1 ], n_embed * sequence_length * sizeof (float ));
415
+ struct ggml_tensor *xk = ggml_view_2d (ctx, xxx, n_embed, sequence_length, xxx->nb [1 ], n_embed * sequence_length * 2 * sizeof (float ));
416
+ struct ggml_tensor *xv = ggml_view_2d (ctx, xxx, n_embed, sequence_length, xxx->nb [1 ], n_embed * sequence_length * 3 * sizeof (float ));
417
+ struct ggml_tensor *xa = ggml_view_2d (ctx, xxx, n_embed, sequence_length, xxx->nb [1 ], n_embed * sequence_length * 4 * sizeof (float ));
418
+ struct ggml_tensor *xg = ggml_view_2d (ctx, xxx, n_embed, sequence_length, xxx->nb [1 ], n_embed * sequence_length * 5 * sizeof (float ));
419
+
420
+ struct ggml_tensor * r = ggml_reshape_3d (ctx, ggml_mul_mat (ctx, layer.att_receptance , xr), head_size, head_count, sequence_length);
421
+ struct ggml_tensor * g = ggml_mul_mat (ctx, layer.att_g2 , ggml_sigmoid (ctx, ggml_mul_mat (ctx, layer.att_g1 , xg)));
422
+ struct ggml_tensor * a = ggml_sigmoid (ctx,
423
+ ggml_add (
424
+ ctx,
425
+ ggml_mul_mat (ctx, layer.att_a2 , ggml_mul_mat (ctx, layer.att_a1 , xa)),
426
+ layer.att_a0
427
+ )
428
+ );
429
+
430
+ struct ggml_tensor * w = ggml_add (
431
+ ctx,
432
+ ggml_mul_mat (ctx, layer.att_w2 , ggml_tanh (ctx, ggml_mul_mat (ctx, layer.att_w1 , xw))),
433
+ layer.att_w0
434
+ );
435
+ w = ggml_exp (ctx, ggml_scale (ctx, ggml_cast (ctx, ggml_sigmoid (ctx, w), GGML_TYPE_F32), -0.606531 ));
436
+
437
+ struct ggml_tensor * k = ggml_mul_mat (ctx, layer.att_key , xk);
438
+ struct ggml_tensor * kk = ggml_reshape_3d (ctx, ggml_mul (ctx, k, layer.att_k_k ), head_size, head_count, sequence_length);
439
+ kk = rwkv_l2norm (ctx, kk);
440
+ struct ggml_tensor * ka = ggml_mul (ctx, k, layer.att_k_a );
441
+ k = ggml_add (ctx, k, ggml_sub (ctx, ggml_mul (ctx, a, ka), ka));
442
+
443
+ struct ggml_tensor * v = ggml_mul_mat (ctx, layer.att_value , xv);
444
+ if (v_first == NULL ) {
445
+ v_first = v;
446
+ } else {
447
+ v = ggml_add (ctx, v, ggml_mul (ctx,
448
+ ggml_sub (ctx, v_first, v),
449
+ ggml_sigmoid (ctx,
450
+ ggml_add (ctx,
451
+ ggml_mul_mat (ctx, layer.att_v2 , ggml_mul_mat (ctx, layer.att_v1 , xv)),
452
+ layer.att_v0
453
+ )
454
+ )
455
+ )
456
+ );
457
+ }
458
+
459
+ w = ggml_reshape_3d (ctx, w, head_size, head_count, sequence_length);
460
+ k = ggml_reshape_3d (ctx, k, head_size, head_count, sequence_length);
461
+ v = ggml_reshape_3d (ctx, v, head_size, head_count, sequence_length);
462
+ a = ggml_reshape_3d (ctx, a, head_size, head_count, sequence_length);
463
+
464
+ struct ggml_tensor * wkv_out = rwkv_wkv_v7 (ctx, state.att_heads , r, w, k, v, ggml_neg (ctx, kk), ggml_mul (ctx, kk, a));
465
+ x = ggml_view_1d (ctx, wkv_out, n_embed * sequence_length, 0 );
466
+
467
+ state.att_heads = ggml_view_1d (ctx, wkv_out, n_embed * head_size, n_embed * sequence_length * sizeof (float ));
468
+
469
+ // group norm with head_count groups
470
+ x = ggml_reshape_3d (ctx, x, head_size, head_count, sequence_length);
471
+ x = ggml_norm (ctx, x, 64e-5f );
472
+ // Convert back to a regular vector.
473
+ x = ggml_reshape_2d (ctx, x, n_embed, sequence_length);
474
+ x = ggml_add (ctx, ggml_mul (ctx, x, layer.att_ln_x_weight ), layer.att_ln_x_bias );
475
+
476
+ x = ggml_add (ctx, x,
477
+ ggml_reshape_2d (ctx,
478
+ ggml_mul (ctx, v, ggml_sum_rows (ctx, ggml_mul (ctx, ggml_mul (ctx, k, r), layer.att_r_k ))),
479
+ n_embed, sequence_length
480
+ )
481
+ );
482
+
483
+ x = ggml_mul (ctx, x, g);
484
+
485
+ return ggml_mul_mat (ctx, layer.att_output , x);
486
+ }
487
+
391
488
static struct ggml_tensor * rwkv_ffn_v4_v5 (struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
392
489
struct ggml_tensor * x_prev;
393
490
rwkv_carry_x (ctx, layer.ln2_weight , layer.ln2_bias , x, x_prev, state.ffn_xx , true );
@@ -437,6 +534,18 @@ static struct ggml_tensor * rwkv_ffn_v6(struct ggml_context * ctx, struct ggml_t
437
534
return ggml_mul (ctx, r, ggml_mul_mat (ctx, layer.ffn_value , k));
438
535
}
439
536
537
+ static struct ggml_tensor * rwkv_ffn_v7 (struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
538
+ struct ggml_tensor * x_prev;
539
+ rwkv_carry_x (ctx, layer.ln2_weight , layer.ln2_bias , x, x_prev, state.ffn_xx , true );
540
+ x_prev = ggml_sub (ctx, x_prev, x);
541
+
542
+ struct ggml_tensor * xk = ggml_add (ctx, ggml_mul (ctx, x_prev, layer.ffn_x_k ), x);
543
+
544
+ struct ggml_tensor * k = ggml_sqr (ctx, ggml_relu (ctx, ggml_mul_mat (ctx, layer.ffn_key , xk)));
545
+
546
+ return ggml_mul_mat (ctx, layer.ffn_value , k);
547
+ }
548
+
440
549
static void rwkv_create_input_and_output_views (
441
550
struct ggml_context * ctx,
442
551
struct rwkv_layer_state * inputs,
@@ -543,6 +652,9 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu
543
652
ggml_set_name (output, " state.out" );
544
653
ggml_set_input (graph.tokens );
545
654
655
+ // For v7.
656
+ struct ggml_tensor * v_first = NULL ;
657
+
546
658
// x = self.w.emb.weight[token]
547
659
struct ggml_tensor * x = ggml_get_rows (ctx, model.emb , graph.tokens );
548
660
@@ -556,7 +668,8 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu
556
668
557
669
switch (model.arch_version_major ) {
558
670
case 7 :
559
-
671
+ x = ggml_add (ctx, x, rwkv_att_v7 (ctx, x, v_first, layer, state, model.head_count , model.head_size ));
672
+ x = ggml_add (ctx, x, rwkv_ffn_v7 (ctx, x, layer, state));
560
673
break ;
561
674
case 6 :
562
675
x = ggml_add (ctx, x, rwkv_att_v6 (ctx, x, layer, state, model.head_count , model.head_size ));
@@ -671,6 +784,9 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c
671
784
ggml_set_name (output, " state.out" );
672
785
ggml_set_input (graph.tokens );
673
786
787
+ // For v7.
788
+ struct ggml_tensor * v_first = NULL ;
789
+
674
790
// x = self.w.emb.weight[token]
675
791
struct ggml_tensor * x = ggml_get_rows (ctx, model.emb , graph.tokens );
676
792
@@ -684,7 +800,7 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c
684
800
685
801
switch (model.arch_version_major ) {
686
802
case 7 :
687
-
803
+ x = ggml_add (ctx, x, rwkv_att_v7 (ctx, x, v_first, layer, state, model. head_count , model. head_size ));
688
804
break ;
689
805
case 6 :
690
806
x = ggml_add (ctx, x, rwkv_att_v6 (ctx, x, layer, state, model.head_count , model.head_size ));
@@ -703,7 +819,7 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c
703
819
// TODO Can we skip ffn for all but the last token, the same way we skip unembedding?
704
820
switch (model.arch_version_major ) {
705
821
case 7 :
706
-
822
+ x = ggml_add (ctx, x, rwkv_ffn_v7 (ctx, x, layer, state));
707
823
break ;
708
824
case 6 :
709
825
x = ggml_add (ctx, x, rwkv_ffn_v6 (ctx, x, layer, state));
0 commit comments