Add MTP decoding support for GLM-4.x MoE (#1270)

* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic
This commit is contained in:
Samuel Oliveira Alves
2026-02-22 14:14:39 -03:00
committed by GitHub
parent cbf7fc7e2f
commit 09a88c9ae5
16 changed files with 820 additions and 206 deletions

View File

@@ -20,6 +20,7 @@
const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
@@ -31,6 +32,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
@@ -144,6 +146,58 @@ struct common_speculative_state {
virtual void accept(uint16_t n_accepted) = 0;
};
struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
common_sampler * smpl;
common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
{
struct common_params_sampling params;
params.samplers_sequence = {
llama_sampler_type::DIST,
};
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
}
~common_speculative_state_mtp() override {
common_sampler_free(smpl);
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
int32_t n_past = (int32_t)prompt_tgt.size();
llama_seq_id seq_id = 0;
result = mtp_speculative_gen_draft(
smpl,
ctx_tgt,
params.n_max,
params.p_min,
id_last,
n_past,
seq_id
);
}
void accept(uint16_t n_accepted) override {
GGML_UNUSED(n_accepted);
}
};
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
@@ -760,6 +814,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
@@ -828,6 +883,7 @@ common_speculative * common_speculative_init(
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -867,6 +923,9 @@ common_speculative * common_speculative_init(
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
}
@@ -890,6 +949,12 @@ common_speculative * common_speculative_init(
));
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
/* .ctx_tgt = */ ctx_tgt
));
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
@@ -1047,3 +1112,112 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_perf.c_str());
}
}
// ----------------------------------------------------------------------------
// MTP
// ----------------------------------------------------------------------------
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id) {
llama_tokens drafts;
drafts.reserve(n_draft);
if (!smpl) return drafts;
common_sampler_reset(smpl);
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
llama_token current_input_id = id_last;
int32_t current_n_past = n_past;
for (int i = 0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
if (llama_decode(ctx, mtp_batch) != 0) {
break;
}
common_sampler_sample(smpl, ctx, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
if (!cur_p || cur_p->size == 0) {
break;
}
const llama_token id_next = cur_p->data[0].id;
const float prob = cur_p->data[0].p;
common_sampler_accept(smpl, nullptr, id_next, true);
if (prob < p_min) {
break;
}
drafts.push_back(id_next);
current_input_id = id_next;
current_n_past++;
}
llama_batch_free(mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
// Purge the metadata for the draft tokens.
// This prevents cache state corruption where two cells map to the same logical position.
if (!drafts.empty()) {
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
}
return drafts;
}
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
if (batch.n_tokens == 0) {
return;
}
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
llama_batch mtp_batch = batch;
if (is_prompt_warmup) {
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
} else {
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
}
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
llama_decode(ctx, mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
}
void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
) {
if (ids.empty()) {
return;
}
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
}
mtp_update_kv_cache(ctx, accepted_batch, false);
llama_batch_free(accepted_batch);
}