Skip to content

Commit bd1eaef

Browse files
committed
fix: convert f64 to f32 and i64 to i32 when loading weights
1 parent ab835f7 commit bd1eaef

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

model.cpp

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -815,13 +815,28 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
815815
dst[i] = f8_e4m3_to_f16(src[i]);
816816
}
817817
}
818+
818819
void f8_e5m2_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
819820
// support inplace op
820821
for (int64_t i = n - 1; i >= 0; i--) {
821822
dst[i] = f8_e5m2_to_f16(src[i]);
822823
}
823824
}
824825

826+
void f64_to_f32_vec(double* src, float* dst, int64_t n) {
827+
// support inplace op
828+
for (int64_t i = 0; i < n; i++) {
829+
dst[i] = (float)src[i];
830+
}
831+
}
832+
833+
void i64_to_i32_vec(int64_t* src, int32_t* dst, int64_t n) {
834+
// support inplace op
835+
for (int64_t i = 0; i < n; i++) {
836+
dst[i] = (int32_t)src[i];
837+
}
838+
}
839+
825840
void convert_tensor(void* src,
826841
ggml_type src_type,
827842
void* dst,
@@ -1057,13 +1072,13 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
10571072
} else if (dtype == "F32") {
10581073
ttype = GGML_TYPE_F32;
10591074
} else if (dtype == "F64") {
1060-
ttype = GGML_TYPE_F64;
1075+
ttype = GGML_TYPE_F32;
10611076
} else if (dtype == "F8_E4M3") {
10621077
ttype = GGML_TYPE_F16;
10631078
} else if (dtype == "F8_E5M2") {
10641079
ttype = GGML_TYPE_F16;
10651080
} else if (dtype == "I64") {
1066-
ttype = GGML_TYPE_I64;
1081+
ttype = GGML_TYPE_I32;
10671082
}
10681083
return ttype;
10691084
}
@@ -1185,6 +1200,14 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11851200
tensor_storage.is_f8_e5m2 = true;
11861201
// f8 -> f16
11871202
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1203+
} else if (dtype == "F64") {
1204+
tensor_storage.is_f64 = true;
1205+
// f64 -> f32
1206+
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
1207+
} else if (dtype == "I64") {
1208+
tensor_storage.is_i64 = true;
1209+
// i64 -> i32
1210+
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
11881211
} else {
11891212
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
11901213
}
@@ -1945,7 +1968,12 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19451968
// for the CPU and Metal backend, we can copy directly into the tensor
19461969
if (tensor_storage.type == dst_tensor->type) {
19471970
GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes());
1948-
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
1971+
if (tensor_storage.is_f64 || tensor_storage.is_i64) {
1972+
read_buffer.resize(tensor_storage.nbytes_to_read());
1973+
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
1974+
} else {
1975+
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
1976+
}
19491977

19501978
if (tensor_storage.is_bf16) {
19511979
// inplace op
@@ -1956,9 +1984,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19561984
} else if (tensor_storage.is_f8_e5m2) {
19571985
// inplace op
19581986
f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
1987+
} else if (tensor_storage.is_f64) {
1988+
f64_to_f32_vec((double*)read_buffer.data(), (float*)dst_tensor->data, tensor_storage.nelements());
1989+
} else if (tensor_storage.is_i64) {
1990+
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements());
19591991
}
19601992
} else {
1961-
read_buffer.resize(tensor_storage.nbytes());
1993+
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
19621994
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
19631995

19641996
if (tensor_storage.is_bf16) {
@@ -1970,13 +2002,19 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19702002
} else if (tensor_storage.is_f8_e5m2) {
19712003
// inplace op
19722004
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
2005+
} else if (tensor_storage.is_f64) {
2006+
// inplace op
2007+
f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
2008+
} else if (tensor_storage.is_i64) {
2009+
// inplace op
2010+
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements());
19732011
}
19742012

19752013
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
19762014
dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
19772015
}
19782016
} else {
1979-
read_buffer.resize(tensor_storage.nbytes());
2017+
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
19802018
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
19812019

19822020
if (tensor_storage.is_bf16) {
@@ -1988,6 +2026,12 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19882026
} else if (tensor_storage.is_f8_e5m2) {
19892027
// inplace op
19902028
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
2029+
} else if (tensor_storage.is_f64) {
2030+
// inplace op
2031+
f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
2032+
} else if (tensor_storage.is_i64) {
2033+
// inplace op
2034+
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements());
19912035
}
19922036

19932037
if (tensor_storage.type == dst_tensor->type) {

model.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ struct TensorStorage {
102102
bool is_bf16 = false;
103103
bool is_f8_e4m3 = false;
104104
bool is_f8_e5m2 = false;
105+
bool is_f64 = false;
106+
bool is_i64 = false;
105107
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
106108
int n_dims = 0;
107109

@@ -133,6 +135,8 @@ struct TensorStorage {
133135
int64_t nbytes_to_read() const {
134136
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
135137
return nbytes() / 2;
138+
} else if (is_f64 || is_i64) {
139+
return nbytes() * 2;
136140
} else {
137141
return nbytes();
138142
}
@@ -183,6 +187,10 @@ struct TensorStorage {
183187
type_name = "f8_e4m3";
184188
} else if (is_f8_e5m2) {
185189
type_name = "f8_e5m2";
190+
} else if (is_f64) {
191+
type_name = "f64";
192+
} else if (is_i64) {
193+
type_name = "i64";
186194
}
187195
ss << name << " | " << type_name << " | ";
188196
ss << n_dims << " [";

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy