Add MTP decoding support for GLM-4.x MoE (#1270)

* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic
This commit is contained in:
Samuel Oliveira Alves
2026-02-22 14:14:39 -03:00
committed by GitHub
parent cbf7fc7e2f
commit 09a88c9ae5
16 changed files with 820 additions and 206 deletions

View File

@@ -279,6 +279,12 @@ extern "C" {
LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs
};
enum llama_mtp_op_type {
MTP_OP_NONE = 0,
MTP_OP_WARMUP = 1,
MTP_OP_UPDATE_ACCEPTED = 2,
MTP_OP_DRAFT_GEN = 3,
};
typedef struct llama_token_data {
llama_token id; // token id
@@ -394,6 +400,7 @@ extern "C" {
bool validate_quants; // if true, check for NaNs while loading the model
bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
bool merge_up_gate_exps; // if true, merge ffn_up_exps and ffn_gate_exps tensors into a single, contiguous tensor
bool mtp; // if true, load MTP layers if present
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
@@ -449,6 +456,8 @@ extern "C" {
bool split_mode_graph_scheduling; // if true, force split mode graph scheduling
//bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs
bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads
bool mtp; // Activate MTP if supported
enum llama_mtp_op_type mtp_op_type;
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
@@ -1463,6 +1472,17 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
//
// MTP
//
LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);
// Set which, if any, MTP operation the context will use
LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type mtp_op_type);
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
#ifdef __cplusplus
}
#endif