From 7286b83d3f1b1df627ca25097d4d67cd604ed484 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Tue, 6 Feb 2024 17:03:12 -0500 Subject: [PATCH 01/16] BERT WIP --- convert-hf-to-gguf.py | 53 ++++++++ gguf-py/gguf/constants.py | 42 ++++--- gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 13 +- llama.cpp | 221 ++++++++++++++++++++++++++++++++- 5 files changed, 310 insertions(+), 22 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 5e343742d2484..971e0fcb2f662 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -205,6 +205,8 @@ def from_model_architecture(model_architecture): return OrionModel if model_architecture == "InternLM2ForCausalLM": return InternLM2Model + if model_architecture == "BertModel": + return BertModel return Model def _is_model_safetensors(self) -> bool: @@ -258,6 +260,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH: return gguf.MODEL_ARCH.ORION if arch == "InternLM2ForCausalLM": return gguf.MODEL_ARCH.INTERNLM2 + if arch == "BertModel": + return gguf.MODEL_ARCH.BERT raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -1521,6 +1525,55 @@ def write_tensors(self): self.post_write_tensors(tensor_map, name, data_torch) +class BertModel(Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + + def set_gguf_parameters(self): + # TODO(cebtenzzre): merge with parent class + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + tensors = dict(self.get_tensors()) + for name, data_torch in tensors.items(): + # we are only using BERT for embeddings so we don't need the pooling layer + if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + continue # we don't need these + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + data = data_torch.squeeze().numpy() + n_dims = len(data.shape) + new_dtype: type[np.floating[Any]] + + if self.ftype == 1 and name.endswith(".weight") and n_dims == 2: + # if f16 desired, convert any float32 2-dim weight tensors to float16 + new_dtype = np.float16 + else: + # if f32 desired, convert any float16 to float32 + new_dtype = np.float32 + + print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}") + + if data.dtype != new_dtype: + data = data.astype(new_dtype) + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ed8e26f83e6a9..4b42e488cfa56 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -60,22 +60,23 @@ class Rope: SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" class Tokenizer: - MODEL = "tokenizer.ggml.model" - LIST = "tokenizer.ggml.tokens" - TOKEN_TYPE = "tokenizer.ggml.token_type" - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - HF_JSON = "tokenizer.huggingface.json" - RWKV = "tokenizer.rwkv.world" - CHAT_TEMPLATE = "tokenizer.chat_template" + MODEL = "tokenizer.ggml.model" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" # @@ -121,6 +122,7 @@ class MODEL_TENSOR(IntEnum): ATTN_OUT = auto() ATTN_NORM = auto() ATTN_NORM_2 = auto() + ATTN_OUT_NORM = auto() ATTN_ROT_EMBD = auto() FFN_GATE_INP = auto() FFN_NORM = auto() @@ -133,6 +135,7 @@ class MODEL_TENSOR(IntEnum): FFN_UP_EXP = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() + LAYER_OUT_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -176,6 +179,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", @@ -185,6 +189,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -260,17 +265,18 @@ class MODEL_TENSOR(IntEnum): ], MODEL_ARCH.BERT: [ MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, MODEL_TENSOR.TOKEN_TYPES, MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 16808196e769d..20d2e5cff9564 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -387,6 +387,9 @@ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[by def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: self.add_array(Keys.Tokenizer.TOKEN_TYPE, types) + def add_token_type_count(self, value: int) -> None: + self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value) + def add_token_scores(self, scores: Sequence[float]) -> None: self.add_array(Keys.Tokenizer.SCORES, scores) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4f16d85044693..c7ba1420e0453 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -30,6 +30,7 @@ class TensorNameMap: # Normalization of token embeddings MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom + "embeddings.LayerNorm", # bert ), # Position embeddings @@ -54,7 +55,6 @@ class TensorNameMap: "transformer.ln_f", # gpt2 gpt-j falcon "model.norm", # llama-hf baichuan internlm2 "norm", # llama-pth - "embeddings.LayerNorm", # bert "transformer.norm_f", # mpt "ln_f", # refact bloom qwen gpt2 "language_model.encoder.final_layernorm", # persimmon @@ -79,7 +79,6 @@ class TensorNameMap: "transformer.h.{bid}.ln_mlp", # falcon40b "model.layers.{bid}.input_layernorm", # llama-hf "layers.{bid}.attention_norm", # llama-pth - "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi "h.{bid}.ln_1", # gpt2 @@ -155,6 +154,11 @@ class TensorNameMap: "model.layers.{bid}.attention.wo", # internlm2 ), + # Attention output norm + MODEL_TENSOR.ATTN_OUT_NORM: ( + "encoder.layer.{bid}.attention.output.LayerNorm", # bert + ), + # Rotary embeddings MODEL_TENSOR.ATTN_ROT_EMBD: ( "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf @@ -171,7 +175,6 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_2", # mpt "model.layers.{bid}.post_attention_layernorm", # llama-hf "layers.{bid}.ffn_norm", # llama-pth - "encoder.layer.{bid}.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi "h.{bid}.ln_2", # gpt2 @@ -266,6 +269,10 @@ class TensorNameMap: MODEL_TENSOR.ROPE_FREQS: ( "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon ), + + MODEL_TENSOR.LAYER_OUT_NORM: ( + "encoder.layer.{bid}.output.LayerNorm", # bert + ) } mapping: dict[str, tuple[MODEL_TENSOR, str]] diff --git a/llama.cpp b/llama.cpp index 65e399adca60e..575b60f563d38 100644 --- a/llama.cpp +++ b/llama.cpp @@ -196,6 +196,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_PERSIMMON, LLM_ARCH_REFACT, + LLM_ARCH_BERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -219,6 +220,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_PERSIMMON, "persimmon" }, { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -271,6 +273,7 @@ enum llm_kv { LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, LLM_KV_TOKENIZER_SCORES, LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, @@ -326,6 +329,7 @@ static std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, @@ -534,6 +538,24 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_BLOOM, { @@ -1388,6 +1410,9 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_22M, + MODEL_33M, + MODEL_109M, MODEL_0_5B, MODEL_1B, MODEL_3B, @@ -1428,6 +1453,7 @@ struct llama_hparams { uint32_t n_ff; uint32_t n_expert = 0; uint32_t n_expert_used = 0; + uint32_t n_vocab_type = 0; // for BERT-style token types float f_norm_eps; float f_norm_rms_eps; @@ -1532,6 +1558,8 @@ struct llama_layer { struct ggml_tensor * bqkv; // normalization + struct ggml_tensor * attn_out_norm; + struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm_b; @@ -1550,6 +1578,10 @@ struct llama_layer { struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_act; + + // normalization + struct ggml_tensor * layer_out_norm; + struct ggml_tensor * layer_out_norm_b; }; struct llama_kv_cell { @@ -1667,6 +1699,7 @@ struct llama_model { llama_vocab vocab; struct ggml_tensor * tok_embd; + struct ggml_tensor * type_embd; struct ggml_tensor * pos_embd; struct ggml_tensor * tok_norm; struct ggml_tensor * tok_norm_b; @@ -1791,8 +1824,10 @@ struct llama_context { struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_type; // I32 [n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_K_shift; // I32 [n_ctx] + struct ggml_tensor * inp_sum; // F32 [1, n_batch] #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -2933,6 +2968,21 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + + switch (hparams.n_embd) { + case 384: // MiniLM + switch (hparams.n_layer) { + case 6: model.type = e_model::MODEL_22M; break; + case 12: model.type = e_model::MODEL_33M; break; + } break; + case 768: // BERT-Base + model.type = e_model::MODEL_109M; break; + } + } break; case LLM_ARCH_BLOOM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3719,6 +3769,45 @@ static bool llm_load_tensors( layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}); } } break; + case LLM_ARCH_BERT: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_vocab_type, n_embd}); + model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); + model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + + layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + } + } break; case LLM_ARCH_BLOOM: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5561,6 +5650,113 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bert() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.type_embd, lctx.inp_type), + inpL); + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.pos_embd, lctx.inp_pos), + inpL); + + inpL = llm_build_norm(ctx0, inpL, hparams, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, cb, -1); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur = inpL; + + // self-attention + { + // compute Q and K + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head, n_tokens); + struct ggml_tensor * K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head, n_tokens); + struct ggml_tensor * V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + // KQ = soft_max(KQ / sqrt(head width)) + KQ = ggml_soft_max(ctx0, + ggml_scale(ctx0, KQ, 1.0f / sqrt((float)n_embd_head))); + + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV), n_embd, N); + + // attention output + cur = ggml_add(ctx0, + model.layers[il].bo, + ggml_mul_mat(ctx0, model.layers[il].wo, cur)); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpSA); + + // attention layer norm + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_out_norm, + model.layers[il].attn_out_norm_b, + LLM_NORM, cb, il); + + struct ggml_tensor * ffn_inp = cur; + + // feed-forward network + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // output layer norm + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].layer_out_norm, + model.layers[il].layer_out_norm_b, + LLM_NORM, cb, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // pooling (sum = [L, 1, B]) + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), lctx.inp_sum); // [E, 1, B] + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_bloom() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6835,6 +7031,12 @@ static struct ggml_cgraph * llama_build_graph( ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } + { + // for embedding models, token type is always zero ("sentence A") + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_type->buffer)); + memset(lctx.inp_type->data, 0, batch.n_tokens * ggml_element_size(lctx.inp_type)); + } + { const int64_t n_kv = llm.n_kv; const int64_t n_tokens = batch.n_tokens; @@ -6870,6 +7072,15 @@ static struct ggml_cgraph * llama_build_graph( data[i] = lctx.kv_self.cells[i].delta; } } + + { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_inp_sum->buffer)); + float * data = (float *) lctx.inp_sum->data; + + for (int i = 0; i < batch.n_tokens; ++i) { + data[i] = 1.0f/float(batch.n_tokens); + } + } } llm.init(); @@ -6899,6 +7110,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_refact(); } break; + case LLM_ARCH_BERT: + { + result = llm.build_bert(); + } break; case LLM_ARCH_BLOOM: { result = llm.build_bloom(); @@ -10574,7 +10789,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*5, + /* .mem_size */ ggml_tensor_overhead()*7, /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -10583,14 +10798,18 @@ struct llama_context * llama_new_context_with_model( ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); + ctx->inp_type = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); + ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch); ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_pos, "inp_pos"); + ggml_set_name(ctx->inp_type, "inp_type"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); + ggml_set_name(ctx->inp_sum, "inp_sum"); ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); From 0051c82d52de7ce81fb4aa427af4f181e61091cd Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 01:02:10 -0600 Subject: [PATCH 02/16] it runs; tokenization is messed up; pooling is wrong for multi batches --- convert-hf-to-gguf.py | 27 ++++++++ llama.cpp | 144 ++++++++++++++++++++---------------------- 2 files changed, 94 insertions(+), 77 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 88b9b912b9f6e..a952a5282277c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1590,6 +1590,33 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) + def set_vocab(self): + path = self.dir_model + added_tokens_path = self.dir_model if self.dir_model.exists() else None + + vocab = HfVocab(path, added_tokens_path) + tokens, scores, toktypes = zip(*vocab.all_tokens()) + + assert len(tokens) == vocab.vocab_size + + # for some reason set(toktypes) = {1, 3} so we need to compress it + all_types, toktypes1 = np.unique(toktypes, return_inverse=True) + n_token_types, toktypes1 = len(all_types), toktypes1.tolist() + self.gguf_writer.add_uint32("tokenizer.ggml.token_type_count", n_token_types) + + # convert tokens to SPM style + tokens = [ + (t[2:] if t.startswith(b"##") else b"\xe2\x96\x81" + t) for t in tokens + ] + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) # ignore types for now (all zero) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + def write_tensors(self): tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) tensors = dict(self.get_tensors()) diff --git a/llama.cpp b/llama.cpp index 18ff8c547133b..1118a291bcb0e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -359,6 +359,7 @@ struct LLM_KV { enum llm_tensor { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, @@ -547,13 +548,12 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, @@ -1519,6 +1519,8 @@ struct llama_hparams { float f_clamp_kqv; float f_max_alibi_bias; + bool causal_attn = true; + bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; @@ -1611,8 +1613,6 @@ struct llama_layer { struct ggml_tensor * bqkv; // normalization - struct ggml_tensor * attn_out_norm; - struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm_b; @@ -1631,10 +1631,6 @@ struct llama_layer { struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_act; - - // normalization - struct ggml_tensor * layer_out_norm; - struct ggml_tensor * layer_out_norm_b; }; struct llama_kv_cell { @@ -3036,7 +3032,7 @@ static void llm_load_hparams( case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); switch (hparams.n_embd) { case 384: // MiniLM @@ -3279,7 +3275,11 @@ static void llm_load_vocab( // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + try { + vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: model vocab missing newline token: %s\n", __func__, e.what()); + } } else { const std::vector ids = llama_tokenize_internal(vocab, "\u010A", false); GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); @@ -3617,6 +3617,7 @@ static bool llm_load_tensors( const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; + const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_ff = hparams.n_ff; GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); @@ -3834,7 +3835,7 @@ static bool llm_load_tensors( case LLM_ARCH_BERT: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_vocab_type, n_embd}); + model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); @@ -3845,11 +3846,11 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); - layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); @@ -4826,6 +4827,7 @@ struct llm_build_context { const int32_t n_orig_ctx; const bool do_rope_shift; + const bool causal_attn; const llm_build_cb & cb; @@ -4869,6 +4871,7 @@ struct llm_build_context { kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), do_rope_shift (worst_case || kv_self.has_shift), + causal_attn (hparams.causal_attn), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { // all initializations should be done in init() @@ -5722,69 +5725,50 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); - - // inp_pos - contains the positions + struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); + cb(inpL, "inp_embd", -1); - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.type_embd, lctx.inp_type), - inpL); - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.pos_embd, lctx.inp_pos), - inpL); + // embed layer norm + inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); - inpL = llm_build_norm(ctx0, inpL, hparams, - model.tok_norm, - model.tok_norm_b, - LLM_NORM, cb, -1); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens] + // iterate layers for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur = inpL; // self-attention { - // compute Q and K - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head, n_tokens); - struct ggml_tensor * K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head, n_tokens); - struct ggml_tensor * V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); + cb(Qcur, "Qcur", il); - struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); - // KQ = soft_max(KQ / sqrt(head width)) - KQ = ggml_soft_max(ctx0, - ggml_scale(ctx0, KQ, 1.0f / sqrt((float)n_embd_head))); + struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); + cb(Kcur, "Kcur", il); - V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); - struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); + cb(Vcur, "Vcur", il); - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV), n_embd, N); + // seems like we just need to do this for Q? + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - // attention output - cur = ggml_add(ctx0, - model.layers[il].bo, - ggml_mul_mat(ctx0, model.layers[il].wo, cur)); + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); } // re-add the layer input - cur = ggml_add(ctx0, cur, inpSA); + cur = ggml_add(ctx0, cur, inpL); // attention layer norm - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].attn_out_norm, - model.layers[il].attn_out_norm_b, - LLM_NORM, cb, il); + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il); struct ggml_tensor * ffn_inp = cur; @@ -5800,19 +5784,19 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); // output layer norm - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].layer_out_norm, - model.layers[il].layer_out_norm_b, - LLM_NORM, cb, il); + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); // input for next layer inpL = cur; } + // final output cur = inpL; - // pooling (sum = [L, 1, B]) - cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), lctx.inp_sum); // [E, 1, B] + // pooling + struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); + cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); + cb(cur, "result_embed", -1); ggml_build_forward_expand(gf, cur); @@ -7260,7 +7244,8 @@ static struct ggml_cgraph * llama_build_graph( for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || + (llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) { f = -INFINITY; } else { f = 0; @@ -7283,7 +7268,7 @@ static struct ggml_cgraph * llama_build_graph( } { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_inp_sum->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); float * data = (float *) lctx.inp_sum->data; for (int i = 0; i < batch.n_tokens; ++i) { @@ -7482,13 +7467,18 @@ static int llama_decode_internal( // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - - // the embeddings could be the second to last tensor, or the third to last tensor - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + struct ggml_tensor * embeddings = nullptr; + if (strcmp(res->name, "result_embed") == 0) { + embeddings = res; + res = nullptr; + } else { + // the embeddings could be the second to last tensor, or the third to last tensor + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + embeddings = gf->nodes[gf->n_nodes - 2]; + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -7555,7 +7545,7 @@ static int llama_decode_internal( // extract logits // TODO: do not compute and extract logits if only embeddings are needed // need to update the graphs to skip "result_output" - { + if (res) { auto & logits_out = lctx.logits; #ifndef NDEBUG @@ -7601,7 +7591,7 @@ static int llama_decode_internal( embedding_out.resize(n_embd); ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), 0, n_embd*sizeof(float)); ggml_backend_synchronize(embeddings_backend); } From 59c1829b0c39a3b74817c9a14b7deb0c74f0ad5e Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 13:09:36 -0600 Subject: [PATCH 03/16] add in wordpiece tokenizer --- convert-hf-to-gguf.py | 31 ++-- examples/embedding/embedding.cpp | 12 +- llama.cpp | 254 ++++++++++++++++++++++++++++++- llama.h | 1 + 4 files changed, 279 insertions(+), 19 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a952a5282277c..e047a424cd60c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1594,26 +1594,37 @@ def set_vocab(self): path = self.dir_model added_tokens_path = self.dir_model if self.dir_model.exists() else None + # use huggingface vocab to get all tokens vocab = HfVocab(path, added_tokens_path) tokens, scores, toktypes = zip(*vocab.all_tokens()) - assert len(tokens) == vocab.vocab_size - # for some reason set(toktypes) = {1, 3} so we need to compress it - all_types, toktypes1 = np.unique(toktypes, return_inverse=True) - n_token_types, toktypes1 = len(all_types), toktypes1.tolist() + # we need this to validate the size of the token_type embeddings + # though currently we are passing all zeros to the token_type embeddings + n_token_types = len(set(toktypes)) self.gguf_writer.add_uint32("tokenizer.ggml.token_type_count", n_token_types) - # convert tokens to SPM style - tokens = [ - (t[2:] if t.startswith(b"##") else b"\xe2\x96\x81" + t) for t in tokens - ] + # convert to phantom space vocab + def phantom(tok, typ): + if tok.startswith(b'[') and tok.endswith(b']'): + return tok + elif tok.startswith(b"##"): + return tok[2:] + else: + return b"\xe2\x96\x81" + tok + tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)] - self.gguf_writer.add_tokenizer_model("llama") + # set up bos and eos tokens (cls and sep) + self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id) + self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id) + + # add vocab to gguf + self.gguf_writer.add_tokenizer_model("bert") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) # ignore types for now (all zero) + self.gguf_writer.add_token_types(toktypes) + # handle special tokens special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab.add_to_gguf(self.gguf_writer) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3295cd2400ac3..27376c8f09fdc 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -87,7 +87,17 @@ int main(int argc, char ** argv) { } const int n_embd = llama_n_embd(model); - const auto * embeddings = llama_get_embeddings(ctx); + auto * embeddings = llama_get_embeddings(ctx); + + // l2-normalize embeddings + float norm = 0; + for (int i = 0; i < n_embd; i++) { + norm += embeddings[i] * embeddings[i]; + } + norm = sqrt(norm); + for (int i = 0; i < n_embd; i++) { + embeddings[i] /= norm; + } for (int i = 0; i < n_embd; i++) { printf("%f ", embeddings[i]); diff --git a/llama.cpp b/llama.cpp index 1118a291bcb0e..5d97241f336a6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2860,6 +2860,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ switch (type) { case LLAMA_VOCAB_TYPE_SPM: return "SPM"; case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; default: return "unknown"; } } @@ -3033,6 +3034,7 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + hparams.causal_attn = false; switch (hparams.n_embd) { case 384: // MiniLM @@ -3248,6 +3250,16 @@ static void llm_load_vocab( vocab.special_unk_id = -1; vocab.special_sep_id = -1; vocab.special_pad_id = -1; + } else if (tokenizer_name == "bert") { + vocab.type = LLAMA_VOCAB_TYPE_WPM; + + // default special tokens + vocab.special_bos_id = 101; + vocab.special_eos_id = 102; + vocab.special_unk_id = 100; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.add_space_prefix = false; } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); @@ -3275,11 +3287,9 @@ static void llm_load_vocab( // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - try { - vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); - } catch (const std::exception & e) { - LLAMA_LOG_WARN("%s: model vocab missing newline token: %s\n", __func__, e.what()); - } + vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { + vocab.linefeed_id = vocab.special_pad_id; } else { const std::vector ids = llama_tokenize_internal(vocab, "\u010A", false); GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); @@ -5725,11 +5735,14 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + // get input vectors with right size + struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0); + struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); + // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); - struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0); inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL); - struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); cb(inpL, "inp_embd", -1); @@ -5794,7 +5807,6 @@ struct llm_build_context { cur = inpL; // pooling - struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); cb(cur, "result_embed", -1); @@ -7655,6 +7667,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(false); return unicode_to_bytes_bpe(token_data.text); } + case LLAMA_VOCAB_TYPE_WPM: { + GGML_ASSERT(false); + } default: GGML_ASSERT(false); } @@ -7667,6 +7682,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; return vocab.token_to_id.at(buf); } + case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_BPE: { return vocab.token_to_id.at(bytes_to_unicode_bpe(ch)); } @@ -8137,6 +8153,207 @@ struct llm_tokenizer_bpe { llm_bigram_bpe::queue work_queue; }; +struct llm_tokenizer_wpm { + llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + auto * token_map = &vocab.token_to_id; + + // normalize and split by whitespace + std::vector words = preprocess(text); + + // bos token prepended already + + // find the longest tokens that form the words + for (const std::string &word : words) { + // skip empty words + if (word.size() == 0) continue; + + // prepend phantom space + std::string word1 = "\xe2\x96\x81" + word; + int n = word1.size(); + + // we're at the start of a new word + int i = 0; + bool match_any = false; + + // move through character position in word + while (i < n) { + // loop through possible match length + bool match = false; + for (int j = n; j > i; j--) { + auto it = token_map->find(word1.substr(i, j - i)); + if (it != token_map->end()) { + output.push_back(it->second); + match = true; + match_any = true; + i = j; + break; + } + } + + // must be an unknown character + if (!match) i++; + } + + // we didn't find any matches for this word + if (!match_any) { + output.push_back(vocab.special_unk_id); + } + } + + // append eos token + output.push_back(vocab.special_eos_id); + } + + std::vector preprocess(const std::string & text) { + std::string ori_str = text; + ori_str = normalize(ori_str); + uint64_t ori_size = ori_str.size(); + + // single punct / single symbol / single digit + // baseline: add whitespace on the left and right of punct and chinese characters + std::vector words; + std::string new_str = ""; + uint64_t i = 0; + while (i < ori_size) { + int utf_char_len = utf8_len(ori_str[i]); + if ((utf_char_len == 1) && ispunct(ori_str[i])) { + new_str += " "; + new_str += ori_str[i]; + new_str += " "; + i += 1; + } + else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) { + new_str += " "; + new_str += ori_str.substr(i, 3); + new_str += " "; + i += 3; + } + else { + new_str += ori_str[i]; + i += 1; + } + } + + // split by whitespace + uint64_t l = 0; + uint64_t r = 0; + while (r < new_str.size()) { + // if is whitespace + if (isspace(new_str[r])) { + if (r > l) words.push_back(new_str.substr(l, (r - l))); + l = r + 1; + r = l; + } + else { + r += 1; + } + } + if (r > l) { + words.push_back(new_str.substr(l, (r - l))); + } + return words; + } + + std::string normalize(const std::string &text) { + // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98 + std::string text2 = strip_accents(text); + for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) + { + char c = text2[i]; + if (c >= 'A' && c <= 'Z') + text2[i] = c - 'A' + 'a'; + } + return text2; + } + + bool is_chinese_char(const std::string& str) { + int len = str.length(); + unsigned int codepoint = 0; + int num_bytes = 0; + int i = 0; + unsigned char ch = static_cast(str[i]); + if (ch <= 0x7f) { + codepoint = ch; + num_bytes = 1; + } else if ((ch >> 5) == 0x06) { + codepoint = ch & 0x1f; + num_bytes = 2; + } else if ((ch >> 4) == 0x0e) { + codepoint = ch & 0x0f; + num_bytes = 3; + } else if ((ch >> 3) == 0x1e) { + codepoint = ch & 0x07; + num_bytes = 4; + } + for (int j = 1; j < num_bytes; ++j) { + if (i + j >= len) { + return false; // incomplete UTF-8 character + } + unsigned char next_ch = static_cast(str[i + j]); + if ((next_ch >> 6) != 0x02) { + return false; // invalid trailing byte + } + codepoint = (codepoint << 6) | (next_ch & 0x3f); + } + if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || + (codepoint >= 0x3400 && codepoint <= 0x4DBF) || + (codepoint >= 0x20000 && codepoint <= 0x2A6DF) || + (codepoint >= 0x2A700 && codepoint <= 0x2B73F) || + (codepoint >= 0x2B740 && codepoint <= 0x2B81F) || + (codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 + (codepoint >= 0xF900 && codepoint <= 0xFAFF) || + (codepoint >= 0x2F800 && codepoint <= 0x2FA1F) || + (codepoint >= 0x3000 && codepoint <= 0x303F) || + (codepoint >= 0xFF00 && codepoint <= 0xFFEF)) { + return true; + } + return false; + } + + std::string strip_accents(const std::string &inputString) { + std::string resultString; + std::map accentMap = { + {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'}, + {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'}, + {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'}, + {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'}, + {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'}, + {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'}, + {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'}, + {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'}, + {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'},{ "ñ", 'n'}, + }; + + for (size_t i = 0; i < inputString.length();) + { + int len = utf8_len(inputString[i]); + std::string curChar = inputString.substr(i, len); + auto iter = accentMap.find(curChar); + if (iter != accentMap.end()) + { + resultString += iter->second; + } + else + { + resultString += curChar; + } + i += len; + } + + return resultString; + } + + static size_t utf8_len(char src) { + const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; + } + + const llama_vocab & vocab; +}; + typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT @@ -8341,6 +8558,26 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } } break; + case LLAMA_VOCAB_TYPE_WPM: + { + for (const auto & fragment: fragment_buffer) + { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) + { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_wpm tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } + else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + { + output.push_back(fragment.token); + } + } + } break; } return output; @@ -11947,6 +12184,7 @@ static std::string llama_decode_text(const std::string & text) { int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) { if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { + case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_SPM: { // NOTE: we accept all unsupported token types, // suppressing them like CONTROL tokens. diff --git a/llama.h b/llama.h index cec4158bc8e80..367e8f1a105a5 100644 --- a/llama.h +++ b/llama.h @@ -61,6 +61,7 @@ extern "C" { enum llama_vocab_type { LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece }; enum llama_token_type { From 5f1c21d0b609ea835e755c24fe3a13c90c81f030 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 13:28:25 -0600 Subject: [PATCH 04/16] put causal_attn flag in gguf --- convert-hf-to-gguf.py | 1 + llama.cpp | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e047a424cd60c..de41968e24bca 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1588,6 +1588,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + self.gguf_writer.add_bool("bert.attention.causal", False) self.gguf_writer.add_file_type(self.ftype) def set_vocab(self): diff --git a/llama.cpp b/llama.cpp index 5d97241f336a6..49ed16e241356 100644 --- a/llama.cpp +++ b/llama.cpp @@ -263,6 +263,7 @@ enum llm_kv { LLM_KV_ATTENTION_VALUE_LENGTH, LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_CAUSAL, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -319,6 +320,7 @@ static std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -3033,8 +3035,8 @@ static void llm_load_hparams( case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - hparams.causal_attn = false; switch (hparams.n_embd) { case 384: // MiniLM @@ -7601,9 +7603,11 @@ static int llama_decode_internal( if (!lctx.embedding.empty()) { auto & embedding_out = lctx.embedding; + const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0; + embedding_out.resize(n_embd); ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), 0, n_embd*sizeof(float)); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float)); ggml_backend_synchronize(embeddings_backend); } From e3efcf13c82df44c612742c8b91a01dfa2b40934 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 17:33:14 -0500 Subject: [PATCH 05/16] Update convert-hf-to-gguf.py Co-authored-by: Jared Van Bortel --- convert-hf-to-gguf.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index fe94804c9890a..a6edd34f4b94b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1666,12 +1666,11 @@ def set_vocab(self): # convert to phantom space vocab def phantom(tok, typ): - if tok.startswith(b'[') and tok.endswith(b']'): + if tok.startswith(b"[") and tok.endswith(b"]"): return tok - elif tok.startswith(b"##"): + if tok.startswith(b"##"): return tok[2:] - else: - return b"\xe2\x96\x81" + tok + return b"\xe2\x96\x81" + tok tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)] # set up bos and eos tokens (cls and sep) From 96d37f8d55b1565ea4c73016e3975ee411cb9e3e Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 16:41:05 -0600 Subject: [PATCH 06/16] add causal attention gguf key --- convert-hf-to-gguf.py | 4 ++-- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a6edd34f4b94b..6fb84d04996e0 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1647,7 +1647,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) - self.gguf_writer.add_bool("bert.attention.causal", False) + self.gguf_writer.add_causal_attention(False) self.gguf_writer.add_file_type(self.ftype) def set_vocab(self): @@ -1662,7 +1662,7 @@ def set_vocab(self): # we need this to validate the size of the token_type embeddings # though currently we are passing all zeros to the token_type embeddings n_token_types = len(set(toktypes)) - self.gguf_writer.add_uint32("tokenizer.ggml.token_type_count", n_token_types) + self.gguf_writer.add_token_type_count(n_token_types) # convert to phantom space vocab def phantom(tok, typ): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5208a9414da3..a9c13dd3826b8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -50,6 +50,7 @@ class Attention: VALUE_LENGTH = "{arch}.attention.value_length" LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" + CAUSAL = "{arch}.attention.causal" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 20d2e5cff9564..7af58a46c2cb7 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -357,6 +357,9 @@ def add_layer_norm_eps(self, value: float) -> None: def add_layer_norm_rms_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + def add_causal_attention(self, value: bool) -> None: + self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) From e78388d39a353ca9cb61bdc64c6938ab41f3dec2 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 8 Feb 2024 21:17:57 -0500 Subject: [PATCH 07/16] use ctx_output for tok_norm of BERT and BLOOM --- llama.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1d4ea55d64ef3..8a0189c5d5578 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3851,8 +3851,8 @@ static bool llm_load_tensors( model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); - model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -3887,9 +3887,9 @@ static bool llm_load_tensors( } break; case LLM_ARCH_BLOOM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); // output { From b14c457fb46b25c6cc8554610264b0c234edc9b7 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 8 Feb 2024 21:22:23 -0500 Subject: [PATCH 08/16] bert : add some missing graph callbacks --- llama.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama.cpp b/llama.cpp index 8a0189c5d5578..38a5fb87607ec 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5752,6 +5752,7 @@ struct llm_build_context { // embed layer norm inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); + cb(inpL, "inp_norm", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); @@ -5788,6 +5789,7 @@ struct llm_build_context { cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il); struct ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); // feed-forward network cur = llm_build_ffn(ctx0, cur, @@ -5796,6 +5798,7 @@ struct llm_build_context { model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); // attentions bypass the intermediate layer cur = ggml_add(ctx0, cur, ffn_inp); From 68758083d65b51799281fee30f45b9327e755470 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 22:43:26 -0600 Subject: [PATCH 09/16] fix up model sizing and result acquisition --- llama.cpp | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1d4ea55d64ef3..f58eb7c7b65dd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1464,9 +1464,11 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_17M, MODEL_22M, MODEL_33M, MODEL_109M, + MODEL_335M, MODEL_0_5B, MODEL_1B, MODEL_2B, @@ -3040,14 +3042,18 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - switch (hparams.n_embd) { - case 384: // MiniLM - switch (hparams.n_layer) { - case 6: model.type = e_model::MODEL_22M; break; - case 12: model.type = e_model::MODEL_33M; break; + switch (hparams.n_layer) { + case 3: + model.type = e_model::MODEL_17M; break; // bge-micro + case 6: + model.type = e_model::MODEL_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small + case 768: model.type = e_model::MODEL_109M; break; // bge-base } break; - case 768: // BERT-Base - model.type = e_model::MODEL_109M; break; + case 24: + model.type = e_model::MODEL_335M; break; // bge-large } } break; case LLM_ARCH_BLOOM: @@ -3851,8 +3857,8 @@ static bool llm_load_tensors( model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); - model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7481,21 +7487,15 @@ static int llama_decode_internal( ggml_cgraph * gf = llama_build_graph(lctx, batch); - // the output is always the last tensor in the graph - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = nullptr; - if (strcmp(res->name, "result_embed") == 0) { - embeddings = res; - res = nullptr; - } else { - // the embeddings could be the second to last tensor, or the third to last tensor - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - embeddings = gf->nodes[gf->n_nodes - 2]; - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - } + // get logits and embeddings + struct ggml_tensor * res = ggml_graph_get_tensor(gf, "result_output"); + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "result_norm"); + + // if logits are none we must be doing embeddings + if (res == nullptr) { + embeddings = ggml_graph_get_tensor(gf, "result_embed"); } + GGML_ASSERT(res || embeddings); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); From d080bebcc6b46570f5ea40549ab4ef51dff7f3ee Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 23:03:33 -0600 Subject: [PATCH 10/16] hard-code token_type = 0 --- llama.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index f58eb7c7b65dd..13966a8425301 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5746,14 +5746,13 @@ struct llm_build_context { struct ggml_tensor * inpL; // get input vectors with right size - struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); - inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL); - inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); + inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0)); // since type = 0 + inpL = ggml_add(ctx0, inpL, ggml_get_rows(ctx0, model.pos_embd, inp_pos)); cb(inpL, "inp_embd", -1); // embed layer norm From 961e98f24549ef58f5f172bfe2edb92ac0f2fd8f Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Fri, 9 Feb 2024 11:53:17 -0600 Subject: [PATCH 11/16] style fixes --- llama.cpp | 63 +++++++++++++++++++++---------------------------------- 1 file changed, 24 insertions(+), 39 deletions(-) diff --git a/llama.cpp b/llama.cpp index 35770a1ab56d6..e4498c704c636 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8175,7 +8175,9 @@ struct llm_tokenizer_wpm { // find the longest tokens that form the words for (const std::string &word : words) { // skip empty words - if (word.size() == 0) continue; + if (word.size() == 0) { + continue; + } // prepend phantom space std::string word1 = "\xe2\x96\x81" + word; @@ -8201,7 +8203,9 @@ struct llm_tokenizer_wpm { } // must be an unknown character - if (!match) i++; + if (!match) { + i++; + } } // we didn't find any matches for this word @@ -8215,8 +8219,7 @@ struct llm_tokenizer_wpm { } std::vector preprocess(const std::string & text) { - std::string ori_str = text; - ori_str = normalize(ori_str); + std::string ori_str = normalize(text); uint64_t ori_size = ori_str.size(); // single punct / single symbol / single digit @@ -8267,8 +8270,7 @@ struct llm_tokenizer_wpm { std::string normalize(const std::string &text) { // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98 std::string text2 = strip_accents(text); - for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) - { + for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) { char c = text2[i]; if (c >= 'A' && c <= 'Z') text2[i] = c - 'A' + 'a'; @@ -8331,20 +8333,16 @@ struct llm_tokenizer_wpm { {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'}, {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'}, {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'}, - {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'},{ "ñ", 'n'}, + {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'}, }; - for (size_t i = 0; i < inputString.length();) - { + for (size_t i = 0; i < inputString.length();) { int len = utf8_len(inputString[i]); std::string curChar = inputString.substr(i, len); auto iter = accentMap.find(curChar); - if (iter != accentMap.end()) - { + if (iter != accentMap.end()) { resultString += iter->second; - } - else - { + } else { resultString += curChar; } i += len; @@ -8362,12 +8360,12 @@ struct llm_tokenizer_wpm { const llama_vocab & vocab; }; -typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT } FRAGMENT_BUFFER_VARIANT_TYPE; -struct fragment_buffer_variant{ +struct fragment_buffer_variant { fragment_buffer_variant(llama_vocab::id _token) : type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), @@ -8397,8 +8395,7 @@ struct fragment_buffer_variant{ // #define PRETOKENIZERDEBUG -static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) -{ +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { // for each special token for (const auto & st: vocab.special_tokens_cache) { const auto & special_token = st.first; @@ -8516,10 +8513,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & switch (vocab.type) { case LLAMA_VOCAB_TYPE_SPM: { - for (const auto & fragment: fragment_buffer) - { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) - { + for (const auto & fragment: fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { // without adding this leading whitespace, we do not get the same results as the original tokenizer // TODO: It's likely possible to get rid of this string copy entirely @@ -8539,19 +8534,15 @@ static std::vector llama_tokenize_internal(const llama_vocab & llm_tokenizer_spm tokenizer(vocab); llama_escape_whitespace(raw_text); tokenizer.tokenize(raw_text, output); - } - else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - { + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } } } break; case LLAMA_VOCAB_TYPE_BPE: { - for (const auto & fragment: fragment_buffer) - { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) - { + for (const auto & fragment: fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); #ifdef PRETOKENIZERDEBUG @@ -8559,19 +8550,15 @@ static std::vector llama_tokenize_internal(const llama_vocab & #endif llm_tokenizer_bpe tokenizer(vocab); tokenizer.tokenize(raw_text, output); - } - else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - { + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } } } break; case LLAMA_VOCAB_TYPE_WPM: { - for (const auto & fragment: fragment_buffer) - { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) - { + for (const auto & fragment: fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); #ifdef PRETOKENIZERDEBUG @@ -8579,9 +8566,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & #endif llm_tokenizer_wpm tokenizer(vocab); tokenizer.tokenize(raw_text, output); - } - else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - { + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } } From 56afb2f60e47a80231ee4c6b275efeca164a626e Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Fri, 9 Feb 2024 12:00:41 -0600 Subject: [PATCH 12/16] undo attempted type_embd simplify --- llama.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index e4498c704c636..42756b48f2315 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5746,13 +5746,14 @@ struct llm_build_context { struct ggml_tensor * inpL; // get input vectors with right size + struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); - inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0)); // since type = 0 - inpL = ggml_add(ctx0, inpL, ggml_get_rows(ctx0, model.pos_embd, inp_pos)); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); cb(inpL, "inp_embd", -1); // embed layer norm From ab49e9ee45b73ec1808a3df9bafb74146a6af56f Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Fri, 9 Feb 2024 13:44:32 -0500 Subject: [PATCH 13/16] bert : simplify token type embedding access --- convert-hf-to-gguf.py | 5 ++++- llama.cpp | 14 +++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6fb84d04996e0..cae1551a236b0 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1705,7 +1705,10 @@ def write_tensors(self): n_dims = len(data.shape) new_dtype: type[np.floating[Any]] - if self.ftype == 1 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 and name.endswith(".weight") and n_dims == 2 + and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32 + ): # if f16 desired, convert any float32 2-dim weight tensors to float16 new_dtype = np.float16 else: diff --git a/llama.cpp b/llama.cpp index 42756b48f2315..f64571dc81538 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1881,7 +1881,6 @@ struct llama_context { struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch] - struct ggml_tensor * inp_type; // I32 [n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_K_shift; // I32 [n_ctx] struct ggml_tensor * inp_sum; // F32 [1, n_batch] @@ -5746,13 +5745,14 @@ struct llm_build_context { struct ggml_tensor * inpL; // get input vectors with right size - struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); - inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL); + // token types are hardcoded to zero ("Sentence A") + struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); cb(inpL, "inp_embd", -1); @@ -7249,12 +7249,6 @@ static struct ggml_cgraph * llama_build_graph( ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - { - // for embedding models, token type is always zero ("sentence A") - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_type->buffer)); - memset(lctx.inp_type->data, 0, batch.n_tokens * ggml_element_size(lctx.inp_type)); - } - { const int64_t n_kv = llm.n_kv; const int64_t n_tokens = batch.n_tokens; @@ -11240,7 +11234,6 @@ struct llama_context * llama_new_context_with_model( ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); - ctx->inp_type = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch); @@ -11248,7 +11241,6 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_pos, "inp_pos"); - ggml_set_name(ctx->inp_type, "inp_type"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_sum, "inp_sum"); From 6972e7e90ec4b4e0c91b6a51a0c0f81703ff12fe Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Fri, 9 Feb 2024 13:50:31 -0500 Subject: [PATCH 14/16] flake8 : add W503 to ignore list --- .flake8 | 1 + 1 file changed, 1 insertion(+) diff --git a/.flake8 b/.flake8 index 113ca5fd3cb17..18fba2c1574a6 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,3 @@ [flake8] max-line-length = 125 +ignore = W503 From 8fbefed1480dcc92727e45b68f28765f65434650 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 11 Feb 2024 12:59:59 +0200 Subject: [PATCH 15/16] minor : code style normalization --- llama.cpp | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/llama.cpp b/llama.cpp index f64571dc81538..d5c39dc5c35ab 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8262,18 +8262,19 @@ struct llm_tokenizer_wpm { return words; } - std::string normalize(const std::string &text) { + std::string normalize(const std::string & text) { // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98 std::string text2 = strip_accents(text); for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) { char c = text2[i]; - if (c >= 'A' && c <= 'Z') + if (c >= 'A' && c <= 'Z') { text2[i] = c - 'A' + 'a'; + } } return text2; } - bool is_chinese_char(const std::string& str) { + bool is_chinese_char(const std::string & str) { int len = str.length(); unsigned int codepoint = 0; int num_bytes = 0; @@ -8302,24 +8303,24 @@ struct llm_tokenizer_wpm { } codepoint = (codepoint << 6) | (next_ch & 0x3f); } - if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || - (codepoint >= 0x3400 && codepoint <= 0x4DBF) || + if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || + (codepoint >= 0x3400 && codepoint <= 0x4DBF) || (codepoint >= 0x20000 && codepoint <= 0x2A6DF) || (codepoint >= 0x2A700 && codepoint <= 0x2B73F) || (codepoint >= 0x2B740 && codepoint <= 0x2B81F) || (codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 - (codepoint >= 0xF900 && codepoint <= 0xFAFF) || + (codepoint >= 0xF900 && codepoint <= 0xFAFF) || (codepoint >= 0x2F800 && codepoint <= 0x2FA1F) || - (codepoint >= 0x3000 && codepoint <= 0x303F) || - (codepoint >= 0xFF00 && codepoint <= 0xFFEF)) { - return true; + (codepoint >= 0x3000 && codepoint <= 0x303F) || + (codepoint >= 0xFF00 && codepoint <= 0xFFEF)) { + return true; // NOLINT } return false; } - std::string strip_accents(const std::string &inputString) { + std::string strip_accents(const std::string & input_string) { std::string resultString; - std::map accentMap = { + std::map accent_map = { {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'}, {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'}, {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'}, @@ -8331,11 +8332,11 @@ struct llm_tokenizer_wpm { {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'}, }; - for (size_t i = 0; i < inputString.length();) { - int len = utf8_len(inputString[i]); - std::string curChar = inputString.substr(i, len); - auto iter = accentMap.find(curChar); - if (iter != accentMap.end()) { + for (size_t i = 0; i < input_string.length();) { + int len = utf8_len(input_string[i]); + std::string curChar = input_string.substr(i, len); + auto iter = accent_map.find(curChar); + if (iter != accent_map.end()) { resultString += iter->second; } else { resultString += curChar; From e379e8c10b11f20ca6a2fdeb19b45a05a9247128 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Sun, 11 Feb 2024 09:50:18 -0600 Subject: [PATCH 16/16] avoid use of ggml_graph_get_tensor --- llama.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index f64571dc81538..51c2264dbc125 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7484,15 +7484,21 @@ static int llama_decode_internal( ggml_cgraph * gf = llama_build_graph(lctx, batch); - // get logits and embeddings - struct ggml_tensor * res = ggml_graph_get_tensor(gf, "result_output"); - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "result_norm"); - - // if logits are none we must be doing embeddings - if (res == nullptr) { - embeddings = ggml_graph_get_tensor(gf, "result_embed"); + // the output is always the last tensor in the graph + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + if (strcmp(res->name, "result_output") == 0) { + // the embeddings could be the second to last tensor, or the third to last tensor + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } + } else if (strcmp(res->name, "result_embed") == 0) { + embeddings = res; + res = nullptr; + } else { + GGML_ASSERT(false); } - GGML_ASSERT(res || embeddings); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy