mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Give the user the option to override where model weights are stored
This commit is contained in:
@@ -265,6 +265,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
params.kv_overrides.emplace_back();
|
||||
params.kv_overrides.back().key[0] = 0;
|
||||
}
|
||||
if (!params.tensor_buft_overrides.empty()) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -287,6 +290,40 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tensor_buft_override>& overrides) {
|
||||
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||
if (buft_list.empty()) {
|
||||
// enumerate all the devices and add their buffer types to the list
|
||||
for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
|
||||
//auto * dev = ggml_backend_reg_get_name(i);
|
||||
auto * buft = ggml_backend_reg_get_default_buffer_type(i);
|
||||
if (buft) {
|
||||
buft_list[ggml_backend_buft_name(buft)] = buft;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto & override : string_split<std::string>(value, ',')) {
|
||||
std::string::size_type pos = override.find('=');
|
||||
if (pos == std::string::npos) {
|
||||
fprintf(stderr, "Invalid buft override argument %s\n", value.c_str());
|
||||
return false;
|
||||
}
|
||||
std::string tensor_name = override.substr(0, pos);
|
||||
std::string buffer_type = override.substr(pos + 1);
|
||||
if (buft_list.find(buffer_type) == buft_list.end()) {
|
||||
fprintf(stderr, "Available buffer types:\n");
|
||||
for (const auto & it : buft_list) {
|
||||
fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
|
||||
|
||||
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
||||
@@ -1120,6 +1157,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--override-tensor" || arg == "-ot") {
|
||||
CHECK_ARG
|
||||
if (!parse_buft_overrides(std::string{argv[i]}, params.tensor_buft_overrides)) {
|
||||
fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
|
||||
invalid_param = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--host") {
|
||||
CHECK_ARG
|
||||
params.hostname = argv[i];
|
||||
@@ -2238,6 +2283,12 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
|
||||
mparams.kv_overrides = params.kv_overrides.data();
|
||||
}
|
||||
if (params.tensor_buft_overrides.empty()) {
|
||||
mparams.tensor_buft_overrides = NULL;
|
||||
} else {
|
||||
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
|
||||
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
|
||||
}
|
||||
|
||||
return mparams;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user