From 09a88c9ae5167d708aa4151b439dc223b4e322f9 Mon Sep 17 00:00:00 2001 From: Samuel Oliveira Alves <107287165+SamuelOliveirads@users.noreply.github.com> Date: Sun, 22 Feb 2026 14:14:39 -0300 Subject: [PATCH] Add MTP decoding support for GLM-4.x MoE (#1270) * wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic --- common/common.cpp | 13 + common/common.h | 2 + common/speculative.cpp | 174 +++++++++++++ common/speculative.h | 19 ++ examples/mtmd/clip.cpp | 2 +- examples/server/server-context.cpp | 97 +++++++- examples/server/server-context.h | 3 + include/llama.h | 20 ++ src/llama-build-context.cpp | 386 +++++++++++++++++++++-------- src/llama-build-context.h | 10 + src/llama-context.h | 7 +- src/llama-cparams.h | 2 + src/llama-hparams.cpp | 6 + src/llama-load-tensors.cpp | 12 +- src/llama-model.h | 2 + src/llama.cpp | 271 ++++++++++++++------ 16 files changed, 820 insertions(+), 206 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 4fc0b9e5..6e38a69b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1463,6 +1463,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cuda_params = argv[i]; return true; } + if (arg == "-mtp" || arg == "--multi-token-prediction") { + params.has_mtp = true; + return true; + } + if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") { + params.has_mtp = false; + return true; + } if (arg == "-draft" || arg == "--draft-params") { CHECK_ARG params.speculative.params = argv[i]; @@ -2475,6 +2483,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); + options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" }); + options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" }); options.push_back({ "*", "--draft-max, --draft, --draft-n N", "number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max }); options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" }); @@ -3207,6 +3217,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params mparams.validate_quants = params.validate_quants; mparams.merge_qkv = params.merge_qkv; mparams.merge_up_gate_exps = params.merge_up_gate_exps; + mparams.mtp = params.has_mtp; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -3329,6 +3340,8 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; cparams.max_extra_alloc = params.max_extra_alloc_MiB; + cparams.mtp = params.has_mtp; + cparams.mtp_op_type = MTP_OP_NONE; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/common/common.h b/common/common.h index 8a958cdd..26cd520e 100644 --- a/common/common.h +++ b/common/common.h @@ -139,6 +139,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format); enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT, // draft model + COMMON_SPECULATIVE_TYPE_MTP, // MTP model COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only @@ -356,6 +357,7 @@ struct gpt_params { bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling //bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph + bool has_mtp = false; // enable MTP if supported by the model std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/common/speculative.cpp b/common/speculative.cpp index 03d2208b..c130be24 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -20,6 +20,7 @@ const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, + COMMON_SPECULATIVE_TYPE_MTP, COMMON_SPECULATIVE_TYPE_EAGLE3, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, @@ -31,6 +32,7 @@ const std::vector common_speculative_types = { const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, + {"mtp", COMMON_SPECULATIVE_TYPE_MTP}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, @@ -144,6 +146,58 @@ struct common_speculative_state { virtual void accept(uint16_t n_accepted) = 0; }; +struct common_speculative_state_mtp : public common_speculative_state { + llama_context * ctx_tgt; + common_sampler * smpl; + + common_speculative_state_mtp( + enum common_speculative_type type, + llama_context * ctx_tgt) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + { + struct common_params_sampling params; + params.samplers_sequence = { + llama_sampler_type::DIST, + }; + smpl = common_sampler_init(llama_get_model(ctx_tgt), params); + } + + ~common_speculative_state_mtp() override { + common_sampler_free(smpl); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + + int32_t n_past = (int32_t)prompt_tgt.size(); + + llama_seq_id seq_id = 0; + + result = mtp_speculative_gen_draft( + smpl, + ctx_tgt, + params.n_max, + params.p_min, + id_last, + n_past, + seq_id + ); + } + + void accept(uint16_t n_accepted) override { + GGML_UNUSED(n_accepted); + } +}; + + struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; @@ -760,6 +814,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; @@ -828,6 +883,7 @@ common_speculative * common_speculative_init( { bool has_draft = !params.mparams_dft.path.empty(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP); bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); @@ -867,6 +923,9 @@ common_speculative * common_speculative_init( if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_mtp) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + } if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } @@ -890,6 +949,12 @@ common_speculative * common_speculative_init( )); break; } + case COMMON_SPECULATIVE_TYPE_MTP: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt + )); + break; + } case COMMON_SPECULATIVE_TYPE_EAGLE3: { impls.push_back(std::make_unique(config.type)); break; @@ -1047,3 +1112,112 @@ void common_speculative_print_stats(const common_speculative * spec) { str_perf.c_str()); } } + +// ---------------------------------------------------------------------------- +// MTP +// ---------------------------------------------------------------------------- +std::vector mtp_speculative_gen_draft( + struct common_sampler * smpl, + struct llama_context * ctx, + int n_draft, + float p_min, + llama_token id_last, + int32_t n_past, + llama_seq_id seq_id) { + + llama_tokens drafts; + drafts.reserve(n_draft); + + if (!smpl) return drafts; + + common_sampler_reset(smpl); + + llama_batch mtp_batch = llama_batch_init(1, 0, 1); + llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN); + + llama_token current_input_id = id_last; + int32_t current_n_past = n_past; + + for (int i = 0; i < n_draft; ++i) { + mtp_batch.n_tokens = 0; + common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true); + + if (llama_decode(ctx, mtp_batch) != 0) { + break; + } + + common_sampler_sample(smpl, ctx, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + if (!cur_p || cur_p->size == 0) { + break; + } + + const llama_token id_next = cur_p->data[0].id; + const float prob = cur_p->data[0].p; + + common_sampler_accept(smpl, nullptr, id_next, true); + + if (prob < p_min) { + break; + } + + drafts.push_back(id_next); + + current_input_id = id_next; + current_n_past++; + } + llama_batch_free(mtp_batch); + llama_set_mtp_op_type(ctx, MTP_OP_NONE); + + // Purge the metadata for the draft tokens. + // This prevents cache state corruption where two cells map to the same logical position. + if (!drafts.empty()) { + llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past); + } + + return drafts; +} + + +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { + if (batch.n_tokens == 0) { + return; + } + + LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + + llama_batch mtp_batch = batch; + if (is_prompt_warmup) { + llama_set_mtp_op_type(ctx, MTP_OP_WARMUP); + } else { + llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED); + } + + for (int i = 0; i < mtp_batch.n_tokens; ++i) { + mtp_batch.logits[i] = true; + } + llama_decode(ctx, mtp_batch); + llama_set_mtp_op_type(ctx, MTP_OP_NONE); +} + +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +) { + if (ids.empty()) { + return; + } + + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); + } + + mtp_update_kv_cache(ctx, accepted_batch, false); + + llama_batch_free(accepted_batch); +} \ No newline at end of file diff --git a/common/speculative.h b/common/speculative.h index 876cde3d..fdaee241 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -39,3 +39,22 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); + +// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture. +std::vector mtp_speculative_gen_draft( + struct common_sampler * smpl, + struct llama_context * ctx, + int n_draft, + float p_min, + llama_token id_last, + int32_t n_past, + llama_seq_id seq_id); + +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); + +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +); diff --git a/examples/mtmd/clip.cpp b/examples/mtmd/clip.cpp index ce627687..2eb3fdf6 100644 --- a/examples/mtmd/clip.cpp +++ b/examples/mtmd/clip.cpp @@ -35,7 +35,7 @@ #include #include -#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS) +#define DEFAULT_INTERPOLATION_MODE ((int)GGML_SCALE_MODE_BILINEAR | (int)GGML_SCALE_FLAG_ALIGN_CORNERS) // TODO: allow to pass callback from user code struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 104cb793..6d79cce4 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -152,12 +152,17 @@ bool server_context::load_model(const gpt_params& params_) { LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} }); return false; } + cparams_dft = common_context_params_to_llama(params_dft); params_base.speculative.model_dft = model_dft; params_base.speculative.cparams_dft = cparams_dft; } + else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) { + LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {}); + params_base.has_mtp = false; + } return true; } @@ -209,12 +214,35 @@ void server_context::init() { slot.sparams = params_base.sparams; + if (params_base.has_mtp) { + if (llama_model_n_nextn_layer(model) > 0) { + SRV_INF("%s\n", "MTP detected, configuring for speculative decoding..."); + + params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; + + slot.has_mtp = true; + slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; + slot.params.speculative.n_min = 0; + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); + + SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); + llama_set_embeddings(ctx, true); + } + else { + SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative."); + params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; + slot.has_mtp = false; + } + } + const bool can_spec = common_speculative_is_compat(ctx); if (!can_spec) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } // try speculative decoding - if (can_spec){ + if (can_spec) { slot.spec = common_speculative_init(params_base.speculative, slot.ctx); if (slot.spec) { if (mctx) { @@ -223,9 +251,14 @@ void server_context::init() { } SLT_INF(slot, "%s", "speculative decoding context initialized\n"); } else { - SLT_INF(slot, "%s", "speculative decoding context not initialized\n"); + if (slot.has_mtp) { + SRV_ERR("%s", "failed to initialize MTP speculative context\n"); + } else { + SLT_INF(slot, "%s", "speculative decoding context not initialized\n"); + } } } + slot.reset(); slots.push_back(std::move(slot)); @@ -380,7 +413,7 @@ void server_slot::add_token_string(const completion_token_output& token) { } bool server_slot::can_speculate() const { - return !!spec; + return (!!spec || has_mtp); } int server_slot::get_n_draft_max() const { @@ -2533,6 +2566,15 @@ void server_context::add_sampled_tokens() { const auto & params_spec = slot.params.speculative; + if (slot.has_mtp) { + if (!slot.mtp_hidden_state.empty()) { + llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data()); + } else { + LOG_ERROR("MTP hidden state is empty during speculation", {}); + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); + } + } + llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); if (draft.size() > (size_t)n_draft_max) { @@ -2540,13 +2582,6 @@ void server_context::add_sampled_tokens() { draft.resize(n_draft_max); } - /*struct llama_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; - const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);*/ - // add the sampled token to the batch slot.i_batch_dft.push_back(batch.n_tokens); common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true); @@ -2759,6 +2794,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t } slot.n_past = 0; + slot.n_buffer = 0; + slot.token_buffer.clear(); slot.n_prompt_tokens = prompt_tokens.size(); LOG_VERBOSE("prompt tokenized", { @@ -3092,6 +3129,25 @@ void server_context::speculative_decoding_accept() { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted); + + if (slot.has_mtp) { + const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + if (!ids.empty()) { + const float* emb = llama_get_embeddings_ith(ctx, ids.size() - 1); + if (emb) { + slot.mtp_hidden_state.resize(n_embd); + memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float)); + } + } + else { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0)); + } + llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data()); + + int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1); + mtp_accept_tokens(ctx, ids, n_past_base, slot.id); + } + slot.i_batch_dft.clear(); slot.drafted.clear(); @@ -3362,6 +3418,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) { common_sampler_accept(slot.ctx_sampling, ctx, id, true); + if (params_base.has_mtp && slot.n_decoded == 0) { + if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) { + mtp_update_kv_cache(ctx, batch_view, true); + const float* emb = llama_get_embeddings_ith(ctx, -1); + if (emb) { + const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + slot.mtp_hidden_state.resize(n_embd); + memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float)); + } + } + } slot.n_decoded += 1; const int64_t t_current = ggml_time_us(); @@ -3394,7 +3461,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.i_batch = -1; } - + if (params_base.has_mtp) { + for (auto& slot : slots) { + if (slot.n_past < slot.n_prompt_tokens) { + if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) { + mtp_update_kv_cache(ctx, batch_view, true); + } + } + } + } // speculative decoding - main model sample and accept speculative_decoding_accept(); } diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 1c16cc35..8ff64d1a 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -134,6 +134,9 @@ struct server_slot { struct common_params_sampling sparams; common_sampler * ctx_sampling = nullptr; + bool has_mtp = false; + std::vector mtp_hidden_state; + // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted diff --git a/include/llama.h b/include/llama.h index 5332d810..6ab628e3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -279,6 +279,12 @@ extern "C" { LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs }; + enum llama_mtp_op_type { + MTP_OP_NONE = 0, + MTP_OP_WARMUP = 1, + MTP_OP_UPDATE_ACCEPTED = 2, + MTP_OP_DRAFT_GEN = 3, + }; typedef struct llama_token_data { llama_token id; // token id @@ -394,6 +400,7 @@ extern "C" { bool validate_quants; // if true, check for NaNs while loading the model bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor bool merge_up_gate_exps; // if true, merge ffn_up_exps and ffn_gate_exps tensors into a single, contiguous tensor + bool mtp; // if true, load MTP layers if present }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -449,6 +456,8 @@ extern "C" { bool split_mode_graph_scheduling; // if true, force split mode graph scheduling //bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads + bool mtp; // Activate MTP if supported + enum llama_mtp_op_type mtp_op_type; // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -1463,6 +1472,17 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); + // + // MTP + // + + LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model); + + // Set which, if any, MTP operation the context will use + LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type mtp_op_type); + + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + #ifdef __cplusplus } #endif diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 5031a1be..707ba1a8 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -303,6 +303,25 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector & ids) return gf; } +struct ggml_tensor * llm_build_context::build_inp_embd_mtp(struct ggml_tensor * mtp_tok_embd) { + struct ggml_tensor * cur = nullptr; + + if (batch.token) { + lctx.inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch.n_tokens); + + cb(lctx.inp_tokens, "inp_tokens", -1); + ggml_set_input(lctx.inp_tokens); + + cur = ggml_get_rows(ctx0, mtp_tok_embd, lctx.inp_tokens); + } else { + return nullptr; + } + + cb(cur, "inp_embd", -1); + + return cur; +} + ggml_tensor * llm_build_context::build_inp_pos() { int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, int64_t(n_tokens)*n_pos_per_embd); @@ -415,7 +434,10 @@ ggml_cgraph * llm_build_context::append_pooling(struct ggml_cgraph * gf) { struct ggml_tensor * inp = nullptr; for (int i = gf->n_nodes - 1; i >= 0; --i) { inp = gf->nodes[i]; - if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) { + + if (strcmp(inp->name, "result_norm") == 0 || + strcmp(inp->name, "result_embd") == 0 || + strcmp(inp->name, "output_normed") == 0) { break; } inp = nullptr; @@ -7372,138 +7394,281 @@ ggml_cgraph * llm_build_context::build_glm4_moe() { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - // input embeddings - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + ggml_tensor * cur; // position embeddings struct ggml_tensor * inp_pos = build_inp_pos(); - // attention KV cache input - //auto * inp_attn = build_attn_inp_kv_unified(); - - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - - // output token IDs (for last layer cropping) - struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; - auto rope_cache = model.split_mode != LLAMA_SPLIT_MODE_GRAPH && cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ? ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) : nullptr; - float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + if (cparams.mtp_op_type != MTP_OP_NONE) { + ggml_tensor* hidden_states_from_main_model; - // Only process up to last layer (skip final NextN layer) - // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { - struct ggml_tensor * inpSA = inpL; - - // self-attention - if (rope_cache == nullptr) { - cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, - inp_pos, il == n_transformer_layers - 1 ? inp_out_ids : nullptr, nullptr, - KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true); + if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) { + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); } else { - // Pre-attention norm - cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + } + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, - model.layers[il].wqkv, model.layers[il].bqkv, - model.layers[il].wqk, model.layers[il].bqk, - model.layers[il].wq, model.layers[il].bq, - model.layers[il].wk, model.layers[il].bk, - model.layers[il].wv, model.layers[il].bv, - model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il); + lctx.inp_mtp_states = hidden_states_from_main_model; - // apply RoPE - if (rope_cache) { - Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache); - Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache); + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + + cur = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head, gf, inp_pos, rope_cache); + + } else { + struct ggml_tensor * inpL; + + // input embeddings + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + // output token IDs (for last layer cropping) + struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; + + float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + struct ggml_tensor * inpSA = inpL; + + // self-attention + if (rope_cache == nullptr) { + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, + inp_pos, il == n_transformer_layers - 1 ? inp_out_ids : nullptr, nullptr, + KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true); } else { - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + // Pre-attention norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); - // build attention KV (no unified cache) - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, - n_tokens, kv_head, n_kv, - 1.0f/sqrtf(float(n_embd_head)), cb, il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, + model.layers[il].wqkv, model.layers[il].bqkv, + model.layers[il].wqk, model.layers[il].bqk, + model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, + model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il); - if (il == n_transformer_layers - 1 && inp_out_ids) { - // skip computing output for unused tokens - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // apply RoPE if (rope_cache) { - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache); + Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache); + } else { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // build attention KV (no unified cache) + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, + n_tokens, kv_head, n_kv, + 1.0f/sqrtf(float(n_embd_head)), cb, il); + + if (il == n_transformer_layers - 1 && inp_out_ids) { + // skip computing output for unused tokens + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (rope_cache) { + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } } } + + // crop output on last layer + + // residual connection for attention output + ggml_tensor * ffn_inp; + if (rope_cache) { + ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + } else { + ffn_inp = cur; + } + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + // dense FFN + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true); + cb(cur, "ffn_out", il); + } else { + cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, + model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, + model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, + model.layers[il].ffn_exp_probs_b, + model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases? + model.layers[il].ffn_gate_shexp, nullptr, + model.layers[il].ffn_down_shexp, nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, + (llm_expert_gating_func_type) hparams.expert_gating_func, + LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps); + } + + // residual and context vector + //cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // prepare next layer input + inpL = cur; } + cur = inpL; - // crop output on last layer - - // residual connection for attention output - ggml_tensor * ffn_inp; - if (rope_cache) { - ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - } else { - ffn_inp = cur; - } - - if ((uint32_t) il < hparams.n_layer_dense_lead) { - // dense FFN - cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true); - cb(cur, "ffn_out", il); - } else { - cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, - model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, - model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, - model.layers[il].ffn_exp_probs_b, - model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases? - model.layers[il].ffn_gate_shexp, nullptr, - model.layers[il].ffn_down_shexp, nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, - (llm_expert_gating_func_type) hparams.expert_gating_func, - LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps); - } - - // residual and context vector - //cur = ggml_add(ctx0, cur, ffn_inp); - cur = lctx.cvec.apply_to(ctx0, cur, il); - cb(cur, "l_out", il); - - // prepare next layer input - inpL = cur; + // lm head + cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb); + cb(cur, "result_output", -1); } - cur = inpL; - - // lm head - cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb); - cb(cur, "result_output", -1); - ggml_build_forward_expand(gf, cur); return gf; } +struct ggml_tensor * llm_build_context::build_mtp_tail( + const llama_layer & mtp_layer, + struct ggml_tensor * prev_embeddings, + int64_t n_embd_head, + struct ggml_cgraph * gf, + struct ggml_tensor * inp_pos, + struct ggml_tensor * rope_cache +) { + const int il = hparams.n_layer - 1; + + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // If nextn.embed_tokens is missing (GLM-4.6), use model.tok_embd + ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens; + if (mtp_embd_weights == nullptr) { + mtp_embd_weights = model.tok_embd; + } + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights); + + ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il); + ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_embeddings, hparams, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il); + + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + cb(combined, "mtp_concat", il); + ggml_tensor* cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined); + + struct ggml_tensor * inpSA = cur; + + cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // Self-Attention + { + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, + nullptr, nullptr, // wqkv, bqkv (not used in GLM usually?) + nullptr, nullptr, // wqk, bqk + mtp_layer.wq, mtp_layer.bq, + mtp_layer.wk, mtp_layer.bk, + mtp_layer.wv, mtp_layer.bv, + mtp_layer.attn_q_norm, mtp_layer.attn_k_norm, + 0.f, il); + + // RoPE + if (rope_cache) { + Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache); + Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache); + } else { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // KV Cache & Attention + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, + n_tokens, kv_head, n_kv, + 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "mtp_ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_post_norm", il); + + // moe ffn for nextn block + { + // Routed Experts + ggml_tensor * routed_out = llm_build_std_moe_ffn(ctx0, lctx, + NULL, // Norm handled above + cur, // Input (Normed) + mtp_layer.ffn_gate_inp, NULL, + mtp_layer.ffn_up_exps, NULL, + mtp_layer.ffn_gate_exps, NULL, + mtp_layer.ffn_down_exps, NULL, + mtp_layer.ffn_exp_probs_b, + nullptr, nullptr, // we don't have shared expert biases? + nullptr, nullptr, + nullptr, nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, + (llm_expert_gating_func_type) hparams.expert_gating_func, + LLM_FFN_SILU, cb, il, gf, true, mtp_layer.ffn_up_gate_exps); + cb(routed_out, "ffn_moe_out", il); + + // Shared Expert FFN + ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx, + NULL, // Norm handled above + cur, // Input + mtp_layer.ffn_up_shexp, NULL, NULL, + mtp_layer.ffn_gate_shexp, NULL, NULL, + mtp_layer.ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true); + cb(shared_out, "ffn_shexp_out", il); + + // Sum and Residual + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_ffn_out_resid", il); + } + cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il); + + if (inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + // If nextn.shared_head_head is missing (GLM-4.6), use model.output (Main LM Head) + ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head; + if (mtp_head_weights == nullptr) { + mtp_head_weights = model.output; + } + cur = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur); + cb(cur, "result_output", -1); + + return cur; +} + ggml_cgraph * llm_build_context::build_bitnet() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -9836,7 +10001,8 @@ ggml_cgraph * llm_build_context::llama_build_graph( } // add on pooling layer - if (lctx.cparams.embeddings) { + if (lctx.cparams.mtp_op_type == MTP_OP_NONE && (lctx.cparams.embeddings || + (lctx.model.hparams.nextn_predict_layers > 0 || lctx.model.mtp))) { result = llm.append_pooling(result); } @@ -10178,3 +10344,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens return cur; } + +int32_t llama_model_n_nextn_layer(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 9508c5c6..0810605b 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -114,6 +114,8 @@ struct llm_build_context { ggml_cgraph * build_defrag(const std::vector & ids); + struct ggml_tensor * build_inp_embd_mtp(struct ggml_tensor * mtp_tok_embd); + ggml_tensor * build_inp_pos(); ggml_tensor * build_input_scale(int n_tokens); @@ -430,4 +432,12 @@ llm_expert_gating_func_type gating_op, bool is_multi = false); static uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & kv_self); + struct ggml_tensor * build_mtp_tail( + const struct llama_layer & mtp_layer, + struct ggml_tensor * prev_embeddings, + int64_t n_embd_head, + struct ggml_cgraph * gf, + struct ggml_tensor * inp_pos, + struct ggml_tensor * rope_cache + ); }; diff --git a/src/llama-context.h b/src/llama-context.h index 9f4255fd..4acee93f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -191,6 +191,8 @@ struct llama_context { ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; + const float * draft_input_hidden_state = nullptr; + // input tensors struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] @@ -209,6 +211,7 @@ struct llama_context { struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens] + struct ggml_tensor * inp_mtp_states = nullptr; ggml_backend_t ggml_backend_by_name(const char * name); @@ -225,5 +228,7 @@ struct llama_context { std::vector cache_copies; bool update_cache_copies(); - + bool prepare_mtp_graph_inputs( + struct llama_context & lctx); + void set_mtp_op_type(llama_mtp_op_type value); }; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 3ee26c55..b178059f 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -45,9 +45,11 @@ struct llama_cparams { bool scheduler_async; int min_experts; float thresh_experts; + bool mtp; enum ggml_type reduce_type; enum llama_pooling_type pooling_type; + enum llama_mtp_op_type mtp_op_type; ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index fa5a34c9..52cec617 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -903,6 +903,12 @@ void llm_load_hparams( } // NextN/MTP parameters + if (model.mtp) { + hparams.n_layer_kv_from_start = hparams.n_layer; + } + else { + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + } ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); switch (hparams.n_layer) { diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index b4c64b31..7e60f156 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2278,9 +2278,12 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) { ggml_context * ctx_split = ctx_for_layer_split(i); int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= llama_model_loader::TENSOR_SKIP; + // Skip loading MTP layers if the feature is disabled + if (!model.mtp) { + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= llama_model_loader::TENSOR_SKIP; + } } auto & layer = model.layers[i]; @@ -3481,7 +3484,8 @@ bool create_tensors_helper::create_tensors() { throw std::runtime_error("unknown architecture"); } if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) { - const int n_layer = model.layers.size() - model.hparams.nextn_predict_layers; + const int n_layer = model.mtp ? model.layers.size() + : model.layers.size() - model.hparams.nextn_predict_layers; LLAMA_LOG_INFO("================================ max_gpu = %d\n", model.max_gpu); std::vector mem_used(model.splits.size(), 0); const auto & hparams = model.hparams; diff --git a/src/llama-model.h b/src/llama-model.h index 2b415c54..02358fb5 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -374,6 +374,8 @@ struct llama_model { int max_gpu = 0; // max. number of GPUs to use per layer for aplit mode "graph" int n_gpu_layers; + bool mtp; // use mtp if is supported by the Model + std::vector rpc_servers; std::vector devices; diff --git a/src/llama.cpp b/src/llama.cpp index 3d22de1e..a21b3518 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -546,6 +546,7 @@ struct llama_context::Prev { int all_seq_id; int n_outputs; int n_kv; + llama_mtp_op_type mtp_op_type; ggml_cgraph * graph; }; @@ -563,11 +564,13 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) { kv_self.head > 0 && kv_self.n == prev->n_kv && n_outputs == prev->n_outputs && + cparams.mtp_op_type == prev->mtp_op_type && update_cache_copies(); } bool llama_context::update_cache_copies() { - int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; + const int n_layer = model.mtp ? model.hparams.n_layer + : model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; auto layer_has_attention_kv = [&](int il) { return !((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && model.hparams.is_recurrent(il)); }; @@ -638,6 +641,12 @@ llama_context::llama_context(const llama_model & model) } } +void llama_context::set_mtp_op_type(llama_mtp_op_type value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.mtp_op_type = value; +} + llama_context::~llama_context() { ggml_backend_sched_free(sched); @@ -716,7 +725,8 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams = model.hparams; - const int64_t n_layer = hparams.n_layer - hparams.nextn_predict_layers; + const int64_t n_layer = model.mtp ? hparams.n_layer + : hparams.n_layer - hparams.nextn_predict_layers; cache.has_shift = false; @@ -993,7 +1003,8 @@ static bool llama_kv_cache_init( // to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_batch & batch) { + const struct llama_batch & batch, + enum llama_mtp_op_type op_type) { const uint32_t n_tokens = batch.n_tokens; if (cache.recurrent) { @@ -1044,6 +1055,45 @@ static bool llama_kv_cache_find_slot( } // otherwise, one cell per token. + bool is_mtp_special_op = (op_type == MTP_OP_WARMUP || + op_type == MTP_OP_UPDATE_ACCEPTED); + if (is_mtp_special_op) { + const llama_pos target_pos = batch.pos[0]; + const llama_seq_id target_seq = batch.seq_id[0][0]; + + bool found = false; + + if (cache.head < cache.size && + cache.cells[cache.head].pos == target_pos && + cache.cells[cache.head].has_seq_id(target_seq)) { + found = true; + } + else { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos == target_pos && + cache.cells[i].has_seq_id(target_seq)) { + + cache.head = i; + found = true; + break; + } + } + } + + if (!found) { + LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n", + __func__, target_seq, target_pos); + return false; + } + + if (cache.head + n_tokens > cache.size) { + LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__); + return false; + } + + return true; + } + if (n_tokens > cache.size) { LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); return false; @@ -1893,6 +1943,7 @@ static bool llm_load_tensors( const float * tensor_split, bool use_mlock, bool validate_quants, + bool mtp, llama_progress_callback progress_callback, void * progress_callback_user_data) { model.t_start_us = ggml_time_us(); @@ -1921,6 +1972,7 @@ static bool llm_load_tensors( model.main_gpu = main_gpu; model.max_gpu = max_gpu; model.n_gpu_layers = n_gpu_layers; + model.mtp = mtp; const int n_layer = hparams.n_layer; const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); @@ -2300,7 +2352,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam if (!llm_load_tensors( ml, model, params.n_gpu_layers, params.mla, params.split_mode, params.main_gpu, params.max_gpu, params.tensor_split, - params.use_mlock, params.validate_quants, + params.use_mlock, params.validate_quants, params.mtp, params.progress_callback, params.progress_callback_user_data )) { return -2; @@ -2969,8 +3021,9 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead - const bool has_logits = !cparams.embeddings; - const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)); + const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp; + const bool has_logits = !cparams.embeddings || has_mtp; + const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp; const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; @@ -3049,6 +3102,24 @@ static void llama_graph_compute( // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); } +static bool prepare_mtp_graph_inputs(struct llama_context & lctx) { + ggml_tensor * dst = lctx.inp_mtp_states; + const float * src = nullptr; + if (lctx.cparams.mtp_op_type == MTP_OP_WARMUP || lctx.cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) { + src = lctx.embd; + } else { + src = lctx.draft_input_hidden_state; + } + + if (!src) { + LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__); + return false; + } + + ggml_backend_tensor_set(dst, src, 0, ggml_nbytes(dst)); + return true; +} + // decode a batch of tokens by evaluating the transformer // // - lctx: llama context @@ -3260,7 +3331,7 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) { return 1; } @@ -3322,37 +3393,50 @@ static int llama_decode_internal( #endif if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { lctx.prev = std::make_unique(llama_context::Prev{ - (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf}); + (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, + cparams.mtp_op_type, gf}); } } else { //printf("Reusing graph\n"); gf = lctx.prev->graph; } + if (cparams.mtp_op_type != MTP_OP_NONE) { + if (!prepare_mtp_graph_inputs(lctx)) { + return GGML_STATUS_FAILED; + } + } + // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; + struct ggml_tensor * embd = nullptr; if (lctx.n_outputs == 0) { // no output - res = nullptr; - embd = nullptr; - } else if (cparams.embeddings) { - res = nullptr; // do not extract logits for embedding case - embd = nullptr; - for (int i = gf->n_nodes - 1; i >= 0; --i) { - if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) { - embd = gf->nodes[i]; - break; + res = nullptr; + } + else { + const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp; + if (cparams.embeddings || has_mtp) { + for (int i = gf->n_nodes - 1; i >= 0; --i) { + if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) { + embd = gf->nodes[i]; + break; + } + if (strcmp(gf->nodes[i]->name, "result_norm") == 0) { + embd = gf->nodes[i]; + } + } + } + if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) { + res = nullptr; // do not extract logits for embedding case + } else { + if (!embd) { // do not extract embeddings when not needed + GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } } - GGML_ASSERT(embd != nullptr && "missing embeddings tensor"); - } else { - embd = nullptr; // do not extract embeddings when not needed - GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); - #if IK_PRINT_TIMING == 1 tim1 = ggml_time_us(); #endif @@ -3392,17 +3476,21 @@ static int llama_decode_internal( #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(lctx.logits != nullptr); + // Do not process logits if MTP is only updating the KV cache. + if (cparams.mtp_op_type != MTP_OP_WARMUP && + cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(lctx.logits != nullptr); - float * logits_out = lctx.logits + n_outputs_prev*n_vocab; - const int32_t n_outputs_new = lctx.n_outputs; + float * logits_out = lctx.logits + n_outputs_prev*n_vocab; + const int32_t n_outputs_new = lctx.n_outputs; - if (n_outputs_new) { - GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); - GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); - ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); + if (n_outputs_new) { + GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); + GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); + ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); + } } #if IK_PRINT_TIMING tim2 = ggml_time_us(); @@ -3411,7 +3499,7 @@ static int llama_decode_internal( } // extract embeddings - if (embd) { + if (embd && cparams.mtp_op_type == MTP_OP_NONE) { #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif @@ -3617,57 +3705,59 @@ static int llama_encode_internal( // extract embeddings if (embd) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); - GGML_ASSERT(backend_embd != nullptr); + if (cparams.mtp_op_type == MTP_OP_NONE) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); + GGML_ASSERT(backend_embd != nullptr); - if (llama_model_has_decoder(&lctx.model)) { - lctx.embd_enc.resize(n_tokens*n_embd); - float * embd_out = lctx.embd_enc.data(); + if (llama_model_has_decoder(&lctx.model)) { + lctx.embd_enc.resize(n_tokens*n_embd); + float * embd_out = lctx.embd_enc.data(); - ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); - // remember the sequence ids used during the encoding - needed for cross attention later - lctx.seq_ids_enc.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - for (int s = 0; s < batch.n_seq_id[i]; s++) { - llama_seq_id seq_id = batch.seq_id[i][s]; - lctx.seq_ids_enc[i].insert(seq_id); - } - } - } else { - GGML_ASSERT(lctx.embd != nullptr); - - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(lctx.embd != nullptr); - float * embd_out = lctx.embd; - - GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size); - ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings - auto & embd_seq_out = lctx.embd_seq; - embd_seq_out.clear(); - - for (uint32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { - continue; - } - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); + // remember the sequence ids used during the encoding - needed for cross attention later + lctx.seq_ids_enc.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + for (int s = 0; s < batch.n_seq_id[i]; s++) { + llama_seq_id seq_id = batch.seq_id[i][s]; + lctx.seq_ids_enc[i].insert(seq_id); } + } + } else { + GGML_ASSERT(lctx.embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(lctx.embd != nullptr); + float * embd_out = lctx.embd; + + GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = lctx.embd_seq; + embd_seq_out.clear(); + + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } } } } @@ -4223,6 +4313,7 @@ struct llama_model_params llama_model_default_params() { /*.validate_quants =*/ false, /*.merge_qkv =*/ false, /*.merge_up_gate_exps =*/ false, + /*.mtp =*/ false, }; #ifdef GGML_USE_METAL @@ -4278,6 +4369,8 @@ struct llama_context_params llama_context_default_params() { /*.split_mode_graph_scheduling =*/ false, // /*.split_mode_f16 =*/ true, /*.scheduler_async =*/ false, + /*.mtp =*/ false, + /*.mtp_op_type =*/ MTP_OP_NONE, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, /*.offload_policy =*/ nullptr, @@ -4648,6 +4741,7 @@ struct llama_context * llama_init_from_model( cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.cuda_params = params.cuda_params; + cparams.mtp = params.mtp; cparams.reduce_type = params.type_reduce; cparams.pooling_type = params.pooling_type; @@ -4725,6 +4819,12 @@ struct llama_context * llama_init_from_model( } } + if (model->arch != LLM_ARCH_GLM4_MOE && cparams.mtp != 0) { + cparams.mtp = 0; + } + + cparams.mtp_op_type = params.mtp_op_type; + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); @@ -6058,7 +6158,7 @@ struct llama_data_read { batch.n_seq_id[i] = 1; batch.seq_id[i][0] = dest_seq_id; } - if (!llama_kv_cache_find_slot(kv_self, batch)) { + if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) { llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; @@ -7003,6 +7103,10 @@ int32_t llama_decode( return ret; } +void llama_set_mtp_op_type(llama_context * ctx, llama_mtp_op_type mtp_op_type) { + ctx->set_mtp_op_type(mtp_op_type); +} + void llama_synchronize(struct llama_context * ctx) { ggml_backend_sched_synchronize(ctx->sched); @@ -8333,3 +8437,8 @@ void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_of printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXX offload(%s) = %d\n", op_name, on_or_off); ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off); } + +void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { + ctx->draft_input_hidden_state = hidden_state; +} +