diff --git a/README.md b/README.md index ef84c199..7b41a621 100644 --- a/README.md +++ b/README.md @@ -130,10 +130,10 @@ This option would require a little more manual work, but you can use it with any ```commandline # Windows -python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16 +python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16 # Linux / MacOS -python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16 +python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16 ``` **Optionally**, quantize the model into one of quantized formats from the table above: diff --git a/rwkv.cpp b/rwkv.cpp index b99ac6a0..c90f07a6 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -174,8 +174,8 @@ bool rwkv_fwrite_data(FILE * file, const void * data, const size_t length) { #define TYPE_UNKNOWN TYPE_COUNT enum rwkv_type { - TYPE_F32, - TYPE_F16, + TYPE_FP32, + TYPE_FP16, TYPE_Q4_0, TYPE_Q4_1, TYPE_Q4_1_O, // Unsupported @@ -204,8 +204,8 @@ extern const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { }; extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { - TYPE_F32, /* F32 */ - TYPE_F16, /* F16 */ + TYPE_FP32, /* FP32 */ + TYPE_FP16, /* FP16 */ TYPE_Q4_0, /* Q4_0 */ TYPE_Q4_1, /* Q4_1 */ TYPE_Q4_2, /* Q4_2 */ @@ -220,7 +220,7 @@ extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { TYPE_COUNT, /* COUNT */ }; -extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"float32", "float16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; +extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"FP32", "FP16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; enum rwkv_type rwkv_type_from_string(const char * str) { for (int ord = 0; ord < TYPE_COUNT; ord++) { @@ -363,27 +363,82 @@ bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = return true; } -bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { +// Returns true, if a tensor with specified type, name and dimension count shound be quantized to target_type. +// Returns false, if a tensor should be left as-is. +bool rwkv_should_be_quantized(const ggml_type source_type, const ggml_type target_type, const std::string & name, const uint32_t dim_count) { + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take little space, especially in bigger models; + // but they significantly increase perplexity when quantized. + return (source_type == GGML_TYPE_F32 || source_type == GGML_TYPE_F16) && + target_type != GGML_TYPE_COUNT && + target_type != source_type && + ggml_is_quantized(target_type) && + dim_count == 2 && + name != "emb.weight" && + name != "head.weight"; +} + +bool rwkv_fread_ggml_tensor_data( + FILE * file, + const struct rwkv_tensor_header & header, + struct ggml_context * ctx, + std::string & name, + struct ggml_tensor *& tensor, + const ggml_type target_type +) { RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); - tensor = header.dim_count == 1 - ? ggml_new_tensor_1d(ctx, ggml_type, header.width) - : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + if (rwkv_should_be_quantized(ggml_type, target_type, name, header.dim_count)) { + size_t buffer_size_bytes = header.dim_count == 1 + ? rwkv_tensor_size(ggml_type, header.width) + : rwkv_tensor_size(ggml_type, header.width, header.height); + + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, target_type, header.width) + : ggml_new_tensor_2d(ctx, target_type, header.width, header.height); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); + + std::unique_ptr buffer(new(std::nothrow) char[buffer_size_bytes]); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, buffer_size_bytes, buffer.get()), "Failed to read tensor data from %s", name.c_str()); + + // Quantization works only with FP32 values + if (header.data_type == TYPE_FP16) { + std::unique_ptr float_buffer(new(std::nothrow) char[buffer_size_bytes * 2]); + + ggml_fp16_to_fp32_row((const ggml_fp16_t *) buffer.get(), (float *) float_buffer.get(), ggml_nelements(tensor)); + + buffer.reset(float_buffer.release()); + } + + int64_t histogram[16] {}; + + ggml_quantize_chunk(target_type, (const float *) buffer.get(), tensor->data, 0, ggml_nelements(tensor), histogram); + + buffer.reset(); + } else { + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, ggml_type, header.width) + : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); - ggml_set_name(tensor, name.c_str()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + } - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); return true; } -bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { +bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor, const ggml_type target_type) { struct rwkv_tensor_header header; RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); - return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor); + return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor, target_type); } bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { @@ -429,7 +484,7 @@ struct rwkv_model { struct ggml_tensor * ln0_weight; struct ggml_tensor * ln0_bias; - std::unique_ptr layers; + std::unique_ptr layers; struct ggml_tensor * ln_out_weight; struct ggml_tensor * ln_out_bias; @@ -580,9 +635,9 @@ struct rwkv_context { // Reused by all graphs. struct rwkv_ggml_context ctx; struct ggml_tensor * input_state; - std::unique_ptr input_layers; + std::unique_ptr input_layers; struct ggml_tensor * output_state; - std::unique_ptr output_layers; + std::unique_ptr output_layers; struct ggml_tensor * logits; uint32_t n_threads; @@ -610,7 +665,7 @@ bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); uint32_t n_layer = model.header.n_layer; - std::unique_ptr layers(new(std::nothrow) struct rwkv_layer [n_layer]); + std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); model.layers = std::move(layers); @@ -1037,7 +1092,7 @@ bool rwkv_build_sequence_graph( struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, model.ln0_weight, x), ggml_repeat(ctx, model.ln0_bias, x)); - + for (size_t i = 0; i < model.header.n_layer; i++) { struct rwkv_layer & layer = model.layers[i]; struct rwkv_layer_state state = inputs[i]; @@ -1115,7 +1170,7 @@ struct rwkv_file { } }; -bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) { +bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance, const ggml_type target_type) { struct stat file_stat; struct rwkv_model model; struct rwkv_ggml_context ctx; @@ -1140,7 +1195,13 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data"); - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); + enum ggml_type source_type = rwkv_type_to_ggml[tensor_header.data_type]; + + enum ggml_type in_memory_type = rwkv_should_be_quantized(source_type, target_type, name, tensor_header.dim_count) + ? target_type + : source_type; + + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, in_memory_type, tensor_header.width, tensor_header.height); if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { ffn_key_size = tensor_header.height; @@ -1156,13 +1217,13 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst struct ggml_tensor * tensor; while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor), "Failed to read model params"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor, target_type), "Failed to read model params"); parameters[std::move(name)] = tensor; } } std::unordered_map & parameters_ref = parameters; - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model,[&](const char * key, struct ggml_tensor *& dest) { struct ggml_tensor * tensor = parameters_ref[key]; RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); dest = tensor; @@ -1203,11 +1264,11 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr inputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); // We collect parts of output state here. Each part is (n_embed) vector. - std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); for (size_t i = 0; i < n_layer; i++) { @@ -1258,12 +1319,45 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr instance(new(std::nothrow) struct rwkv_instance()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance, "Failed to allocate instance"); - RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get())); + RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get(), target_type)); return rwkv_new_context_impl(instance, n_threads); } @@ -1438,10 +1532,10 @@ void rwkv_free(struct rwkv_context * ctx) { std::unique_ptr rwkv_ctx(ctx); } -bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { +bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * format_name) { global_last_error = RWKV_ERROR_NONE; - enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(type_name)]; + enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(format_name)]; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, ggml_is_quantized(out_type), "Unsupported output data type (%s)", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]); RWKV_MSG("Loading model from '%s'\n", in_path); @@ -1464,7 +1558,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_ASSERT_FALSE_MSG( RWKV_ERROR_FILE, in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, - "Unsupported input data type (%s); needs to be F32 or F16", + "Unsupported input data type (%s); needs to be FP32 or FP16", rwkv_type_to_string[rwkv_type_from_ggml[in_type]] ); @@ -1477,7 +1571,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const size_t orig_total_size = 0; size_t new_total_size = 0; - // Required to init the fp16 tables + // Required to init the F16 tables // Doesn't crash if ggml_init fails ggml_free(ggml_init({ 0, NULL, true })); @@ -1496,7 +1590,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const } // f16 type tensors get relocated to out and then converted into f32 at in - if (header.data_type == TYPE_F16) { + if (header.data_type == TYPE_FP16) { if (in_size > max_out_size) { max_out_size = in_size; } @@ -1524,7 +1618,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! int64_t hist_all[16] {}; - std::unique_ptr scratch(new(std::nothrow) uint8_t [max_in_size + max_out_size]); + std::unique_ptr scratch(new(std::nothrow) uint8_t[max_in_size + max_out_size]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate buffer"); uint8_t * in_buf = scratch.get(); @@ -1542,19 +1636,16 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const const char * name_str = name.c_str(); RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); - data = header.data_type == TYPE_F16 ? out_buf : in_buf; + data = header.data_type == TYPE_FP16 ? out_buf : in_buf; size_t orig_size = rwkv_tensor_size(header), new_size = orig_size; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); - // Quantize only 2D tensors, except embedding and head matrices. - // Embedding and head take not too much space, especially in bigger models; - // but they significantly increase perplexity when quantized. - if ((header.data_type == TYPE_F32 || header.data_type == TYPE_F16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { + if (rwkv_should_be_quantized(rwkv_type_to_ggml[header.data_type], out_type, name, header.dim_count)) { RWKV_MSG("quantizing... "); size_t nelements = (size_t) header.width * (size_t) header.height; - if (header.data_type == TYPE_F16) { + if (header.data_type == TYPE_FP16) { ggml_fp16_to_fp32_row((const ggml_fp16_t *) out_buf, (float *) in_buf, nelements); } diff --git a/rwkv.h b/rwkv.h index 8327425e..3ea021e4 100644 --- a/rwkv.h +++ b/rwkv.h @@ -83,10 +83,53 @@ extern "C" { // - ctx: the context the retrieve the error for, or NULL for the global error. RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); + enum rwkv_init_from_file_option_key { + // Sets target format of model parameters. + // + // If an FP16 or FP32 model is being loaded, and this option is set, + // parameters will be quantized just-in-time into the specified format. + // If an already quantized model is being loaded, value of this option is ignored. + // The function will not read the whole model file at once, but will do quantization tensor-by-tensor; + // it is safe to load big models which will fit into RAM when quantized. + // Use of this option will introduce significant one-time delay when loading the model. + // + // Intended use-case is to have only FP16 model on disk, while not wasting + // the disk space on models of all available quantized formats. + // + // Allowed values: + // - Q4_0 + // - Q4_1 + // - Q5_0 + // - Q5_1 + // - Q8_0 + RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME, + // Do not use this as an actual option key. + RWKV_INIT_FROM_FILE_OPTION_COUNT + }; + + struct rwkv_init_from_file_option { + // Key of the option. + enum rwkv_init_from_file_option_key key; + // Value of the option as a NULL-terminated, UTF-8 encoded string. + char * value; + }; + // Loads the model from a file and prepares it for inference. + // Loading behavior can be customized with options, but none of them are required. + // Function behavior when multiple options with the same key are specified is undefined. // Returns NULL on any error. // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. + // - options: array of options. Passing NULL is the same as setting option_count to 0. + // - option_count: size of the options array. + RWKV_API struct rwkv_context * rwkv_init_from_file_ex( + const char * model_file_path, + const uint32_t n_threads, + const struct rwkv_init_from_file_option * options, + const size_t option_count + ); + + // Same as rwkv_init_from_file_ex, but passing an empty array of options. RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); // Creates a new context from an existing one. diff --git a/rwkv/convert_pytorch_to_ggml.py b/rwkv/convert_pytorch_to_ggml.py index 2ea4a48d..18debea4 100644 --- a/rwkv/convert_pytorch_to_ggml.py +++ b/rwkv/convert_pytorch_to_ggml.py @@ -1,5 +1,5 @@ # Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file. -# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 +# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16 # Get model checkpoints from https://huggingface.co/BlinkDL # See FILE_FORMAT.md for the documentation on the file format. @@ -12,7 +12,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file') parser.add_argument('src_path', help='Path to PyTorch checkpoint file') parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten') - parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32') + parser.add_argument('data_type', help='Data type, FP16 or FP32', type=str, choices=['FP16', 'FP32'], default='FP16') return parser.parse_args() def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: @@ -26,6 +26,8 @@ def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: return n_layer def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None: + is_FP16 = data_type == 'FP16' or data_type == 'float16' + emb_weight: torch.Tensor = state_dict['emb.weight'] n_layer = get_layer_count(state_dict) @@ -42,7 +44,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t n_vocab, n_embed, n_layer, - 1 if data_type == 'float16' else 0 + 1 if is_FP16 else 0 )) for k in state_dict.keys(): @@ -56,8 +58,8 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if '.time_decay' in k: tensor = -torch.exp(tensor) - # Keep 1-dim vectors in fp32 - if data_type == 'float16' and len(tensor.shape) > 1: + # Keep 1-dim vectors in FP32 + if is_FP16 and len(tensor.shape) > 1: tensor = tensor.half() shape = tensor.shape diff --git a/rwkv/convert_pytorch_to_ggml.test.py b/rwkv/convert_pytorch_to_ggml.test.py index 9ced1d05..501a85ef 100644 --- a/rwkv/convert_pytorch_to_ggml.test.py +++ b/rwkv/convert_pytorch_to_ggml.test.py @@ -13,7 +13,7 @@ def test() -> None: 'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) } - convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') + convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='FP32') with open(test_file_path, 'rb') as input: actual_bytes: bytes = input.read() diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index b38c7ce2..fd775731 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -2,7 +2,7 @@ import torch import multiprocessing import rwkv_cpp_shared_library -from typing import Tuple, Optional +from typing import Dict, Tuple, Optional class RWKVModel: """ @@ -15,6 +15,7 @@ def __init__( model_path: str, thread_count: int = max(1, multiprocessing.cpu_count() // 2), gpu_layers_count: int = 0, + options: Optional[Dict[rwkv_cpp_shared_library.RWKVInitFromFileOptionKey, str]] = None ): """ Loads the model and prepares it for inference. @@ -28,6 +29,8 @@ def __init__( Path to RWKV model file in ggml format. thread_count : int Thread count to use. If not set, defaults to CPU count / 2. + options : Optional[Dict[RWKVInitFromFileOptionKey, str]] + Options passed to rwkv_init_from_file_ex. """ assert os.path.isfile(model_path), f'{model_path} is not a file' @@ -36,7 +39,7 @@ def __init__( self._library = shared_library - self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, options) if gpu_layers_count > 0: self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count) diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index a38cbbb2..562bf7a1 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -2,7 +2,8 @@ import sys import ctypes import pathlib -from typing import Optional +import enum +from typing import Dict, Optional QUANTIZED_FORMAT_NAMES = ( 'Q4_0', @@ -14,6 +15,29 @@ P_FLOAT = ctypes.POINTER(ctypes.c_float) +class RWKVInitFromFileOptionKey(enum.Enum): + # Sets target format of model parameters. + # + # If an FP16 or FP32 model is being loaded, and this option is set, + # parameters will be quantized just-in-time into the specified format. + # If an already quantized model is being loaded, value of this option is ignored. + # The function will not read the whole model file at once, but will do quantization tensor-by-tensor; + # it is safe to load big models which will fit into RAM when quantized. + # Use of this option will introduce significant one-time delay when loading the model. + # + # Intended use-case is to have only FP16 model on disk, while not wasting + # the disk space on models of all available quantized formats. + # + # For allowed values, see QUANTIZED_FORMAT_NAMES. + TARGET_FORMAT_NAME = 0 + +class RWKVInitFromFileOption(ctypes.Structure): + + _fields_ = [ + ('key', ctypes.c_int), + ('value', ctypes.c_char_p) + ] + class RWKVContext: def __init__(self, ptr: ctypes.pointer): @@ -37,8 +61,8 @@ def __init__(self, shared_library_path: str): self.library = ctypes.cdll.LoadLibrary(shared_library_path) - self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] - self.library.rwkv_init_from_file.restype = ctypes.c_void_p + self.library.rwkv_init_from_file_ex.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.POINTER(RWKVInitFromFileOption), ctypes.c_size_t] + self.library.rwkv_init_from_file_ex.restype = ctypes.c_void_p self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32] self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool @@ -70,9 +94,10 @@ def __init__(self, shared_library_path: str): self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p - def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: + def rwkv_init_from_file(self, model_file_path: str, thread_count: int, options: Optional[Dict[RWKVInitFromFileOptionKey, str]] = None) -> RWKVContext: """ Loads the model from a file and prepares it for inference. + Loading behavior can be customized with options, but none of them are required. Throws an exception in case of any error. Error messages would be printed to stderr. Parameters @@ -81,9 +106,25 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo Path to model file in ggml format. thread_count : int Count of threads to use, must be positive. + options : Optional[Dict[RWKVInitFromFileOptionKey, str]] + Options passed to rwkv_init_from_file_ex. """ - ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) + options_count = 0 + options_ptr = None + + if options is not None and len(options) > 0: + options_count = len(options) + options_ptr = (RWKVInitFromFileOption * options_count)() + + i = 0 + for k, v in options.items(): + options_ptr[i].key = k.value + options_ptr[i].value = v.encode('utf-8') + + i += 1 + + ptr = self.library.rwkv_init_from_file_ex(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count), options_ptr, options_count) assert ptr is not None, 'rwkv_init_from_file failed, check stderr' diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d176f7bd..4090c9b7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,6 +12,7 @@ file(COPY tiny-rwkv-660K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY tiny-rwkv-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) -rwkv_add_test(test_ggml_basics.c) -rwkv_add_test(test_tiny_rwkv.c) -rwkv_add_test(test_context_cloning.c) +file(GLOB tests *.c) +foreach (test ${tests}) + rwkv_add_test(${test}) +endforeach() diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index eb0f7c4c..6632475b 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -1,4 +1,6 @@ -#include +// Tests that after context cloning evaluation gives identical results. + +#include "rwkv.h" #include #include diff --git a/tests/test_quantization_on_the_fly.c b/tests/test_quantization_on_the_fly.c new file mode 100644 index 00000000..d63ba8b0 --- /dev/null +++ b/tests/test_quantization_on_the_fly.c @@ -0,0 +1,91 @@ +// Tests that results from on-the-fly quantized model are identical with results of pre-quantized model. + +#include "ggml.h" +#include "rwkv.h" + +#include +#include +#include + +#define N_THREADS 2 + +int main(void) { + rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1"); + + struct rwkv_context * prequantized_ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32-Q5_1.bin", N_THREADS); + + if (!prequantized_ctx) { + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + fprintf(stderr, "Unexpected error 0x%.8X\n", error); + return EXIT_FAILURE; + } + + // --- + + struct rwkv_init_from_file_option option = {RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME, "Q5_1"}; + + struct rwkv_context * on_the_fly_quantized_ctx = rwkv_init_from_file_ex("tiny-rwkv-660K-FP32.bin", N_THREADS, &option, 1); + + if (!on_the_fly_quantized_ctx) { + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + fprintf(stderr, "Unexpected error 0x%.8X\n", error); + return EXIT_FAILURE; + } + + // --- + + float * state = calloc(rwkv_get_state_len(prequantized_ctx), sizeof(float)); + + if (!state) { + fprintf(stderr, "Failed to allocate state\n"); + return EXIT_FAILURE; + } + + float * expected_logits = calloc(rwkv_get_logits_len(prequantized_ctx), sizeof(float)); + + if (!expected_logits) { + fprintf(stderr, "Failed to allocate logits\n"); + return EXIT_FAILURE; + } + + const unsigned char prompt[12] = "hello world"; + + rwkv_eval(prequantized_ctx, prompt[0], NULL, state, expected_logits); + + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(prequantized_ctx, prompt[i], state, state, expected_logits); + } + + // --- + + float * actual_logits = calloc(rwkv_get_logits_len(on_the_fly_quantized_ctx), sizeof(float)); + + if (!actual_logits) { + fprintf(stderr, "Failed to allocate logits\n"); + return EXIT_FAILURE; + } + + rwkv_eval(on_the_fly_quantized_ctx, prompt[0], NULL, state, actual_logits); + + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(on_the_fly_quantized_ctx, prompt[i], state, state, actual_logits); + } + + // --- + + if (memcmp(expected_logits, actual_logits, rwkv_get_logits_len(on_the_fly_quantized_ctx) * sizeof(float))) { + fprintf(stderr, "Results not identical :(\n"); + return EXIT_FAILURE; + } else { + fprintf(stdout, "Results identical, success!\n"); + } + + rwkv_free(on_the_fly_quantized_ctx); + rwkv_free(prequantized_ctx); + + free(expected_logits); + free(actual_logits); + free(state); + + return 0; +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy