Add optional MLA (#188)

* Deepseek MLA Optimizations

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>

* Make MLA optional

* Remove some unnecessary copies in the MLA attention

* Deepseek MLA Optimizations V2 (#195)

* Avoid allocating MHA KV cache when MLA is turned on

* Added missing gguf-py file

* Added final optimizations

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>

* Make sure we do have wk_b and wv_b before enabling MLA

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

* Use type_k and type_v to set the types of the MLA caches

They were hard-coded at f16.
On my Ryzen-7950X with native bf16 support I get a fairly
significant PP performance boost with bf16 KV-cache:
PP-4096 = 320 t/s up from 292 t/s with fp16 KV-cache.

* Better gemm strategy when nth > nhead

It gives a ~10% PP performance boost for DeepSeek-Lite with 32 threads
(with or without MLA).
Before this commit, when nth > nhead heads were processed
sequentially with all nth threads participating in each
matrix multiplication. Now we ind the gcd of nhead and
nth and split threads into nth/gcd groups, each group
processing nhead/gcd heads.

---------

Co-authored-by: Saood Karim <saood05@gmail.com>
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-02-09 19:48:44 +02:00
committed by GitHub
parent db7eabb111
commit 3e536b95b0
9 changed files with 380 additions and 75 deletions

View File

@@ -813,6 +813,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.flash_attn = true;
return true;
}
if (arg == "-mla" || arg == "--mla-use") {
params.mla_attn = true;
return true;
}
if (arg == "-co" || arg == "--color") {
params.use_color = true;
return true;
@@ -1452,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
@@ -2283,6 +2288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -3280,6 +3286,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());

View File

@@ -174,6 +174,7 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool mla_attn = false; // MLA
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens