mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 08:20:09 +00:00
Be able to read uint32_t and bool arrays from GGUFs (#1252)
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
#include <array>
|
||||
#include <future>
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
@@ -551,15 +552,21 @@ bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & resul
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_UINT32:
|
||||
case GGUF_TYPE_BOOL:
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same_v<T, int32_t>) || (std::is_same_v<T, uint32_t>)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
throw std::runtime_error(format("%s is not a float32, int32, uint32 or bool array", key.c_str()));
|
||||
}
|
||||
|
||||
result.resize(arr_info.length);
|
||||
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
||||
if (arr_info.gt == GGUF_TYPE_BOOL) {
|
||||
std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(),
|
||||
[] (bool x) { return static_cast<T>(x); });
|
||||
|
||||
} else {
|
||||
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -579,19 +586,24 @@ bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> &
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same_v<T, float>)); break;
|
||||
case GGUF_TYPE_UINT32:
|
||||
case GGUF_TYPE_BOOL:
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same_v<T, int32_t>) || (std::is_same_v<T, uint32_t>)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
if (arr_info.length > N_MAX) {
|
||||
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
|
||||
}
|
||||
|
||||
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
||||
if (arr_info.gt == GGUF_TYPE_BOOL) {
|
||||
std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(),
|
||||
[] (bool x) { return static_cast<T>(x); });
|
||||
} else {
|
||||
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user