Give the user the option to override where model weights are stored

This commit is contained in:
Iwan Kawrakow
2025-02-24 16:02:31 +02:00
parent 547eee81d9
commit 2572a6de3c
5 changed files with 781 additions and 621 deletions

View File

@@ -265,6 +265,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.kv_overrides.emplace_back();
params.kv_overrides.back().key[0] = 0;
}
if (!params.tensor_buft_overrides.empty()) {
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
return true;
}
@@ -287,6 +290,40 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return true;
}
namespace {
bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tensor_buft_override>& overrides) {
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
//auto * dev = ggml_backend_reg_get_name(i);
auto * buft = ggml_backend_reg_get_default_buffer_type(i);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}
for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
fprintf(stderr, "Invalid buft override argument %s\n", value.c_str());
return false;
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
fprintf(stderr, "Available buffer types:\n");
for (const auto & it : buft_list) {
fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second));
}
return false;
}
overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
}
return true;
}
}
#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
@@ -1120,6 +1157,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
if (arg == "--override-tensor" || arg == "-ot") {
CHECK_ARG
if (!parse_buft_overrides(std::string{argv[i]}, params.tensor_buft_overrides)) {
fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
invalid_param = true;
}
return true;
}
if (arg == "--host") {
CHECK_ARG
params.hostname = argv[i];
@@ -2238,6 +2283,12 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
mparams.kv_overrides = params.kv_overrides.data();
}
if (params.tensor_buft_overrides.empty()) {
mparams.tensor_buft_overrides = NULL;
} else {
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
}
return mparams;
}

View File

@@ -135,6 +135,7 @@ struct gpt_params {
std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale

View File

@@ -236,6 +236,7 @@ struct cmd_params {
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
std::vector<llama_model_tensor_buft_override> buft_overrides;
ggml_numa_strategy numa;
int reps;
bool verbose;
@@ -267,6 +268,7 @@ static const cmd_params cmd_params_defaults = {
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
/* buft_overrides */ {},
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
/* reps */ 5,
/* verbose */ false,
@@ -309,6 +311,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0");
printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0");
printf(" -ot, --override-tensor pattern (default: none)\n");
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
@@ -349,6 +352,39 @@ static ggml_type ggml_type_from_name(const std::string & s) {
return GGML_TYPE_COUNT;
}
namespace {
bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tensor_buft_override>& overrides) {
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
//auto * dev = ggml_backend_reg_get_name(i);
auto * buft = ggml_backend_reg_get_default_buffer_type(i);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}
for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
fprintf(stderr, "Invalid buft override argument %s\n", value.c_str());
return false;
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
fprintf(stderr, "Available buffer types:\n");
for (const auto & it : buft_list) {
fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second));
}
return false;
}
overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
}
return true;
}
}
static cmd_params parse_cmd_params(int argc, char ** argv) {
cmd_params params;
@@ -616,6 +652,16 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.fmoe = std::stoi(argv[i]);
} else if (arg == "-ot" || arg == "--override-tensor") {
if (++i >= argc) {
invalid_param = true;
break;
}
if (!parse_buft_overrides(std::string{argv[i]}, params.buft_overrides)) {
fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
invalid_param = true;
break;
}
} else {
invalid_param = true;
break;
@@ -648,6 +694,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; }
if (!params.buft_overrides.empty()) params.buft_overrides.emplace_back(llama_model_tensor_buft_override{nullptr, nullptr});
return params;
}
@@ -685,6 +732,7 @@ struct cmd_params_instance {
bool embeddings;
bool repack = false;
bool fmoe = false;
const llama_model_tensor_buft_override* buft_overrides;
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
@@ -698,6 +746,7 @@ struct cmd_params_instance {
mparams.tensor_split = tensor_split.data();
mparams.use_mmap = use_mmap;
mparams.repack_tensors = repack;
mparams.tensor_buft_overrides = buft_overrides;
return mparams;
}
@@ -777,6 +826,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}
@@ -807,6 +857,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}
@@ -837,6 +888,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}
@@ -867,6 +919,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}

View File

@@ -305,6 +305,11 @@ extern "C" {
};
};
struct llama_model_tensor_buft_override {
const char * pattern;
ggml_backend_buffer_type_t buft;
};
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
@@ -332,6 +337,8 @@ extern "C" {
// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible

File diff suppressed because it is too large Load Diff