Add MTP decoding support for GLM-4.x MoE (#1270)

* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic
This commit is contained in:
Samuel Oliveira Alves
2026-02-22 14:14:39 -03:00
committed by GitHub
parent cbf7fc7e2f
commit 09a88c9ae5
16 changed files with 820 additions and 206 deletions

View File

@@ -1463,6 +1463,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cuda_params = argv[i];
return true;
}
if (arg == "-mtp" || arg == "--multi-token-prediction") {
params.has_mtp = true;
return true;
}
if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") {
params.has_mtp = false;
return true;
}
if (arg == "-draft" || arg == "--draft-params") {
CHECK_ARG
params.speculative.params = argv[i];
@@ -2475,6 +2483,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" });
options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" });
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
@@ -3207,6 +3217,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
mparams.validate_quants = params.validate_quants;
mparams.merge_qkv = params.merge_qkv;
mparams.merge_up_gate_exps = params.merge_up_gate_exps;
mparams.mtp = params.has_mtp;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
@@ -3329,6 +3340,8 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
cparams.thresh_experts = params.thresh_experts;
cparams.only_active_experts = params.only_active_exps;
cparams.max_extra_alloc = params.max_extra_alloc_MiB;
cparams.mtp = params.has_mtp;
cparams.mtp_op_type = MTP_OP_NONE;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);

View File

@@ -139,6 +139,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
@@ -356,6 +357,7 @@ struct gpt_params {
bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling
//bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph
bool has_mtp = false; // enable MTP if supported by the model
std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

View File

@@ -20,6 +20,7 @@
const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
COMMON_SPECULATIVE_TYPE_MTP,
COMMON_SPECULATIVE_TYPE_EAGLE3,
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
@@ -31,6 +32,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
@@ -144,6 +146,58 @@ struct common_speculative_state {
virtual void accept(uint16_t n_accepted) = 0;
};
struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
common_sampler * smpl;
common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
{
struct common_params_sampling params;
params.samplers_sequence = {
llama_sampler_type::DIST,
};
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
}
~common_speculative_state_mtp() override {
common_sampler_free(smpl);
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
int32_t n_past = (int32_t)prompt_tgt.size();
llama_seq_id seq_id = 0;
result = mtp_speculative_gen_draft(
smpl,
ctx_tgt,
params.n_max,
params.p_min,
id_last,
n_past,
seq_id
);
}
void accept(uint16_t n_accepted) override {
GGML_UNUSED(n_accepted);
}
};
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
@@ -760,6 +814,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
@@ -828,6 +883,7 @@ common_speculative * common_speculative_init(
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -867,6 +923,9 @@ common_speculative * common_speculative_init(
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
}
if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
}
@@ -890,6 +949,12 @@ common_speculative * common_speculative_init(
));
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
/* .ctx_tgt = */ ctx_tgt
));
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
@@ -1047,3 +1112,112 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_perf.c_str());
}
}
// ----------------------------------------------------------------------------
// MTP
// ----------------------------------------------------------------------------
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id) {
llama_tokens drafts;
drafts.reserve(n_draft);
if (!smpl) return drafts;
common_sampler_reset(smpl);
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
llama_token current_input_id = id_last;
int32_t current_n_past = n_past;
for (int i = 0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
if (llama_decode(ctx, mtp_batch) != 0) {
break;
}
common_sampler_sample(smpl, ctx, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
if (!cur_p || cur_p->size == 0) {
break;
}
const llama_token id_next = cur_p->data[0].id;
const float prob = cur_p->data[0].p;
common_sampler_accept(smpl, nullptr, id_next, true);
if (prob < p_min) {
break;
}
drafts.push_back(id_next);
current_input_id = id_next;
current_n_past++;
}
llama_batch_free(mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
// Purge the metadata for the draft tokens.
// This prevents cache state corruption where two cells map to the same logical position.
if (!drafts.empty()) {
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
}
return drafts;
}
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
if (batch.n_tokens == 0) {
return;
}
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
llama_batch mtp_batch = batch;
if (is_prompt_warmup) {
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
} else {
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
}
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
llama_decode(ctx, mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
}
void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
) {
if (ids.empty()) {
return;
}
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
}
mtp_update_kv_cache(ctx, accepted_batch, false);
llama_batch_free(accepted_batch);
}

View File

@@ -39,3 +39,22 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id);
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
);

View File

@@ -35,7 +35,7 @@
#include <array>
#include <functional>
#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)
#define DEFAULT_INTERPOLATION_MODE ((int)GGML_SCALE_MODE_BILINEAR | (int)GGML_SCALE_FLAG_ALIGN_CORNERS)
// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};

View File

@@ -152,12 +152,17 @@ bool server_context::load_model(const gpt_params& params_) {
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
return false;
}
cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.model_dft = model_dft;
params_base.speculative.cparams_dft = cparams_dft;
}
else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) {
LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {});
params_base.has_mtp = false;
}
return true;
}
@@ -209,12 +214,35 @@ void server_context::init() {
slot.sparams = params_base.sparams;
if (params_base.has_mtp) {
if (llama_model_n_nextn_layer(model) > 0) {
SRV_INF("%s\n", "MTP detected, configuring for speculative decoding...");
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
slot.has_mtp = true;
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
slot.params.speculative.n_min = 0;
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
llama_set_embeddings(ctx, true);
}
else {
SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative.");
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
slot.has_mtp = false;
}
}
const bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// try speculative decoding
if (can_spec){
if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
@@ -223,9 +251,14 @@ void server_context::init() {
}
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
if (slot.has_mtp) {
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
}
}
}
slot.reset();
slots.push_back(std::move(slot));
@@ -380,7 +413,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
}
bool server_slot::can_speculate() const {
return !!spec;
return (!!spec || has_mtp);
}
int server_slot::get_n_draft_max() const {
@@ -2533,6 +2566,15 @@ void server_context::add_sampled_tokens() {
const auto & params_spec = slot.params.speculative;
if (slot.has_mtp) {
if (!slot.mtp_hidden_state.empty()) {
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
} else {
LOG_ERROR("MTP hidden state is empty during speculation", {});
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1));
}
}
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > (size_t)n_draft_max) {
@@ -2540,13 +2582,6 @@ void server_context::add_sampled_tokens() {
draft.resize(n_draft_max);
}
/*struct llama_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
llama_tokens draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);*/
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
@@ -2759,6 +2794,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
}
slot.n_past = 0;
slot.n_buffer = 0;
slot.token_buffer.clear();
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("prompt tokenized", {
@@ -3092,6 +3129,25 @@ void server_context::speculative_decoding_accept() {
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
if (slot.has_mtp) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
if (!ids.empty()) {
const float* emb = llama_get_embeddings_ith(ctx, ids.size() - 1);
if (emb) {
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
}
}
else {
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0));
}
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
mtp_accept_tokens(ctx, ids, n_past_base, slot.id);
}
slot.i_batch_dft.clear();
slot.drafted.clear();
@@ -3362,6 +3418,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
if (params_base.has_mtp && slot.n_decoded == 0) {
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
mtp_update_kv_cache(ctx, batch_view, true);
const float* emb = llama_get_embeddings_ith(ctx, -1);
if (emb) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
}
}
}
slot.n_decoded += 1;
const int64_t t_current = ggml_time_us();
@@ -3394,7 +3461,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.i_batch = -1;
}
if (params_base.has_mtp) {
for (auto& slot : slots) {
if (slot.n_past < slot.n_prompt_tokens) {
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
mtp_update_kv_cache(ctx, batch_view, true);
}
}
}
}
// speculative decoding - main model sample and accept
speculative_decoding_accept();
}

View File

