Add MTP decoding support for GLM-4.x MoE (#1270)

* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic
This commit is contained in:
Samuel Oliveira Alves
2026-02-22 14:14:39 -03:00
committed by GitHub
parent cbf7fc7e2f
commit 09a88c9ae5
16 changed files with 820 additions and 206 deletions

View File

@@ -546,6 +546,7 @@ struct llama_context::Prev {
int all_seq_id;
int n_outputs;
int n_kv;
llama_mtp_op_type mtp_op_type;
ggml_cgraph * graph;
};
@@ -563,11 +564,13 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
kv_self.head > 0 &&
kv_self.n == prev->n_kv &&
n_outputs == prev->n_outputs &&
cparams.mtp_op_type == prev->mtp_op_type &&
update_cache_copies();
}
bool llama_context::update_cache_copies() {
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
const int n_layer = model.mtp ? model.hparams.n_layer
: model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
auto layer_has_attention_kv = [&](int il) {
return !((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && model.hparams.is_recurrent(il));
};
@@ -638,6 +641,12 @@ llama_context::llama_context(const llama_model & model)
}
}
void llama_context::set_mtp_op_type(llama_mtp_op_type value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
cparams.mtp_op_type = value;
}
llama_context::~llama_context() {
ggml_backend_sched_free(sched);
@@ -716,7 +725,8 @@ static bool llama_kv_cache_init(
const struct llama_hparams & hparams = model.hparams;
const int64_t n_layer = hparams.n_layer - hparams.nextn_predict_layers;
const int64_t n_layer = model.mtp ? hparams.n_layer
: hparams.n_layer - hparams.nextn_predict_layers;
cache.has_shift = false;
@@ -993,7 +1003,8 @@ static bool llama_kv_cache_init(
// to the first cell of the slot.
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
const struct llama_batch & batch,
enum llama_mtp_op_type op_type) {
const uint32_t n_tokens = batch.n_tokens;
if (cache.recurrent) {
@@ -1044,6 +1055,45 @@ static bool llama_kv_cache_find_slot(
}
// otherwise, one cell per token.
bool is_mtp_special_op = (op_type == MTP_OP_WARMUP ||
op_type == MTP_OP_UPDATE_ACCEPTED);
if (is_mtp_special_op) {
const llama_pos target_pos = batch.pos[0];
const llama_seq_id target_seq = batch.seq_id[0][0];
bool found = false;
if (cache.head < cache.size &&
cache.cells[cache.head].pos == target_pos &&
cache.cells[cache.head].has_seq_id(target_seq)) {
found = true;
}
else {
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos == target_pos &&
cache.cells[i].has_seq_id(target_seq)) {
cache.head = i;
found = true;
break;
}
}
}
if (!found) {
LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n",
__func__, target_seq, target_pos);
return false;
}
if (cache.head + n_tokens > cache.size) {
LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__);
return false;
}
return true;
}
if (n_tokens > cache.size) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
return false;
@@ -1893,6 +1943,7 @@ static bool llm_load_tensors(
const float * tensor_split,
bool use_mlock,
bool validate_quants,
bool mtp,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
model.t_start_us = ggml_time_us();
@@ -1921,6 +1972,7 @@ static bool llm_load_tensors(
model.main_gpu = main_gpu;
model.max_gpu = max_gpu;
model.n_gpu_layers = n_gpu_layers;
model.mtp = mtp;
const int n_layer = hparams.n_layer;
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
@@ -2300,7 +2352,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
if (!llm_load_tensors(
ml, model, params.n_gpu_layers, params.mla, params.split_mode, params.main_gpu, params.max_gpu, params.tensor_split,
params.use_mlock, params.validate_quants,
params.use_mlock, params.validate_quants, params.mtp,
params.progress_callback, params.progress_callback_user_data
)) {
return -2;
@@ -2969,8 +3021,9 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
const bool has_logits = !cparams.embeddings;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp;
const bool has_logits = !cparams.embeddings || has_mtp;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp;
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
@@ -3049,6 +3102,24 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
}
static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
ggml_tensor * dst = lctx.inp_mtp_states;
const float * src = nullptr;
if (lctx.cparams.mtp_op_type == MTP_OP_WARMUP || lctx.cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
src = lctx.embd;
} else {
src = lctx.draft_input_hidden_state;
}
if (!src) {
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
return false;
}
ggml_backend_tensor_set(dst, src, 0, ggml_nbytes(dst));
return true;
}
// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
@@ -3260,7 +3331,7 @@ static int llama_decode_internal(
kv_self.head = 0;
}
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) {
return 1;
}
@@ -3322,37 +3393,50 @@ static int llama_decode_internal(
#endif
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf});
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
cparams.mtp_op_type, gf});
}
} else {
//printf("Reusing graph\n");
gf = lctx.prev->graph;
}
if (cparams.mtp_op_type != MTP_OP_NONE) {
if (!prepare_mtp_graph_inputs(lctx)) {
return GGML_STATUS_FAILED;
}
}
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
struct ggml_tensor * embd = nullptr;
if (lctx.n_outputs == 0) {
// no output
res = nullptr;
embd = nullptr;
} else if (cparams.embeddings) {
res = nullptr; // do not extract logits for embedding case
embd = nullptr;
for (int i = gf->n_nodes - 1; i >= 0; --i) {
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
embd = gf->nodes[i];
break;
res = nullptr;
}
else {
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
if (cparams.embeddings || has_mtp) {
for (int i = gf->n_nodes - 1; i >= 0; --i) {
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
embd = gf->nodes[i];
break;
}
if (strcmp(gf->nodes[i]->name, "result_norm") == 0) {
embd = gf->nodes[i];
}
}
}
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) {
res = nullptr; // do not extract logits for embedding case
} else {
if (!embd) { // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
}
GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
#if IK_PRINT_TIMING == 1
tim1 = ggml_time_us();
#endif
@@ -3392,17 +3476,21 @@ static int llama_decode_internal(
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(lctx.logits != nullptr);
// Do not process logits if MTP is only updating the KV cache.
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(lctx.logits != nullptr);
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
const int32_t n_outputs_new = lctx.n_outputs;
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
const int32_t n_outputs_new = lctx.n_outputs;
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
}
}
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
@@ -3411,7 +3499,7 @@ static int llama_decode_internal(
}
// extract embeddings
if (embd) {
if (embd && cparams.mtp_op_type == MTP_OP_NONE) {
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
@@ -3617,57 +3705,59 @@ static int llama_encode_internal(
// extract embeddings
if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
if (cparams.mtp_op_type == MTP_OP_NONE) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
if (llama_model_has_decoder(&lctx.model)) {
lctx.embd_enc.resize(n_tokens*n_embd);
float * embd_out = lctx.embd_enc.data();
if (llama_model_has_decoder(&lctx.model)) {
lctx.embd_enc.resize(n_tokens*n_embd);
float * embd_out = lctx.embd_enc.data();
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
// remember the sequence ids used during the encoding - needed for cross attention later
lctx.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
lctx.seq_ids_enc[i].insert(seq_id);
}
}
} else {
GGML_ASSERT(lctx.embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
// remember the sequence ids used during the encoding - needed for cross attention later
lctx.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
lctx.seq_ids_enc[i].insert(seq_id);
}
}
} else {
GGML_ASSERT(lctx.embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
}
}
}
}
@@ -4223,6 +4313,7 @@ struct llama_model_params llama_model_default_params() {
/*.validate_quants =*/ false,
/*.merge_qkv =*/ false,
/*.merge_up_gate_exps =*/ false,
/*.mtp =*/ false,
};
#ifdef GGML_USE_METAL
@@ -4278,6 +4369,8 @@ struct llama_context_params llama_context_default_params() {
/*.split_mode_graph_scheduling =*/ false,
// /*.split_mode_f16 =*/ true,
/*.scheduler_async =*/ false,
/*.mtp =*/ false,
/*.mtp_op_type =*/ MTP_OP_NONE,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
/*.offload_policy =*/ nullptr,
@@ -4648,6 +4741,7 @@ struct llama_context * llama_init_from_model(
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.cuda_params = params.cuda_params;
cparams.mtp = params.mtp;
cparams.reduce_type = params.type_reduce;
cparams.pooling_type = params.pooling_type;
@@ -4725,6 +4819,12 @@ struct llama_context * llama_init_from_model(
}
}
if (model->arch != LLM_ARCH_GLM4_MOE && cparams.mtp != 0) {
cparams.mtp = 0;
}
cparams.mtp_op_type = params.mtp_op_type;
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
@@ -6058,7 +6158,7 @@ struct llama_data_read {
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = dest_seq_id;
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) {
llama_batch_free(batch);
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
@@ -7003,6 +7103,10 @@ int32_t llama_decode(
return ret;
}
void llama_set_mtp_op_type(llama_context * ctx, llama_mtp_op_type mtp_op_type) {
ctx->set_mtp_op_type(mtp_op_type);
}
void llama_synchronize(struct llama_context * ctx) {
ggml_backend_sched_synchronize(ctx->sched);
@@ -8333,3 +8437,8 @@ void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_of
printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXX offload(%s) = %d\n", op_name, on_or_off);
ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off);
}
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
ctx->draft_input_hidden_state = hidden_state;
}