mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 08:34:09 +00:00
Option to use MLA without a transposed cache
The `-mla` command line option turns into an int from a bool. mla = 0: use standard attention mla = 1: use MLA with transposed cache mla > 1: use MLA without transposed cache
This commit is contained in:
@@ -232,7 +232,7 @@ struct cmd_params {
|
||||
std::vector<int> main_gpu;
|
||||
std::vector<bool> no_kv_offload;
|
||||
std::vector<bool> flash_attn;
|
||||
std::vector<bool> mla_attn;
|
||||
std::vector<int> mla_attn;
|
||||
std::vector<std::vector<float>> tensor_split;
|
||||
std::vector<bool> use_mmap;
|
||||
std::vector<bool> embeddings;
|
||||
@@ -264,7 +264,7 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* main_gpu */ {0},
|
||||
/* no_kv_offload */ {false},
|
||||
/* flash_attn */ {false},
|
||||
/* mla_attn */ {false},
|
||||
/* mla_attn */ {0},
|
||||
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
||||
/* use_mmap */ {true},
|
||||
/* embeddings */ {false},
|
||||
@@ -300,7 +300,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||
printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
|
||||
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
|
||||
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
||||
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
|
||||
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
||||
@@ -576,7 +576,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto p = string_split<bool>(argv[i], split_delim);
|
||||
auto p = string_split<int>(argv[i], split_delim);
|
||||
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
|
||||
} else if (arg == "-mmp" || arg == "--mmap") {
|
||||
if (++i >= argc) {
|
||||
@@ -726,7 +726,7 @@ struct cmd_params_instance {
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
bool mla_attn;
|
||||
int mla_attn;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool embeddings;
|
||||
@@ -955,7 +955,7 @@ struct test {
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
bool mla_attn;
|
||||
int mla_attn;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool embeddings;
|
||||
@@ -1097,13 +1097,13 @@ struct test {
|
||||
field == "n_threads" ||
|
||||
field == "model_size" || field == "model_n_params" ||
|
||||
field == "n_gpu_layers" || field == "main_gpu" ||
|
||||
field == "n_prompt" || field == "n_gen" ||
|
||||
field == "n_prompt" || field == "n_gen" || field == "mla_attn" ||
|
||||
field == "avg_ns" || field == "stddev_ns") {
|
||||
return INT;
|
||||
}
|
||||
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
||||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
||||
field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" ||
|
||||
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" ||
|
||||
field == "fused_moe") {
|
||||
return BOOL;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user