@@ -134,6 +134,9 @@ struct server_slot {
struct common_params_sampling sparams;
common_sampler * ctx_sampling = nullptr;
bool has_mtp = false;
std::vector<float> mtp_hidden_state;
// speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted

View File

@@ -279,6 +279,12 @@ extern "C" {
LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs
};
enum llama_mtp_op_type {
MTP_OP_NONE = 0,
MTP_OP_WARMUP = 1,
MTP_OP_UPDATE_ACCEPTED = 2,
MTP_OP_DRAFT_GEN = 3,
};
typedef struct llama_token_data {
llama_token id; // token id
@@ -394,6 +400,7 @@ extern "C" {
bool validate_quants; // if true, check for NaNs while loading the model
bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
bool merge_up_gate_exps; // if true, merge ffn_up_exps and ffn_gate_exps tensors into a single, contiguous tensor
bool mtp; // if true, load MTP layers if present
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
@@ -449,6 +456,8 @@ extern "C" {
bool split_mode_graph_scheduling; // if true, force split mode graph scheduling
//bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs
bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads
bool mtp; // Activate MTP if supported
enum llama_mtp_op_type mtp_op_type;
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
@@ -1463,6 +1472,17 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
//
// MTP
//
LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);
// Set which, if any, MTP operation the context will use
LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type mtp_op_type);
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
#ifdef __cplusplus
}
#endif

View File

@@ -303,6 +303,25 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
return gf;
}
struct ggml_tensor * llm_build_context::build_inp_embd_mtp(struct ggml_tensor * mtp_tok_embd) {
struct ggml_tensor * cur = nullptr;
if (batch.token) {
lctx.inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch.n_tokens);
cb(lctx.inp_tokens, "inp_tokens", -1);
ggml_set_input(lctx.inp_tokens);
cur = ggml_get_rows(ctx0, mtp_tok_embd, lctx.inp_tokens);
} else {
return nullptr;
}
cb(cur, "inp_embd", -1);
return cur;
}
ggml_tensor * llm_build_context::build_inp_pos() {
int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, int64_t(n_tokens)*n_pos_per_embd);
@@ -415,7 +434,10 @@ ggml_cgraph * llm_build_context::append_pooling(struct ggml_cgraph * gf) {
struct ggml_tensor * inp = nullptr;
for (int i = gf->n_nodes - 1; i >= 0; --i) {
inp = gf->nodes[i];
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
if (strcmp(inp->name, "result_norm") == 0 ||
strcmp(inp->name, "result_embd") == 0 ||
strcmp(inp->name, "output_normed") == 0) {
break;
}
inp = nullptr;
@@ -7372,138 +7394,281 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
// input embeddings
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
ggml_tensor * cur;
// position embeddings
struct ggml_tensor * inp_pos = build_inp_pos();
// attention KV cache input
//auto * inp_attn = build_attn_inp_kv_unified();
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
// output token IDs (for last layer cropping)
struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
auto rope_cache = model.split_mode != LLAMA_SPLIT_MODE_GRAPH && cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
if (cparams.mtp_op_type != MTP_OP_NONE) {
ggml_tensor* hidden_states_from_main_model;
// Only process up to last layer (skip final NextN layer)
// Final layer tensors are loaded but not processed in forward pass
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
struct ggml_tensor * inpSA = inpL;
// self-attention
if (rope_cache == nullptr) {
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL,
inp_pos, il == n_transformer_layers - 1 ? inp_out_ids : nullptr, nullptr,
KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true);
if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
} else {
// Pre-attention norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
}
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
ggml_set_input(hidden_states_from_main_model);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
lctx.inp_mtp_states = hidden_states_from_main_model;
// apply RoPE
if (rope_cache) {
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp];
cur = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head, gf, inp_pos, rope_cache);
} else {
struct ggml_tensor * inpL;
// input embeddings
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
// output token IDs (for last layer cropping)
struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
// Only process up to last layer (skip final NextN layer)
// Final layer tensors are loaded but not processed in forward pass
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
struct ggml_tensor * inpSA = inpL;
// self-attention
if (rope_cache == nullptr) {
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL,
inp_pos, il == n_transformer_layers - 1 ? inp_out_ids : nullptr, nullptr,
KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// Pre-attention norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// build attention KV (no unified cache)
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask,
n_tokens, kv_head, n_kv,
1.0f/sqrtf(float(n_embd_head)), cb, il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
if (il == n_transformer_layers - 1 && inp_out_ids) {
// skip computing output for unused tokens
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// apply RoPE
if (rope_cache) {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// build attention KV (no unified cache)
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask,
n_tokens, kv_head, n_kv,
1.0f/sqrtf(float(n_embd_head)), cb, il);
if (il == n_transformer_layers - 1 && inp_out_ids) {
// skip computing output for unused tokens
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
if (rope_cache) {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
}
}
// crop output on last layer
// residual connection for attention output
ggml_tensor * ffn_inp;
if (rope_cache) {
ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
} else {
ffn_inp = cur;
}
if ((uint32_t) il < hparams.n_layer_dense_lead) {
// dense FFN
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
cb(cur, "ffn_out", il);
} else {
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
model.layers[il].ffn_exp_probs_b,
model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases?
model.layers[il].ffn_gate_shexp, nullptr,
model.layers[il].ffn_down_shexp, nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale,
(llm_expert_gating_func_type) hparams.expert_gating_func,
LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps);
}
// residual and context vector
//cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// prepare next layer input
inpL = cur;
}
cur = inpL;
// crop output on last layer
// residual connection for attention output
ggml_tensor * ffn_inp;
if (rope_cache) {
ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
} else {
ffn_inp = cur;
}
if ((uint32_t) il < hparams.n_layer_dense_lead) {
// dense FFN
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
cb(cur, "ffn_out", il);
} else {
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
model.layers[il].ffn_exp_probs_b,
model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases?
model.layers[il].ffn_gate_shexp, nullptr,
model.layers[il].ffn_down_shexp, nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale,
(llm_expert_gating_func_type) hparams.expert_gating_func,
LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps);
}
// residual and context vector
//cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// prepare next layer input
inpL = cur;
// lm head
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
cb(cur, "result_output", -1);
}
cur = inpL;
// lm head
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_tensor * llm_build_context::build_mtp_tail(
const llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,
int64_t n_embd_head,
struct ggml_cgraph * gf,
struct ggml_tensor * inp_pos,
struct ggml_tensor * rope_cache
) {
const int il = hparams.n_layer - 1;
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
// If nextn.embed_tokens is missing (GLM-4.6), use model.tok_embd
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
if (mtp_embd_weights == nullptr) {
mtp_embd_weights = model.tok_embd;
}
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il);
ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_embeddings, hparams, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il);
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
cb(combined, "mtp_concat", il);
ggml_tensor* cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined);
struct ggml_tensor * inpSA = cur;
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// Self-Attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
nullptr, nullptr, // wqkv, bqkv (not used in GLM usually?)
nullptr, nullptr, // wqk, bqk
mtp_layer.wq, mtp_layer.bq,
mtp_layer.wk, mtp_layer.bk,
mtp_layer.wv, mtp_layer.bv,
mtp_layer.attn_q_norm, mtp_layer.attn_k_norm,
0.f, il);
// RoPE
if (rope_cache) {
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// KV Cache & Attention
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask,
n_tokens, kv_head, n_kv,
1.0f/sqrtf(float(n_embd_head)), cb, il);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "mtp_ffn_inp", il);
cur = llm_build_norm(ctx0, ffn_inp, hparams, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
// moe ffn for nextn block
{
// Routed Experts
ggml_tensor * routed_out = llm_build_std_moe_ffn(ctx0, lctx,
NULL, // Norm handled above
cur, // Input (Normed)
mtp_layer.ffn_gate_inp, NULL,
mtp_layer.ffn_up_exps, NULL,
mtp_layer.ffn_gate_exps, NULL,
mtp_layer.ffn_down_exps, NULL,
mtp_layer.ffn_exp_probs_b,
nullptr, nullptr, // we don't have shared expert biases?
nullptr, nullptr,
nullptr, nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale,
(llm_expert_gating_func_type) hparams.expert_gating_func,
LLM_FFN_SILU, cb, il, gf, true, mtp_layer.ffn_up_gate_exps);
cb(routed_out, "ffn_moe_out", il);
// Shared Expert FFN
ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx,
NULL, // Norm handled above
cur, // Input
mtp_layer.ffn_up_shexp, NULL, NULL,
mtp_layer.ffn_gate_shexp, NULL, NULL,
mtp_layer.ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
cb(shared_out, "ffn_shexp_out", il);
// Sum and Residual
cur = ggml_add(ctx0, routed_out, shared_out);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "mtp_ffn_out_resid", il);
}
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
if (inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
// If nextn.shared_head_head is missing (GLM-4.6), use model.output (Main LM Head)
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
if (mtp_head_weights == nullptr) {
mtp_head_weights = model.output;
}
cur = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur);
cb(cur, "result_output", -1);
return cur;
}
ggml_cgraph * llm_build_context::build_bitnet() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -9836,7 +10001,8 @@ ggml_cgraph * llm_build_context::llama_build_graph(
}
// add on pooling layer
if (lctx.cparams.embeddings) {
if (lctx.cparams.mtp_op_type == MTP_OP_NONE && (lctx.cparams.embeddings ||
(lctx.model.hparams.nextn_predict_layers > 0 || lctx.model.mtp))) {
result = llm.append_pooling(result);
}
@@ -10178,3 +10344,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
return cur;
}
int32_t llama_model_n_nextn_layer(const llama_model * model) {
return model->hparams.nextn_predict_layers;
}

