MTP tweaks (#1741)

This commit is contained in:
Kawrakow
2026-05-06 08:35:11 +03:00
committed by GitHub
parent 8b56d813a9
commit e722f0bb73
4 changed files with 68 additions and 17 deletions

View File

@@ -1472,13 +1472,19 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
llama_batch mtp_batch = batch;
if (is_prompt_warmup) {
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
// We don't need the logits when doing warmup
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = false;
}
// This is just in case to not run into empty tensor issues
mtp_batch.logits[mtp_batch.n_tokens-1] = true;
} else {
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
}
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
llama_decode(ctx, mtp_batch);
llama_set_mtp_op_type(ctx, MTP_OP_NONE);
}

View File

@@ -128,9 +128,11 @@ ggml_cgraph * llm_build_context::build_qwen35() {
}
if (lctx.cparams.mtp) {
struct ggml_tensor * embd_copy = ggml_dup(ctx0, inpL);
cb(embd_copy, "result_mtp_embd", -1);
ggml_set_output(embd_copy);
//struct ggml_tensor * embd_copy = ggml_dup(ctx0, inpL);
//cb(embd_copy, "result_mtp_embd", -1);
//ggml_set_output(embd_copy);
cb(inpL, "result_mtp_embd", -1);
ggml_set_output(inpL);
}
cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb);
@@ -153,7 +155,7 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * inp_out_ids = (n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
struct ggml_tensor * inp_out_ids = (n_tokens > 1 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
ggml_tensor * token_emb = build_inp_embd_mtp(model.tok_embd);
@@ -210,10 +212,12 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "ffn_out", il);
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
// As far as I can tell this was wrong. We need the FFN output, and not the normalized result.
//cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "result_norm", -1);
cur = build_output(lctx, ctx0, cur, model.output, nullptr, cb);
//cur = build_output(lctx, ctx0, cur, model.output, nullptr, cb);
cur = build_output(lctx, ctx0, cur, model.output, mtp_layer.nextn.shared_head_norm, cb);
cb(cur, "result_output", -1);
return cur;

View File

@@ -70,7 +70,8 @@ delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_
GGML_ASSERT((uint32_t) s < qnext_state_slots);
}
save_per_step_states = lctx.kv_self.save_per_step_ssm && batch.n_tokens > 1;
int max_per_step = lctx.kv_self.save_per_step_ssm ? std::min<int>(8, lctx.kv_self.ckpt.per_step_max_allocated) : 0;
save_per_step_states = lctx.kv_self.save_per_step_ssm && batch.n_tokens > 1 && batch.n_tokens <= max_per_step;
}
delta_net::~delta_net() = default;

View File

@@ -4165,14 +4165,22 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
// set all ids as invalid (negative)
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
if (has_mtp) {
// MTP uses a large output footprint, clear only the active region.
const size_t clear_size = (logits_size + embd_size) * sizeof(float);
if (clear_size > 0 && output_base) {
memset(output_base, 0, clear_size);
if (false) {
// What is the purpose of clearing the output buffer?
// When we are getting embeddings for models with large vocabularies this
// costs a non-negligible amount of time.
// The output buffer will get populated with meaningful results in llama_decode
// If it doesn't, the solution is not to just blindly zero the buffer
// but to fix the bug that causes meaningless results.
if (has_mtp) {
// MTP uses a large output footprint, clear only the active region.
const size_t clear_size = (logits_size + embd_size) * sizeof(float);
if (clear_size > 0 && output_base) {
memset(output_base, 0, clear_size);
}
} else {
ggml_backend_buffer_clear(lctx.buf_output, 0);
}
} else {
ggml_backend_buffer_clear(lctx.buf_output, 0);
}
lctx.n_outputs = 0;
@@ -6390,6 +6398,38 @@ struct llama_context * llama_init_from_model(
}
}
if (cparams.mtp && hparams.nextn_predict_layers > 0) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab;
const auto n_embd = hparams.n_embd;
const size_t logits_size = n_vocab*n_batch;
const size_t embd_size = n_embd*n_batch;
if (ctx->output_ids.empty()) {
// init, never resized afterwards
ctx->output_ids.resize(n_batch);
}
const size_t prev_size = ctx->buf_output ? ggml_backend_buffer_get_size(ctx->buf_output) : 0;
const size_t new_size = (logits_size + embd_size) * sizeof(float);
// alloc only when more than the current capacity is required
if (!ctx->buf_output || prev_size < new_size) {
if (ctx->buf_output) {
ggml_backend_buffer_free(ctx->buf_output);
ctx->buf_output = nullptr;
ctx->logits = nullptr;
ctx->embd = nullptr;
}
ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
if (ctx->buf_output == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
}
}
}
return ctx;
}