Add MTP decoding support for GLM-4.x MoE (#1270)

* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic
This commit is contained in:
Samuel Oliveira Alves
2026-02-22 14:14:39 -03:00
committed by GitHub
parent cbf7fc7e2f
commit 09a88c9ae5
16 changed files with 820 additions and 206 deletions

View File

@@ -152,12 +152,17 @@ bool server_context::load_model(const gpt_params& params_) {
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
return false;
}
cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.model_dft = model_dft;
params_base.speculative.cparams_dft = cparams_dft;
}
else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) {
LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {});
params_base.has_mtp = false;
}
return true;
}
@@ -209,12 +214,35 @@ void server_context::init() {
slot.sparams = params_base.sparams;
if (params_base.has_mtp) {
if (llama_model_n_nextn_layer(model) > 0) {
SRV_INF("%s\n", "MTP detected, configuring for speculative decoding...");
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
slot.has_mtp = true;
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
slot.params.speculative.n_min = 0;
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
llama_set_embeddings(ctx, true);
}
else {
SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative.");
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
slot.has_mtp = false;
}
}
const bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// try speculative decoding
if (can_spec){
if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
@@ -223,9 +251,14 @@ void server_context::init() {
}
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
if (slot.has_mtp) {
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
}
}
}
slot.reset();
slots.push_back(std::move(slot));
@@ -380,7 +413,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
}
bool server_slot::can_speculate() const {
return !!spec;
return (!!spec || has_mtp);
}
int server_slot::get_n_draft_max() const {
@@ -2533,6 +2566,15 @@ void server_context::add_sampled_tokens() {
const auto & params_spec = slot.params.speculative;
if (slot.has_mtp) {
if (!slot.mtp_hidden_state.empty()) {
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
} else {
LOG_ERROR("MTP hidden state is empty during speculation", {});
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1));
}
}
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > (size_t)n_draft_max) {
@@ -2540,13 +2582,6 @@ void server_context::add_sampled_tokens() {
draft.resize(n_draft_max);
}
/*struct llama_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
llama_tokens draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);*/
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
@@ -2759,6 +2794,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
}
slot.n_past = 0;
slot.n_buffer = 0;
slot.token_buffer.clear();
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("prompt tokenized", {
@@ -3092,6 +3129,25 @@ void server_context::speculative_decoding_accept() {
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
if (slot.has_mtp) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
if (!ids.empty()) {
const float* emb = llama_get_embeddings_ith(ctx, ids.size() - 1);
if (emb) {
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
}
}
else {
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0));
}
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
mtp_accept_tokens(ctx, ids, n_past_base, slot.id);
}
slot.i_batch_dft.clear();
slot.drafted.clear();
@@ -3362,6 +3418,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
if (params_base.has_mtp && slot.n_decoded == 0) {
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
mtp_update_kv_cache(ctx, batch_view, true);
const float* emb = llama_get_embeddings_ith(ctx, -1);
if (emb) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
}
}
}
slot.n_decoded += 1;
const int64_t t_current = ggml_time_us();
@@ -3394,7 +3461,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.i_batch = -1;
}
if (params_base.has_mtp) {
for (auto& slot : slots) {
if (slot.n_past < slot.n_prompt_tokens) {
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
mtp_update_kv_cache(ctx, batch_view, true);
}
}
}
}
// speculative decoding - main model sample and accept
speculative_decoding_accept();
}