mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic
This commit is contained in:
committed by
GitHub
parent
cbf7fc7e2f
commit
09a88c9ae5
@@ -1463,6 +1463,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.cuda_params = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "-mtp" || arg == "--multi-token-prediction") {
|
||||
params.has_mtp = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") {
|
||||
params.has_mtp = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-draft" || arg == "--draft-params") {
|
||||
CHECK_ARG
|
||||
params.speculative.params = argv[i];
|
||||
@@ -2475,6 +2483,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
|
||||
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
|
||||
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
|
||||
options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" });
|
||||
options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" });
|
||||
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
|
||||
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
|
||||
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
|
||||
@@ -3207,6 +3217,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
|
||||
mparams.validate_quants = params.validate_quants;
|
||||
mparams.merge_qkv = params.merge_qkv;
|
||||
mparams.merge_up_gate_exps = params.merge_up_gate_exps;
|
||||
mparams.mtp = params.has_mtp;
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
} else {
|
||||
@@ -3329,6 +3340,8 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.only_active_experts = params.only_active_exps;
|
||||
cparams.max_extra_alloc = params.max_extra_alloc_MiB;
|
||||
cparams.mtp = params.has_mtp;
|
||||
cparams.mtp_op_type = MTP_OP_NONE;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||
|
||||
Reference in New Issue
Block a user