mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic
This commit is contained in:
committed by
GitHub
parent
cbf7fc7e2f
commit
09a88c9ae5
@@ -152,12 +152,17 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
|
||||
return false;
|
||||
}
|
||||
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
|
||||
params_base.speculative.model_dft = model_dft;
|
||||
params_base.speculative.cparams_dft = cparams_dft;
|
||||
|
||||
}
|
||||
else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) {
|
||||
LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {});
|
||||
params_base.has_mtp = false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -209,12 +214,35 @@ void server_context::init() {
|
||||
|
||||
slot.sparams = params_base.sparams;
|
||||
|
||||
if (params_base.has_mtp) {
|
||||
if (llama_model_n_nextn_layer(model) > 0) {
|
||||
SRV_INF("%s\n", "MTP detected, configuring for speculative decoding...");
|
||||
|
||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
|
||||
slot.has_mtp = true;
|
||||
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
slot.params.speculative.n_min = 0;
|
||||
|
||||
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
||||
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
||||
|
||||
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
|
||||
llama_set_embeddings(ctx, true);
|
||||
}
|
||||
else {
|
||||
SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative.");
|
||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
slot.has_mtp = false;
|
||||
}
|
||||
}
|
||||
|
||||
const bool can_spec = common_speculative_is_compat(ctx);
|
||||
if (!can_spec) {
|
||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||
}
|
||||
// try speculative decoding
|
||||
if (can_spec){
|
||||
if (can_spec) {
|
||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
||||
if (slot.spec) {
|
||||
if (mctx) {
|
||||
@@ -223,9 +251,14 @@ void server_context::init() {
|
||||
}
|
||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
||||
if (slot.has_mtp) {
|
||||
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
|
||||
} else {
|
||||
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slot.reset();
|
||||
|
||||
slots.push_back(std::move(slot));
|
||||
@@ -380,7 +413,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
|
||||
}
|
||||
|
||||
bool server_slot::can_speculate() const {
|
||||
return !!spec;
|
||||
return (!!spec || has_mtp);
|
||||
}
|
||||
|
||||
int server_slot::get_n_draft_max() const {
|
||||
@@ -2533,6 +2566,15 @@ void server_context::add_sampled_tokens() {
|
||||
|
||||
const auto & params_spec = slot.params.speculative;
|
||||
|
||||
if (slot.has_mtp) {
|
||||
if (!slot.mtp_hidden_state.empty()) {
|
||||
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
|
||||
} else {
|
||||
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1));
|
||||
}
|
||||
}
|
||||
|
||||
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
if (draft.size() > (size_t)n_draft_max) {
|
||||
@@ -2540,13 +2582,6 @@ void server_context::add_sampled_tokens() {
|
||||
draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
/*struct llama_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
llama_tokens draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);*/
|
||||
|
||||
// add the sampled token to the batch
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
@@ -2759,6 +2794,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_buffer = 0;
|
||||
slot.token_buffer.clear();
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("prompt tokenized", {
|
||||
@@ -3092,6 +3129,25 @@ void server_context::speculative_decoding_accept() {
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
||||
|
||||
if (slot.has_mtp) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
if (!ids.empty()) {
|
||||
const float* emb = llama_get_embeddings_ith(ctx, ids.size() - 1);
|
||||
if (emb) {
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
}
|
||||
}
|
||||
else {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0));
|
||||
}
|
||||
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
|
||||
|
||||
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
|
||||
mtp_accept_tokens(ctx, ids, n_past_base, slot.id);
|
||||
}
|
||||
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
@@ -3362,6 +3418,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
|
||||
if (params_base.has_mtp && slot.n_decoded == 0) {
|
||||
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
|
||||
mtp_update_kv_cache(ctx, batch_view, true);
|
||||
const float* emb = llama_get_embeddings_ith(ctx, -1);
|
||||
if (emb) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
slot.n_decoded += 1;
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
@@ -3394,7 +3461,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
|
||||
if (params_base.has_mtp) {
|
||||
for (auto& slot : slots) {
|
||||
if (slot.n_past < slot.n_prompt_tokens) {
|
||||
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
|
||||
mtp_update_kv_cache(ctx, batch_view, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// speculative decoding - main model sample and accept
|
||||
speculative_decoding_accept();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user