mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 14:44:09 +00:00
Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic
This commit is contained in:
committed by
GitHub
parent
cbf7fc7e2f
commit
09a88c9ae5
@@ -1463,6 +1463,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.cuda_params = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "-mtp" || arg == "--multi-token-prediction") {
|
||||
params.has_mtp = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-no-mtp" || arg == "--no-multi-token-prediction") {
|
||||
params.has_mtp = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-draft" || arg == "--draft-params") {
|
||||
CHECK_ARG
|
||||
params.speculative.params = argv[i];
|
||||
@@ -2475,6 +2483,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
|
||||
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
|
||||
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
|
||||
options.push_back({ "*", "-mtp, --multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", params.has_mtp ? "true" : "false" });
|
||||
options.push_back({ "*", "-no-mtp, --no-multi-token-prediction", "whether to use multi-token-prediction (if supported) (default: %s)", !params.has_mtp ? "true" : "false" });
|
||||
options.push_back({ "*", "--draft-max, --draft, --draft-n N",
|
||||
"number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max });
|
||||
options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" });
|
||||
@@ -3207,6 +3217,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
|
||||
mparams.validate_quants = params.validate_quants;
|
||||
mparams.merge_qkv = params.merge_qkv;
|
||||
mparams.merge_up_gate_exps = params.merge_up_gate_exps;
|
||||
mparams.mtp = params.has_mtp;
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
} else {
|
||||
@@ -3329,6 +3340,8 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.only_active_experts = params.only_active_exps;
|
||||
cparams.max_extra_alloc = params.max_extra_alloc_MiB;
|
||||
cparams.mtp = params.has_mtp;
|
||||
cparams.mtp_op_type = MTP_OP_NONE;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||
|
||||
@@ -139,6 +139,7 @@ thinking_tokens thinking_tokens_from_string(const std::string& format);
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_MTP, // MTP model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
@@ -356,6 +357,7 @@ struct gpt_params {
|
||||
bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling
|
||||
//bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
|
||||
bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph
|
||||
bool has_mtp = false; // enable MTP if supported by the model
|
||||
|
||||
std::string cache_type_k = "f16"; // KV cache data type for the K
|
||||
std::string cache_type_v = "f16"; // KV cache data type for the V
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_MTP,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
||||
@@ -31,6 +32,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
||||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
@@ -144,6 +146,58 @@ struct common_speculative_state {
|
||||
virtual void accept(uint16_t n_accepted) = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_state_mtp : public common_speculative_state {
|
||||
llama_context * ctx_tgt;
|
||||
common_sampler * smpl;
|
||||
|
||||
common_speculative_state_mtp(
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
{
|
||||
struct common_params_sampling params;
|
||||
params.samplers_sequence = {
|
||||
llama_sampler_type::DIST,
|
||||
};
|
||||
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
|
||||
}
|
||||
|
||||
~common_speculative_state_mtp() override {
|
||||
common_sampler_free(smpl);
|
||||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
}
|
||||
|
||||
void draft(
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
|
||||
int32_t n_past = (int32_t)prompt_tgt.size();
|
||||
|
||||
llama_seq_id seq_id = 0;
|
||||
|
||||
result = mtp_speculative_gen_draft(
|
||||
smpl,
|
||||
ctx_tgt,
|
||||
params.n_max,
|
||||
params.p_min,
|
||||
id_last,
|
||||
n_past,
|
||||
seq_id
|
||||
);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
@@ -760,6 +814,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
|
||||
@@ -828,6 +883,7 @@ common_speculative * common_speculative_init(
|
||||
{
|
||||
bool has_draft = !params.mparams_dft.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
|
||||
@@ -867,6 +923,9 @@ common_speculative * common_speculative_init(
|
||||
if (has_ngram_cache) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
|
||||
}
|
||||
if (has_draft) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
|
||||
}
|
||||
@@ -890,6 +949,12 @@ common_speculative * common_speculative_init(
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
|
||||
break;
|
||||
@@ -1047,3 +1112,112 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_perf.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// MTP
|
||||
// ----------------------------------------------------------------------------
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
int n_draft,
|
||||
float p_min,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id) {
|
||||
|
||||
llama_tokens drafts;
|
||||
drafts.reserve(n_draft);
|
||||
|
||||
if (!smpl) return drafts;
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
|
||||
|
||||
llama_token current_input_id = id_last;
|
||||
int32_t current_n_past = n_past;
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
|
||||
if (llama_decode(ctx, mtp_batch) != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_sampler_sample(smpl, ctx, 0, true);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
if (!cur_p || cur_p->size == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
const llama_token id_next = cur_p->data[0].id;
|
||||
const float prob = cur_p->data[0].p;
|
||||
|
||||
common_sampler_accept(smpl, nullptr, id_next, true);
|
||||
|
||||
if (prob < p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
drafts.push_back(id_next);
|
||||
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
}
|
||||
llama_batch_free(mtp_batch);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
||||
|
||||
// Purge the metadata for the draft tokens.
|
||||
// This prevents cache state corruption where two cells map to the same logical position.
|
||||
if (!drafts.empty()) {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
|
||||
}
|
||||
|
||||
return drafts;
|
||||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
|
||||
if (batch.n_tokens == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
|
||||
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
|
||||
} else {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = true;
|
||||
}
|
||||
llama_decode(ctx, mtp_batch);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
||||
}
|
||||
|
||||
void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
) {
|
||||
if (ids.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
|
||||
}
|
||||
|
||||
mtp_update_kv_cache(ctx, accepted_batch, false);
|
||||
|
||||
llama_batch_free(accepted_batch);
|
||||
}
|
||||
@@ -39,3 +39,22 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
|
||||
std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
int n_draft,
|
||||
float p_min,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
|
||||
|
||||
void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
);
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
#include <array>
|
||||
#include <functional>
|
||||
|
||||
#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)
|
||||
#define DEFAULT_INTERPOLATION_MODE ((int)GGML_SCALE_MODE_BILINEAR | (int)GGML_SCALE_FLAG_ALIGN_CORNERS)
|
||||
|
||||
// TODO: allow to pass callback from user code
|
||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||
|
||||
@@ -152,12 +152,17 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
|
||||
return false;
|
||||
}
|
||||
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
|
||||
params_base.speculative.model_dft = model_dft;
|
||||
params_base.speculative.cparams_dft = cparams_dft;
|
||||
|
||||
}
|
||||
else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) {
|
||||
LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {});
|
||||
params_base.has_mtp = false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -209,12 +214,35 @@ void server_context::init() {
|
||||
|
||||
slot.sparams = params_base.sparams;
|
||||
|
||||
if (params_base.has_mtp) {
|
||||
if (llama_model_n_nextn_layer(model) > 0) {
|
||||
SRV_INF("%s\n", "MTP detected, configuring for speculative decoding...");
|
||||
|
||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
|
||||
slot.has_mtp = true;
|
||||
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
||||
slot.params.speculative.n_min = 0;
|
||||
|
||||
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
||||
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
||||
|
||||
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
|
||||
llama_set_embeddings(ctx, true);
|
||||
}
|
||||
else {
|
||||
SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative.");
|
||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
slot.has_mtp = false;
|
||||
}
|
||||
}
|
||||
|
||||
const bool can_spec = common_speculative_is_compat(ctx);
|
||||
if (!can_spec) {
|
||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||
}
|
||||
// try speculative decoding
|
||||
if (can_spec){
|
||||
if (can_spec) {
|
||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
||||
if (slot.spec) {
|
||||
if (mctx) {
|
||||
@@ -223,9 +251,14 @@ void server_context::init() {
|
||||
}
|
||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
||||
if (slot.has_mtp) {
|
||||
SRV_ERR("%s", "failed to initialize MTP speculative context\n");
|
||||
} else {
|
||||
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slot.reset();
|
||||
|
||||
slots.push_back(std::move(slot));
|
||||
@@ -380,7 +413,7 @@ void server_slot::add_token_string(const completion_token_output& token) {
|
||||
}
|
||||
|
||||
bool server_slot::can_speculate() const {
|
||||
return !!spec;
|
||||
return (!!spec || has_mtp);
|
||||
}
|
||||
|
||||
int server_slot::get_n_draft_max() const {
|
||||
@@ -2533,6 +2566,15 @@ void server_context::add_sampled_tokens() {
|
||||
|
||||
const auto & params_spec = slot.params.speculative;
|
||||
|
||||
if (slot.has_mtp) {
|
||||
if (!slot.mtp_hidden_state.empty()) {
|
||||
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
|
||||
} else {
|
||||
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1));
|
||||
}
|
||||
}
|
||||
|
||||
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
if (draft.size() > (size_t)n_draft_max) {
|
||||
@@ -2540,13 +2582,6 @@ void server_context::add_sampled_tokens() {
|
||||
draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
/*struct llama_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
llama_tokens draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);*/
|
||||
|
||||
// add the sampled token to the batch
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
@@ -2759,6 +2794,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_buffer = 0;
|
||||
slot.token_buffer.clear();
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("prompt tokenized", {
|
||||
@@ -3092,6 +3129,25 @@ void server_context::speculative_decoding_accept() {
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
||||
|
||||
if (slot.has_mtp) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
if (!ids.empty()) {
|
||||
const float* emb = llama_get_embeddings_ith(ctx, ids.size() - 1);
|
||||
if (emb) {
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
}
|
||||
}
|
||||
else {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0));
|
||||
}
|
||||
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
|
||||
|
||||
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
|
||||
mtp_accept_tokens(ctx, ids, n_past_base, slot.id);
|
||||
}
|
||||
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
@@ -3362,6 +3418,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
|
||||
if (params_base.has_mtp && slot.n_decoded == 0) {
|
||||
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
|
||||
mtp_update_kv_cache(ctx, batch_view, true);
|
||||
const float* emb = llama_get_embeddings_ith(ctx, -1);
|
||||
if (emb) {
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
slot.mtp_hidden_state.resize(n_embd);
|
||||
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
slot.n_decoded += 1;
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
@@ -3394,7 +3461,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
|
||||
if (params_base.has_mtp) {
|
||||
for (auto& slot : slots) {
|
||||
if (slot.n_past < slot.n_prompt_tokens) {
|
||||
if (batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id) {
|
||||
mtp_update_kv_cache(ctx, batch_view, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// speculative decoding - main model sample and accept
|
||||
speculative_decoding_accept();
|
||||
}
|
||||
|
||||
@@ -134,6 +134,9 @@ struct server_slot {
|
||||
struct common_params_sampling sparams;
|
||||
common_sampler * ctx_sampling = nullptr;
|
||||
|
||||
bool has_mtp = false;
|
||||
std::vector<float> mtp_hidden_state;
|
||||
|
||||
// speculative decoding stats
|
||||
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||
|
||||
@@ -279,6 +279,12 @@ extern "C" {
|
||||
LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs
|
||||
};
|
||||
|
||||
enum llama_mtp_op_type {
|
||||
MTP_OP_NONE = 0,
|
||||
MTP_OP_WARMUP = 1,
|
||||
MTP_OP_UPDATE_ACCEPTED = 2,
|
||||
MTP_OP_DRAFT_GEN = 3,
|
||||
};
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
@@ -394,6 +400,7 @@ extern "C" {
|
||||
bool validate_quants; // if true, check for NaNs while loading the model
|
||||
bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
|
||||
bool merge_up_gate_exps; // if true, merge ffn_up_exps and ffn_gate_exps tensors into a single, contiguous tensor
|
||||
bool mtp; // if true, load MTP layers if present
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
@@ -449,6 +456,8 @@ extern "C" {
|
||||
bool split_mode_graph_scheduling; // if true, force split mode graph scheduling
|
||||
//bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs
|
||||
bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads
|
||||
bool mtp; // Activate MTP if supported
|
||||
enum llama_mtp_op_type mtp_op_type;
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
@@ -1463,6 +1472,17 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
|
||||
|
||||
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// MTP
|
||||
//
|
||||
|
||||
LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);
|
||||
|
||||
// Set which, if any, MTP operation the context will use
|
||||
LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type mtp_op_type);
|
||||
|
||||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -303,6 +303,25 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_tensor * llm_build_context::build_inp_embd_mtp(struct ggml_tensor * mtp_tok_embd) {
|
||||
struct ggml_tensor * cur = nullptr;
|
||||
|
||||
if (batch.token) {
|
||||
lctx.inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch.n_tokens);
|
||||
|
||||
cb(lctx.inp_tokens, "inp_tokens", -1);
|
||||
ggml_set_input(lctx.inp_tokens);
|
||||
|
||||
cur = ggml_get_rows(ctx0, mtp_tok_embd, lctx.inp_tokens);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cb(cur, "inp_embd", -1);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_context::build_inp_pos() {
|
||||
int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
|
||||
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, int64_t(n_tokens)*n_pos_per_embd);
|
||||
@@ -415,7 +434,10 @@ ggml_cgraph * llm_build_context::append_pooling(struct ggml_cgraph * gf) {
|
||||
struct ggml_tensor * inp = nullptr;
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
inp = gf->nodes[i];
|
||||
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
|
||||
|
||||
if (strcmp(inp->name, "result_norm") == 0 ||
|
||||
strcmp(inp->name, "result_embd") == 0 ||
|
||||
strcmp(inp->name, "output_normed") == 0) {
|
||||
break;
|
||||
}
|
||||
inp = nullptr;
|
||||
@@ -7372,138 +7394,281 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// input embeddings
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
ggml_tensor * cur;
|
||||
|
||||
// position embeddings
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// attention KV cache input
|
||||
//auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
// output token IDs (for last layer cropping)
|
||||
struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
|
||||
|
||||
auto rope_cache = model.split_mode != LLAMA_SPLIT_MODE_GRAPH && cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
ggml_tensor* hidden_states_from_main_model;
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// self-attention
|
||||
if (rope_cache == nullptr) {
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL,
|
||||
inp_pos, il == n_transformer_layers - 1 ? inp_out_ids : nullptr, nullptr,
|
||||
KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true);
|
||||
if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
|
||||
} else {
|
||||
// Pre-attention norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
|
||||
}
|
||||
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
|
||||
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
|
||||
model.layers[il].wqkv, model.layers[il].bqkv,
|
||||
model.layers[il].wqk, model.layers[il].bqk,
|
||||
model.layers[il].wq, model.layers[il].bq,
|
||||
model.layers[il].wk, model.layers[il].bk,
|
||||
model.layers[il].wv, model.layers[il].bv,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
|
||||
lctx.inp_mtp_states = hidden_states_from_main_model;
|
||||
|
||||
// apply RoPE
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
|
||||
cur = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head, gf, inp_pos, rope_cache);
|
||||
|
||||
} else {
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// input embeddings
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
// output token IDs (for last layer cropping)
|
||||
struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
|
||||
|
||||
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// self-attention
|
||||
if (rope_cache == nullptr) {
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL,
|
||||
inp_pos, il == n_transformer_layers - 1 ? inp_out_ids : nullptr, nullptr,
|
||||
KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
// Pre-attention norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// build attention KV (no unified cache)
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask,
|
||||
n_tokens, kv_head, n_kv,
|
||||
1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
|
||||
model.layers[il].wqkv, model.layers[il].bqkv,
|
||||
model.layers[il].wqk, model.layers[il].bqk,
|
||||
model.layers[il].wq, model.layers[il].bq,
|
||||
model.layers[il].wk, model.layers[il].bk,
|
||||
model.layers[il].wv, model.layers[il].bv,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
// skip computing output for unused tokens
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
// apply RoPE
|
||||
if (rope_cache) {
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// build attention KV (no unified cache)
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask,
|
||||
n_tokens, kv_head, n_kv,
|
||||
1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
// skip computing output for unused tokens
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
if (rope_cache) {
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// crop output on last layer
|
||||
|
||||
// residual connection for attention output
|
||||
ggml_tensor * ffn_inp;
|
||||
if (rope_cache) {
|
||||
ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
} else {
|
||||
ffn_inp = cur;
|
||||
}
|
||||
|
||||
if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
||||
// dense FFN
|
||||
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
|
||||
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
|
||||
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases?
|
||||
model.layers[il].ffn_gate_shexp, nullptr,
|
||||
model.layers[il].ffn_down_shexp, nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale,
|
||||
(llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps);
|
||||
}
|
||||
|
||||
// residual and context vector
|
||||
//cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// prepare next layer input
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
// crop output on last layer
|
||||
|
||||
// residual connection for attention output
|
||||
ggml_tensor * ffn_inp;
|
||||
if (rope_cache) {
|
||||
ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
} else {
|
||||
ffn_inp = cur;
|
||||
}
|
||||
|
||||
if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
||||
// dense FFN
|
||||
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
|
||||
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
|
||||
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases?
|
||||
model.layers[il].ffn_gate_shexp, nullptr,
|
||||
model.layers[il].ffn_down_shexp, nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale,
|
||||
(llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps);
|
||||
}
|
||||
|
||||
// residual and context vector
|
||||
//cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// prepare next layer input
|
||||
inpL = cur;
|
||||
// lm head
|
||||
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||
cb(cur, "result_output", -1);
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
// lm head
|
||||
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_tensor * llm_build_context::build_mtp_tail(
|
||||
const llama_layer & mtp_layer,
|
||||
struct ggml_tensor * prev_embeddings,
|
||||
int64_t n_embd_head,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inp_pos,
|
||||
struct ggml_tensor * rope_cache
|
||||
) {
|
||||
const int il = hparams.n_layer - 1;
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
// If nextn.embed_tokens is missing (GLM-4.6), use model.tok_embd
|
||||
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
|
||||
if (mtp_embd_weights == nullptr) {
|
||||
mtp_embd_weights = model.tok_embd;
|
||||
}
|
||||
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
|
||||
|
||||
ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il);
|
||||
ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_embeddings, hparams, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il);
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
|
||||
cb(combined, "mtp_concat", il);
|
||||
ggml_tensor* cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined);
|
||||
|
||||
struct ggml_tensor * inpSA = cur;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// Self-Attention
|
||||
{
|
||||
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
|
||||
nullptr, nullptr, // wqkv, bqkv (not used in GLM usually?)
|
||||
nullptr, nullptr, // wqk, bqk
|
||||
mtp_layer.wq, mtp_layer.bq,
|
||||
mtp_layer.wk, mtp_layer.bk,
|
||||
mtp_layer.wv, mtp_layer.bv,
|
||||
mtp_layer.attn_q_norm, mtp_layer.attn_k_norm,
|
||||
0.f, il);
|
||||
|
||||
// RoPE
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// KV Cache & Attention
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask,
|
||||
n_tokens, kv_head, n_kv,
|
||||
1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "mtp_ffn_inp", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
// moe ffn for nextn block
|
||||
{
|
||||
// Routed Experts
|
||||
ggml_tensor * routed_out = llm_build_std_moe_ffn(ctx0, lctx,
|
||||
NULL, // Norm handled above
|
||||
cur, // Input (Normed)
|
||||
mtp_layer.ffn_gate_inp, NULL,
|
||||
mtp_layer.ffn_up_exps, NULL,
|
||||
mtp_layer.ffn_gate_exps, NULL,
|
||||
mtp_layer.ffn_down_exps, NULL,
|
||||
mtp_layer.ffn_exp_probs_b,
|
||||
nullptr, nullptr, // we don't have shared expert biases?
|
||||
nullptr, nullptr,
|
||||
nullptr, nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale,
|
||||
(llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
LLM_FFN_SILU, cb, il, gf, true, mtp_layer.ffn_up_gate_exps);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
// Shared Expert FFN
|
||||
ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx,
|
||||
NULL, // Norm handled above
|
||||
cur, // Input
|
||||
mtp_layer.ffn_up_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_gate_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
|
||||
cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
// Sum and Residual
|
||||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "mtp_ffn_out_resid", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
|
||||
if (inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
// If nextn.shared_head_head is missing (GLM-4.6), use model.output (Main LM Head)
|
||||
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
|
||||
if (mtp_head_weights == nullptr) {
|
||||
mtp_head_weights = model.output;
|
||||
}
|
||||
cur = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::build_bitnet() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
@@ -9836,7 +10001,8 @@ ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
}
|
||||
|
||||
// add on pooling layer
|
||||
if (lctx.cparams.embeddings) {
|
||||
if (lctx.cparams.mtp_op_type == MTP_OP_NONE && (lctx.cparams.embeddings ||
|
||||
(lctx.model.hparams.nextn_predict_layers > 0 || lctx.model.mtp))) {
|
||||
result = llm.append_pooling(result);
|
||||
}
|
||||
|
||||
@@ -10178,3 +10344,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
int32_t llama_model_n_nextn_layer(const llama_model * model) {
|
||||
return model->hparams.nextn_predict_layers;
|
||||
}
|
||||
|
||||
@@ -114,6 +114,8 @@ struct llm_build_context {
|
||||
|
||||
ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids);
|
||||
|
||||
struct ggml_tensor * build_inp_embd_mtp(struct ggml_tensor * mtp_tok_embd);
|
||||
|
||||
ggml_tensor * build_inp_pos();
|
||||
|
||||
ggml_tensor * build_input_scale(int n_tokens);
|
||||
@@ -430,4 +432,12 @@ llm_expert_gating_func_type gating_op,
|
||||
bool is_multi = false);
|
||||
|
||||
static uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & kv_self);
|
||||
struct ggml_tensor * build_mtp_tail(
|
||||
const struct llama_layer & mtp_layer,
|
||||
struct ggml_tensor * prev_embeddings,
|
||||
int64_t n_embd_head,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inp_pos,
|
||||
struct ggml_tensor * rope_cache
|
||||
);
|
||||
};
|
||||
|
||||
@@ -191,6 +191,8 @@ struct llama_context {
|
||||
ggml_abort_callback abort_callback = nullptr;
|
||||
void * abort_callback_data = nullptr;
|
||||
|
||||
const float * draft_input_hidden_state = nullptr;
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
@@ -209,6 +211,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens]
|
||||
struct ggml_tensor * inp_mtp_states = nullptr;
|
||||
|
||||
ggml_backend_t ggml_backend_by_name(const char * name);
|
||||
|
||||
@@ -225,5 +228,7 @@ struct llama_context {
|
||||
std::vector<CacheCopy> cache_copies;
|
||||
|
||||
bool update_cache_copies();
|
||||
|
||||
bool prepare_mtp_graph_inputs(
|
||||
struct llama_context & lctx);
|
||||
void set_mtp_op_type(llama_mtp_op_type value);
|
||||
};
|
||||
|
||||
@@ -45,9 +45,11 @@ struct llama_cparams {
|
||||
bool scheduler_async;
|
||||
int min_experts;
|
||||
float thresh_experts;
|
||||
bool mtp;
|
||||
|
||||
enum ggml_type reduce_type;
|
||||
enum llama_pooling_type pooling_type;
|
||||
enum llama_mtp_op_type mtp_op_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
||||
@@ -903,6 +903,12 @@ void llm_load_hparams(
|
||||
}
|
||||
|
||||
// NextN/MTP parameters
|
||||
if (model.mtp) {
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
}
|
||||
else {
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
}
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
|
||||
@@ -2278,9 +2278,12 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
int flags = 0;
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
// skip all tensors in the NextN layers
|
||||
flags |= llama_model_loader::TENSOR_SKIP;
|
||||
// Skip loading MTP layers if the feature is disabled
|
||||
if (!model.mtp) {
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
// skip all tensors in the NextN layers
|
||||
flags |= llama_model_loader::TENSOR_SKIP;
|
||||
}
|
||||
}
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
@@ -3481,7 +3484,8 @@ bool create_tensors_helper::create_tensors() {
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
const int n_layer = model.layers.size() - model.hparams.nextn_predict_layers;
|
||||
const int n_layer = model.mtp ? model.layers.size()
|
||||
: model.layers.size() - model.hparams.nextn_predict_layers;
|
||||
LLAMA_LOG_INFO("================================ max_gpu = %d\n", model.max_gpu);
|
||||
std::vector<size_t> mem_used(model.splits.size(), 0);
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@@ -374,6 +374,8 @@ struct llama_model {
|
||||
int max_gpu = 0; // max. number of GPUs to use per layer for aplit mode "graph"
|
||||
int n_gpu_layers;
|
||||
|
||||
bool mtp; // use mtp if is supported by the Model
|
||||
|
||||
std::vector<rpc_device> rpc_servers;
|
||||
std::vector<int32_t> devices;
|
||||
|
||||
|
||||
271
src/llama.cpp
271
src/llama.cpp
@@ -546,6 +546,7 @@ struct llama_context::Prev {
|
||||
int all_seq_id;
|
||||
int n_outputs;
|
||||
int n_kv;
|
||||
llama_mtp_op_type mtp_op_type;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
@@ -563,11 +564,13 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
kv_self.head > 0 &&
|
||||
kv_self.n == prev->n_kv &&
|
||||
n_outputs == prev->n_outputs &&
|
||||
cparams.mtp_op_type == prev->mtp_op_type &&
|
||||
update_cache_copies();
|
||||
}
|
||||
|
||||
bool llama_context::update_cache_copies() {
|
||||
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
const int n_layer = model.mtp ? model.hparams.n_layer
|
||||
: model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
auto layer_has_attention_kv = [&](int il) {
|
||||
return !((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && model.hparams.is_recurrent(il));
|
||||
};
|
||||
@@ -638,6 +641,12 @@ llama_context::llama_context(const llama_model & model)
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::set_mtp_op_type(llama_mtp_op_type value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
cparams.mtp_op_type = value;
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
ggml_backend_sched_free(sched);
|
||||
|
||||
@@ -716,7 +725,8 @@ static bool llama_kv_cache_init(
|
||||
|
||||
const struct llama_hparams & hparams = model.hparams;
|
||||
|
||||
const int64_t n_layer = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
const int64_t n_layer = model.mtp ? hparams.n_layer
|
||||
: hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
cache.has_shift = false;
|
||||
|
||||
@@ -993,7 +1003,8 @@ static bool llama_kv_cache_init(
|
||||
// to the first cell of the slot.
|
||||
static bool llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_batch & batch) {
|
||||
const struct llama_batch & batch,
|
||||
enum llama_mtp_op_type op_type) {
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
if (cache.recurrent) {
|
||||
@@ -1044,6 +1055,45 @@ static bool llama_kv_cache_find_slot(
|
||||
}
|
||||
// otherwise, one cell per token.
|
||||
|
||||
bool is_mtp_special_op = (op_type == MTP_OP_WARMUP ||
|
||||
op_type == MTP_OP_UPDATE_ACCEPTED);
|
||||
if (is_mtp_special_op) {
|
||||
const llama_pos target_pos = batch.pos[0];
|
||||
const llama_seq_id target_seq = batch.seq_id[0][0];
|
||||
|
||||
bool found = false;
|
||||
|
||||
if (cache.head < cache.size &&
|
||||
cache.cells[cache.head].pos == target_pos &&
|
||||
cache.cells[cache.head].has_seq_id(target_seq)) {
|
||||
found = true;
|
||||
}
|
||||
else {
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].pos == target_pos &&
|
||||
cache.cells[i].has_seq_id(target_seq)) {
|
||||
|
||||
cache.head = i;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n",
|
||||
__func__, target_seq, target_pos);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cache.head + n_tokens > cache.size) {
|
||||
LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (n_tokens > cache.size) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
||||
return false;
|
||||
@@ -1893,6 +1943,7 @@ static bool llm_load_tensors(
|
||||
const float * tensor_split,
|
||||
bool use_mlock,
|
||||
bool validate_quants,
|
||||
bool mtp,
|
||||
llama_progress_callback progress_callback,
|
||||
void * progress_callback_user_data) {
|
||||
model.t_start_us = ggml_time_us();
|
||||
@@ -1921,6 +1972,7 @@ static bool llm_load_tensors(
|
||||
model.main_gpu = main_gpu;
|
||||
model.max_gpu = max_gpu;
|
||||
model.n_gpu_layers = n_gpu_layers;
|
||||
model.mtp = mtp;
|
||||
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
|
||||
@@ -2300,7 +2352,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||
|
||||
if (!llm_load_tensors(
|
||||
ml, model, params.n_gpu_layers, params.mla, params.split_mode, params.main_gpu, params.max_gpu, params.tensor_split,
|
||||
params.use_mlock, params.validate_quants,
|
||||
params.use_mlock, params.validate_quants, params.mtp,
|
||||
params.progress_callback, params.progress_callback_user_data
|
||||
)) {
|
||||
return -2;
|
||||
@@ -2969,8 +3021,9 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
const bool has_logits = !cparams.embeddings;
|
||||
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
|
||||
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp;
|
||||
const bool has_logits = !cparams.embeddings || has_mtp;
|
||||
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp;
|
||||
|
||||
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
@@ -3049,6 +3102,24 @@ static void llama_graph_compute(
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
}
|
||||
|
||||
static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
|
||||
ggml_tensor * dst = lctx.inp_mtp_states;
|
||||
const float * src = nullptr;
|
||||
if (lctx.cparams.mtp_op_type == MTP_OP_WARMUP || lctx.cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
src = lctx.embd;
|
||||
} else {
|
||||
src = lctx.draft_input_hidden_state;
|
||||
}
|
||||
|
||||
if (!src) {
|
||||
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(dst, src, 0, ggml_nbytes(dst));
|
||||
return true;
|
||||
}
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
@@ -3260,7 +3331,7 @@ static int llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -3322,37 +3393,50 @@ static int llama_decode_internal(
|
||||
#endif
|
||||
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf});
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
cparams.mtp_op_type, gf});
|
||||
}
|
||||
} else {
|
||||
//printf("Reusing graph\n");
|
||||
gf = lctx.prev->graph;
|
||||
}
|
||||
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
if (!prepare_mtp_graph_inputs(lctx)) {
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
||||
struct ggml_tensor * embd = nullptr;
|
||||
|
||||
if (lctx.n_outputs == 0) {
|
||||
// no output
|
||||
res = nullptr;
|
||||
embd = nullptr;
|
||||
} else if (cparams.embeddings) {
|
||||
res = nullptr; // do not extract logits for embedding case
|
||||
embd = nullptr;
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
|
||||
embd = gf->nodes[i];
|
||||
break;
|
||||
res = nullptr;
|
||||
}
|
||||
else {
|
||||
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
|
||||
if (cparams.embeddings || has_mtp) {
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
|
||||
embd = gf->nodes[i];
|
||||
break;
|
||||
}
|
||||
if (strcmp(gf->nodes[i]->name, "result_norm") == 0) {
|
||||
embd = gf->nodes[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) {
|
||||
res = nullptr; // do not extract logits for embedding case
|
||||
} else {
|
||||
if (!embd) { // do not extract embeddings when not needed
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
|
||||
} else {
|
||||
embd = nullptr; // do not extract embeddings when not needed
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
#if IK_PRINT_TIMING == 1
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
@@ -3392,17 +3476,21 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
// Do not process logits if MTP is only updating the KV cache.
|
||||
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
|
||||
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
|
||||
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
@@ -3411,7 +3499,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (embd) {
|
||||
if (embd && cparams.mtp_op_type == MTP_OP_NONE) {
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
@@ -3617,57 +3705,59 @@ static int llama_encode_internal(
|
||||
|
||||
// extract embeddings
|
||||
if (embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
if (cparams.mtp_op_type == MTP_OP_NONE) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
if (llama_model_has_decoder(&lctx.model)) {
|
||||
lctx.embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = lctx.embd_enc.data();
|
||||
if (llama_model_has_decoder(&lctx.model)) {
|
||||
lctx.embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = lctx.embd_enc.data();
|
||||
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
lctx.seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
lctx.seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
float * embd_out = lctx.embd;
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
lctx.seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
lctx.seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
float * embd_out = lctx.embd;
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4223,6 +4313,7 @@ struct llama_model_params llama_model_default_params() {
|
||||
/*.validate_quants =*/ false,
|
||||
/*.merge_qkv =*/ false,
|
||||
/*.merge_up_gate_exps =*/ false,
|
||||
/*.mtp =*/ false,
|
||||
};
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
@@ -4278,6 +4369,8 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.split_mode_graph_scheduling =*/ false,
|
||||
// /*.split_mode_f16 =*/ true,
|
||||
/*.scheduler_async =*/ false,
|
||||
/*.mtp =*/ false,
|
||||
/*.mtp_op_type =*/ MTP_OP_NONE,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
/*.offload_policy =*/ nullptr,
|
||||
@@ -4648,6 +4741,7 @@ struct llama_context * llama_init_from_model(
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.cuda_params = params.cuda_params;
|
||||
cparams.mtp = params.mtp;
|
||||
|
||||
cparams.reduce_type = params.type_reduce;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
@@ -4725,6 +4819,12 @@ struct llama_context * llama_init_from_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (model->arch != LLM_ARCH_GLM4_MOE && cparams.mtp != 0) {
|
||||
cparams.mtp = 0;
|
||||
}
|
||||
|
||||
cparams.mtp_op_type = params.mtp_op_type;
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
@@ -6058,7 +6158,7 @@ struct llama_data_read {
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = dest_seq_id;
|
||||
}
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) {
|
||||
llama_batch_free(batch);
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
@@ -7003,6 +7103,10 @@ int32_t llama_decode(
|
||||
return ret;
|
||||
}
|
||||
|
||||
void llama_set_mtp_op_type(llama_context * ctx, llama_mtp_op_type mtp_op_type) {
|
||||
ctx->set_mtp_op_type(mtp_op_type);
|
||||
}
|
||||
|
||||
void llama_synchronize(struct llama_context * ctx) {
|
||||
ggml_backend_sched_synchronize(ctx->sched);
|
||||
|
||||
@@ -8333,3 +8437,8 @@ void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_of
|
||||
printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXX offload(%s) = %d\n", op_name, on_or_off);
|
||||
ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off);
|
||||
}
|
||||
|
||||
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
|
||||
ctx->draft_input_hidden_state = hidden_state;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user