Add MTP decoding support for GLM-4.x MoE (#1270)

* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic
This commit is contained in:
Samuel Oliveira Alves
2026-02-22 14:14:39 -03:00
committed by GitHub
parent cbf7fc7e2f
commit 09a88c9ae5
16 changed files with 820 additions and 206 deletions

View File

@@ -1463,6 +1463,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cuda_params = argv[i];
return true;
}
if (arg == "-mtp" || arg == "--multi-token-prediction") {
params.has_mtp = true;
return true;
}
if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") {
params.has_mtp = false;
return true;
}
if (arg == "-draft" || arg == "--draft-params") {
CHECK_ARG
params.speculative.params = argv[i];
@@ -2475,6 +2483,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" });
options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" });
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
@@ -3207,6 +3217,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
mparams.validate_quants = params.validate_quants;
mparams.merge_qkv = params.merge_qkv;
mparams.merge_up_gate_exps = params.merge_up_gate_exps;
mparams.mtp = params.has_mtp;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
@@ -3329,6 +3340,8 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
cparams.thresh_experts = params.thresh_experts;
cparams.only_active_experts = params.only_active_exps;
cparams.max_extra_alloc = params.max_extra_alloc_MiB;
cparams.mtp = params.has_mtp;
cparams.mtp_op_type = MTP_OP_NONE;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);