mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-25 08:59:30 +00:00
Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic
This commit is contained in:
committed by
GitHub
parent
cbf7fc7e2f
commit
09a88c9ae5
@@ -1463,6 +1463,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.cuda_params = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "-mtp" || arg == "--multi-token-prediction") {
|
||||
params.has_mtp = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") {
|
||||
params.has_mtp = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-draft" || arg == "--draft-params") {
|
||||
CHECK_ARG
|
||||
params.speculative.params = argv[i];
|
||||
@@ -2475,6 +2483,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
|
||||
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
|
||||
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
|
||||
options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" });
|
||||
options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" });
|
||||
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
|
||||
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
|
||||
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
|
||||
@@ -3207,6 +3217,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
|
||||
mparams.validate_quants = params.validate_quants;
|
||||
mparams.merge_qkv = params.merge_qkv;
|
||||
mparams.merge_up_gate_exps = params.merge_up_gate_exps;
|
||||
mparams.mtp = params.has_mtp;
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
} else {
|
||||
@@ -3329,6 +3340,8 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.only_active_experts = params.only_active_exps;
|
||||
cparams.max_extra_alloc = params.max_extra_alloc_MiB;
|
||||
cparams.mtp = params.has_mtp;
|
||||
cparams.mtp_op_type = MTP_OP_NONE;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||
|
||||
@@ -139,6 +139,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
@@ -356,6 +357,7 @@ struct gpt_params {
|
||||
bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling
|
||||
//bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
|
||||
bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph
|
||||
bool has_mtp = false; // enable MTP if supported by the model
|
||||
|
||||
std::string cache_type_k = "f16"; // KV cache data type for the K
|
||||
std::string cache_type_v = "f16"; // KV cache data type for the V
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_MTP,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
||||
@@ -31,6 +32,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
||||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
@@ -144,6 +146,58 @@ struct common_speculative_state {
|
||||
virtual void accept(uint16_t n_accepted) = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
common_sampler * smpl;
|
||||
|
||||
common_speculative_state_mtp(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
{
|
||||
struct common_params_sampling params;
|
||||
params.samplers_sequence = {
|
||||
llama_sampler_type::DIST,
|
||||
};
|
||||
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
|
||||
}
|
||||
|
||||
~common_speculative_state_mtp() override {
|
||||
common_sampler_free(smpl);
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
|
||||
int32_t n_past = (int32_t)prompt_tgt.size();
|
||||
|
||||
llama_seq_id seq_id = 0;
|
||||
|
||||
result = mtp_speculative_gen_draft(
|
||||
smpl,
|
||||
ctx_tgt,
|
||||
params.n_max,
|
||||
params.p_min,
|
||||
id_last,
|
||||
n_past,
|
||||
seq_id
|
||||
);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
@@ -760,6 +814,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
|
||||
@@ -828,6 +883,7 @@ common_speculative * common_speculative_init(
|
||||
{
|
||||
bool has_draft = !params.mparams_dft.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
|
||||
@@ -867,6 +923,9 @@ common_speculative * common_speculative_init(
|
||||
if (has_ngram_cache) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
|
||||
}
|
||||
if (has_draft) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
|
||||
}
|
||||
@@ -890,6 +949,12 @@ common_speculative * common_speculative_init(
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
|
||||
break;
|
||||
@@ -1047,3 +1112,112 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_perf.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// MTP
|
||||
// ----------------------------------------------------------------------------
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
int n_draft,
|
||||
float p_min,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id) {
|
||||
|
||||
llama_tokens drafts;
|
||||
drafts.reserve(n_draft);
|
||||
|
||||
if (!smpl) return drafts;
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
|
||||
|
||||
llama_token current_input_id = id_last;
|
||||
int32_t current_n_past = n_past;
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
|
||||
if (llama_decode(ctx, mtp_batch) != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_sampler_sample(smpl, ctx, 0, true);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
if (!cur_p || cur_p->size == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
const llama_token id_next = cur_p->data[0].id;
|
||||
const float prob = cur_p->data[0].p;
|
||||
|
||||
common_sampler_accept(smpl, nullptr, id_next, true);
|
||||
|
||||
if (prob < p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
drafts.push_back(id_next);
|
||||
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
}
|
||||
llama_batch_free(mtp_batch);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
||||
|
||||
// Purge the metadata for the draft tokens.
|
||||
// This prevents cache state corruption where two cells map to the same logical position.
|
||||
if (!drafts.empty()) {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
|
||||
}
|
||||
|
||||
return drafts;
|
||||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
|
||||
if (batch.n_tokens == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
|
||||
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
|
||||
} else {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = true;
|
||||
}
|
||||
llama_decode(ctx, mtp_batch);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
||||
}
|
||||
|
||||
void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
) {
|
||||
if (ids.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
|
||||
}
|
||||
|
||||
mtp_update_kv_cache(ctx, accepted_batch, false);
|
||||
|
||||
llama_batch_free(accepted_batch);
|
||||
}
|
||||
@@ -39,3 +39,22 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
int n_draft,
|
||||
float p_min,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
|
||||
|
||||
void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user