View File

@@ -114,6 +114,8 @@ struct llm_build_context {
ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids);
struct ggml_tensor * build_inp_embd_mtp(struct ggml_tensor * mtp_tok_embd);
ggml_tensor * build_inp_pos();
ggml_tensor * build_input_scale(int n_tokens);
@@ -430,4 +432,12 @@ llm_expert_gating_func_type gating_op,
bool is_multi = false);
static uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & kv_self);
struct ggml_tensor * build_mtp_tail(
const struct llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,
int64_t n_embd_head,
struct ggml_cgraph * gf,
struct ggml_tensor * inp_pos,
struct ggml_tensor * rope_cache
);
};

View File

@@ -191,6 +191,8 @@ struct llama_context {
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
const float * draft_input_hidden_state = nullptr;
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
@@ -209,6 +211,7 @@ struct llama_context {
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens]
struct ggml_tensor * inp_mtp_states = nullptr;
ggml_backend_t ggml_backend_by_name(const char * name);
@@ -225,5 +228,7 @@ struct llama_context {
std::vector<CacheCopy> cache_copies;
bool update_cache_copies();
bool prepare_mtp_graph_inputs(
struct llama_context & lctx);
void set_mtp_op_type(llama_mtp_op_type value);
};

View File

@@ -45,9 +45,11 @@ struct llama_cparams {
bool scheduler_async;
int min_experts;
float thresh_experts;
bool mtp;
enum ggml_type reduce_type;
enum llama_pooling_type pooling_type;
enum llama_mtp_op_type mtp_op_type;
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;

View File

@@ -903,6 +903,12 @@ void llm_load_hparams(
}
// NextN/MTP parameters
if (model.mtp) {
hparams.n_layer_kv_from_start = hparams.n_layer;
}
else {
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
}
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
switch (hparams.n_layer) {

View File

@@ -2278,9 +2278,12 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
ggml_context * ctx_split = ctx_for_layer_split(i);
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= llama_model_loader::TENSOR_SKIP;
// Skip loading MTP layers if the feature is disabled
if (!model.mtp) {
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= llama_model_loader::TENSOR_SKIP;
}
}
auto & layer = model.layers[i];
@@ -3481,7 +3484,8 @@ bool create_tensors_helper::create_tensors() {
throw std::runtime_error("unknown architecture");
}
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
const int n_layer = model.layers.size() - model.hparams.nextn_predict_layers;
const int n_layer = model.mtp ? model.layers.size()
: model.layers.size() - model.hparams.nextn_predict_layers;
LLAMA_LOG_INFO("================================ max_gpu = %d\n", model.max_gpu);
std::vector<size_t> mem_used(model.splits.size(), 0);
const auto & hparams = model.hparams;

View File

@@ -374,6 +374,8 @@ struct llama_model {
int max_gpu = 0; // max. number of GPUs to use per layer for aplit mode "graph"
int n_gpu_layers;
bool mtp; // use mtp if is supported by the Model
std::vector<rpc_device> rpc_servers;
std::vector<int32_t> devices;

View File

@@ -546,6 +546,7 @@ struct llama_context::Prev {
int all_seq_id;
int n_outputs;
int n_kv;
llama_mtp_op_type mtp_op_type;
ggml_cgraph * graph;
};
@@ -563,11 +564,13 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
kv_self.head > 0 &&
kv_self.n == prev->n_kv &&
n_outputs == prev->n_outputs &&
cparams.mtp_op_type == prev->mtp_op_type &&
update_cache_copies();
}
bool llama_context::update_cache_copies() {
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
const int n_layer = model.mtp ? model.hparams.n_layer
: model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
auto layer_has_attention_kv = [&](int il) {
return !((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && model.hparams.is_recurrent(il));
};
@@ -638,6 +641,12 @@ llama_context::llama_context(const llama_model & model)
}
}
void llama_context::set_mtp_op_type(llama_mtp_op_type value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
cparams.mtp_op_type = value;
}
llama_context::~llama_context() {
ggml_backend_sched_free(sched);
@@ -716,7 +725,8 @@ static bool llama_kv_cache_init(
const struct llama_hparams & hparams = model.hparams;
const int64_t n_layer = hparams.n_layer - hparams.nextn_predict_layers;
const int64_t n_layer = model.mtp ? hparams.n_layer
: hparams.n_layer - hparams.nextn_predict_layers;
cache.has_shift = false;
@@ -993,7 +1003,8 @@ static bool llama_kv_cache_init(
// to the first cell of the slot.
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
const struct llama_batch & batch,
enum llama_mtp_op_type op_type) {
const uint32_t n_tokens = batch.n_tokens;
if (cache.recurrent) {
@@ -1044,6 +1055,45 @@ static bool llama_kv_cache_find_slot(
}
// otherwise, one cell per token.
bool is_mtp_special_op = (op_type == MTP_OP_WARMUP ||
op_type == MTP_OP_UPDATE_ACCEPTED);
if (is_mtp_special_op) {
const llama_pos target_pos = batch.pos[0];
const llama_seq_id target_seq = batch.seq_id[0][0];
bool found = false;
if (cache.head < cache.size &&
cache.cells[cache.head].pos == target_pos &&
cache.cells[cache.head].has_seq_id(target_seq)) {
found = true;
}
else {
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos == target_pos &&
cache.cells[i].has_seq_id(target_seq)) {
cache.head = i;
found = true;
break;
}
}
}
if (!found) {
LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n",
__func__, target_seq, target_pos);
return false;
}
if (cache.head + n_tokens > cache.size) {
LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__);
return false;
}
return true;
}
if (n_tokens > cache.size) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
return false;
@@ -1893,6 +1943,7 @@ static bool llm_load_tensors(
const float * tensor_split,
bool use_mlock,
bool validate_quants,
bool mtp,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
model.t_start_us = ggml_time_us();
@@ -1921,6 +1972,7 @@ static bool llm_load_tensors(
model.main_gpu = main_gpu;
model.max_gpu = max_gpu;
model.n_gpu_layers = n_gpu_layers;
model.mtp = mtp;
const int n_layer = hparams.n_layer;
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
@@ -2300,7 +2352,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
if (!llm_load_tensors(
ml, model, params.n_gpu_layers, params.mla, params.split_mode, params.main_gpu, params.max_gpu, params.tensor_split,
params.use_mlock, params.validate_quants,
params.use_mlock, params.validate_quants, params.mtp,
params.progress_callback, params.progress_callback_user_data
)) {
return -2;
@@ -2969,8 +3021,9 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
const bool has_logits = !cparams.embeddings;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp;
const bool has_logits = !cparams.embeddings || has_mtp;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp;
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
@@ -3049,6 +3102,24 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
}
static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
ggml_tensor * dst = lctx.inp_mtp_states;
const float * src = nullptr;
if (lctx.cparams.mtp_op_type == MTP_OP_WARMUP || lctx.cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
src = lctx.embd;
} else {
src = lctx.draft_input_hidden_state;
}
if (!src) {
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
return false;
}
ggml_backend_tensor_set(dst, src, 0, ggml_nbytes(dst));
return true;
}
// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
@@ -3260,7 +3331,7 @@ static int llama_decode_internal(
kv_self.head = 0;
}
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) {
return 1;
}
@@ -3322,37 +3393,50 @@ static int llama_decode_internal(
#endif
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf});
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
cparams.mtp_op_type, gf});
}
} else {
//printf("Reusing graph\n");
gf = lctx.prev->graph;
}
if (cparams.mtp_op_type != MTP_OP_NONE) {
if (!prepare_mtp_graph_inputs(lctx)) {
return GGML_STATUS_FAILED;
}
}
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
struct ggml_tensor * embd = nullptr;
if (lctx.n_outputs == 0) {
// no output
res = nullptr;
embd = nullptr;
} else if (cparams.embeddings) {
res = nullptr; // do not extract logits for embedding case
embd = nullptr;
for (int i = gf->n_nodes - 1; i >= 0; --i) {
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
embd = gf->nodes[i];
break;
res = nullptr;
}
else {
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
if (cparams.embeddings || has_mtp) {
for (int i = gf->n_nodes - 1; i >= 0; --i) {
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
embd = gf->nodes[i];
break;
}
if (strcmp(gf->nodes[i]->name, "result_norm") == 0) {
embd = gf->nodes[i];
}
}
}
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) {
res = nullptr; // do not extract logits for embedding case
} else {
if (!embd) { // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
}
GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
#if IK_PRINT_TIMING == 1
tim1 = ggml_time_us();
#endif
@@ -3392,17 +3476,21 @@ static int llama_decode_internal(
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(lctx.logits != nullptr);
// Do not process logits if MTP is only updating the KV cache.
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(lctx.logits != nullptr);
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
const int32_t n_outputs_new = lctx.n_outputs;
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
const int32_t n_outputs_new = lctx.n_outputs;
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
}
}
#if IK_PRINT_TIMING
tim2 = ggml_time_us();
@@ -3411,7 +3499,7 @@ static int llama_decode_internal(
}
// extract embeddings
if (embd) {
if (embd && cparams.mtp_op_type == MTP_OP_NONE) {
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif
@@ -3617,57 +3705,59 @@ static int llama_encode_internal(
// extract embeddings
if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
if (cparams.mtp_op_type == MTP_OP_NONE) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
if (llama_model_has_decoder(&lctx.model)) {
lctx.embd_enc.resize(n_tokens*n_embd);
float * embd_out = lctx.embd_enc.data();
if (llama_model_has_decoder(&lctx.model)) {
lctx.embd_enc.resize(n_tokens*n_embd);
float * embd_out = lctx.embd_enc.data();
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
// remember the sequence ids used during the encoding - needed for cross attention later
lctx.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
lctx.seq_ids_enc[i].insert(seq_id);
}
}
} else {
GGML_ASSERT(lctx.embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
// remember the sequence ids used during the encoding - needed for cross attention later
lctx.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
lctx.seq_ids_enc[i].insert(seq_id);
}
}
} else {
GGML_ASSERT(lctx.embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
}
}
}
}
@@ -4223,6 +4313,7 @@ struct llama_model_params llama_model_default_params() {
/*.validate_quants =*/ false,
/*.merge_qkv =*/ false,
/*.merge_up_gate_exps =*/ false,
/*.mtp =*/ false,
};
#ifdef GGML_USE_METAL
@@ -4278,6 +4369,8 @@ struct llama_context_params llama_context_default_params() {
/*.split_mode_graph_scheduling =*/ false,
// /*.split_mode_f16 =*/ true,
/*.scheduler_async =*/ false,
/*.mtp =*/ false,
/*.mtp_op_type =*/ MTP_OP_NONE,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
/*.offload_policy =*/ nullptr,
@@ -4648,6 +4741,7 @@ struct llama_context * llama_init_from_model(
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.cuda_params = params.cuda_params;
cparams.mtp = params.mtp;
cparams.reduce_type = params.type_reduce;
cparams.pooling_type = params.pooling_type;
@@ -4725,6 +4819,12 @@ struct llama_context * llama_init_from_model(
}
}
if (model->arch != LLM_ARCH_GLM4_MOE && cparams.mtp != 0) {
cparams.mtp = 0;
}
cparams.mtp_op_type = params.mtp_op_type;
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
@@ -6058,7 +6158,7 @@ struct llama_data_read {
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = dest_seq_id;
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) {
llama_batch_free(batch);
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
@@ -7003,6 +7103,10 @@ int32_t llama_decode(
return ret;
}
void llama_set_mtp_op_type(llama_context * ctx, llama_mtp_op_type mtp_op_type) {
ctx->set_mtp_op_type(mtp_op_type);
}
void llama_synchronize(struct llama_context * ctx) {
ggml_backend_sched_synchronize(ctx->sched);
@@ -8333,3 +8437,8 @@ void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_of
printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXX offload(%s) = %d\n", op_name, on_or_off);
ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off);
}
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
ctx->draft_input_hidden_state = hidden_state;
}