mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic
This commit is contained in:
committed by
GitHub
parent
cbf7fc7e2f
commit
09a88c9ae5
@@ -20,6 +20,7 @@
|
||||
const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_MTP,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
||||
@@ -31,6 +32,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
||||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
@@ -144,6 +146,58 @@ struct common_speculative_state {
|
||||
virtual void accept(uint16_t n_accepted) = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
common_sampler * smpl;
|
||||
|
||||
common_speculative_state_mtp(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
{
|
||||
struct common_params_sampling params;
|
||||
params.samplers_sequence = {
|
||||
llama_sampler_type::DIST,
|
||||
};
|
||||
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
|
||||
}
|
||||
|
||||
~common_speculative_state_mtp() override {
|
||||
common_sampler_free(smpl);
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
|
||||
int32_t n_past = (int32_t)prompt_tgt.size();
|
||||
|
||||
llama_seq_id seq_id = 0;
|
||||
|
||||
result = mtp_speculative_gen_draft(
|
||||
smpl,
|
||||
ctx_tgt,
|
||||
params.n_max,
|
||||
params.p_min,
|
||||
id_last,
|
||||
n_past,
|
||||
seq_id
|
||||
);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
@@ -760,6 +814,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
|
||||
@@ -828,6 +883,7 @@ common_speculative * common_speculative_init(
|
||||
{
|
||||
bool has_draft = !params.mparams_dft.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
|
||||
@@ -867,6 +923,9 @@ common_speculative * common_speculative_init(
|
||||
if (has_ngram_cache) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
|
||||
}
|
||||
if (has_draft) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
|
||||
}
|
||||
@@ -890,6 +949,12 @@ common_speculative * common_speculative_init(
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
|
||||
break;
|
||||
@@ -1047,3 +1112,112 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_perf.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// MTP
|
||||
// ----------------------------------------------------------------------------
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
int n_draft,
|
||||
float p_min,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id) {
|
||||
|
||||
llama_tokens drafts;
|
||||
drafts.reserve(n_draft);
|
||||
|
||||
if (!smpl) return drafts;
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
|
||||
|
||||
llama_token current_input_id = id_last;
|
||||
int32_t current_n_past = n_past;
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
|
||||
if (llama_decode(ctx, mtp_batch) != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_sampler_sample(smpl, ctx, 0, true);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
if (!cur_p || cur_p->size == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
const llama_token id_next = cur_p->data[0].id;
|
||||
const float prob = cur_p->data[0].p;
|
||||
|
||||
common_sampler_accept(smpl, nullptr, id_next, true);
|
||||
|
||||
if (prob < p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
drafts.push_back(id_next);
|
||||
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
}
|
||||
llama_batch_free(mtp_batch);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
||||
|
||||
// Purge the metadata for the draft tokens.
|
||||
// This prevents cache state corruption where two cells map to the same logical position.
|
||||
if (!drafts.empty()) {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
|
||||
}
|
||||
|
||||
return drafts;
|
||||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
|
||||
if (batch.n_tokens == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
|
||||
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
|
||||
} else {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = true;
|
||||
}
|
||||
llama_decode(ctx, mtp_batch);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
||||
}
|
||||
|
||||
void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
) {
|
||||
if (ids.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
|
||||
}
|
||||
|
||||
mtp_update_kv_cache(ctx, accepted_batch, false);
|
||||
|
||||
llama_batch_free(accepted_batch);
|
||||
}
|
||||
Reference in New Issue
Block a user