mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-11 08:30:19 +00:00
MTP tweaks (#1741)
This commit is contained in:
@@ -1472,13 +1472,19 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
|
||||
// We don't need the logits when doing warmup
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = false;
|
||||
}
|
||||
// This is just in case to not run into empty tensor issues
|
||||
mtp_batch.logits[mtp_batch.n_tokens-1] = true;
|
||||
} else {
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = true;
|
||||
}
|
||||
llama_decode(ctx, mtp_batch);
|
||||
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
|
||||
}
|
||||
|
||||
@@ -128,9 +128,11 @@ ggml_cgraph * llm_build_context::build_qwen35() {
|
||||
}
|
||||
|
||||
if (lctx.cparams.mtp) {
|
||||
struct ggml_tensor * embd_copy = ggml_dup(ctx0, inpL);
|
||||
cb(embd_copy, "result_mtp_embd", -1);
|
||||
ggml_set_output(embd_copy);
|
||||
//struct ggml_tensor * embd_copy = ggml_dup(ctx0, inpL);
|
||||
//cb(embd_copy, "result_mtp_embd", -1);
|
||||
//ggml_set_output(embd_copy);
|
||||
cb(inpL, "result_mtp_embd", -1);
|
||||
ggml_set_output(inpL);
|
||||
}
|
||||
|
||||
cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb);
|
||||
@@ -153,7 +155,7 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
struct ggml_tensor * inp_out_ids = (n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
|
||||
struct ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
|
||||
|
||||
ggml_tensor * token_emb = build_inp_embd_mtp(model.tok_embd);
|
||||
|
||||
@@ -210,10 +212,12 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
// As far as I can tell this was wrong. We need the FFN output, and not the normalized result.
|
||||
//cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = build_output(lctx, ctx0, cur, model.output, nullptr, cb);
|
||||
//cur = build_output(lctx, ctx0, cur, model.output, nullptr, cb);
|
||||
cur = build_output(lctx, ctx0, cur, model.output, mtp_layer.nextn.shared_head_norm, cb);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
return cur;
|
||||
|
||||
@@ -70,7 +70,8 @@ delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_
|
||||
GGML_ASSERT((uint32_t) s < qnext_state_slots);
|
||||
}
|
||||
|
||||
save_per_step_states = lctx.kv_self.save_per_step_ssm && batch.n_tokens > 1;
|
||||
int max_per_step = lctx.kv_self.save_per_step_ssm ? std::min<int>(8, lctx.kv_self.ckpt.per_step_max_allocated) : 0;
|
||||
save_per_step_states = lctx.kv_self.save_per_step_ssm && batch.n_tokens > 1 && batch.n_tokens <= max_per_step;
|
||||
}
|
||||
|
||||
delta_net::~delta_net() = default;
|
||||
|
||||
@@ -4165,14 +4165,22 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||
// set all ids as invalid (negative)
|
||||
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
|
||||
|
||||
if (has_mtp) {
|
||||
// MTP uses a large output footprint, clear only the active region.
|
||||
const size_t clear_size = (logits_size + embd_size) * sizeof(float);
|
||||
if (clear_size > 0 && output_base) {
|
||||
memset(output_base, 0, clear_size);
|
||||
if (false) {
|
||||
// What is the purpose of clearing the output buffer?
|
||||
// When we are getting embeddings for models with large vocabularies this
|
||||
// costs a non-negligible amount of time.
|
||||
// The output buffer will get populated with meaningful results in llama_decode
|
||||
// If it doesn't, the solution is not to just blindly zero the buffer
|
||||
// but to fix the bug that causes meaningless results.
|
||||
if (has_mtp) {
|
||||
// MTP uses a large output footprint, clear only the active region.
|
||||
const size_t clear_size = (logits_size + embd_size) * sizeof(float);
|
||||
if (clear_size > 0 && output_base) {
|
||||
memset(output_base, 0, clear_size);
|
||||
}
|
||||
} else {
|
||||
ggml_backend_buffer_clear(lctx.buf_output, 0);
|
||||
}
|
||||
} else {
|
||||
ggml_backend_buffer_clear(lctx.buf_output, 0);
|
||||
}
|
||||
|
||||
lctx.n_outputs = 0;
|
||||
@@ -6390,6 +6398,38 @@ struct llama_context * llama_init_from_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (cparams.mtp && hparams.nextn_predict_layers > 0) {
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = hparams.n_vocab;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
const size_t logits_size = n_vocab*n_batch;
|
||||
const size_t embd_size = n_embd*n_batch;
|
||||
|
||||
if (ctx->output_ids.empty()) {
|
||||
// init, never resized afterwards
|
||||
ctx->output_ids.resize(n_batch);
|
||||
}
|
||||
|
||||
const size_t prev_size = ctx->buf_output ? ggml_backend_buffer_get_size(ctx->buf_output) : 0;
|
||||
const size_t new_size = (logits_size + embd_size) * sizeof(float);
|
||||
|
||||
// alloc only when more than the current capacity is required
|
||||
if (!ctx->buf_output || prev_size < new_size) {
|
||||
if (ctx->buf_output) {
|
||||
ggml_backend_buffer_free(ctx->buf_output);
|
||||
ctx->buf_output = nullptr;
|
||||
ctx->logits = nullptr;
|
||||
ctx->embd = nullptr;
|
||||
}
|
||||
|
||||
ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
|
||||
if (ctx->buf_output == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user