server : support MTP with multimodal prompts (#1758)

Synchronize MTP state after mtmd decode batches so multimodal prompt chunks do not desync the draft context.
This commit is contained in:
Lingfeng Ren
2026-05-10 23:51:07 -07:00
committed by GitHub
parent 23127139cb
commit 35845dd975
8 changed files with 185 additions and 73 deletions

View File

@@ -115,14 +115,15 @@ static void apply_slot_mtp_accept(
return;
}
llama_context * mtp_ctx = get_slot_mtp_ctx(slot, ctx);
if (slot.use_gemma4_external_mtp) {
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, mtp_hidden_state, n_embd);
return;
}
slot.mtp_hidden_state = mtp_hidden_state;
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, mtp_n_past_base, slot.id);
llama_set_draft_input_hidden_state(mtp_ctx, slot.mtp_hidden_state.data());
mtp_accept_tokens(mtp_ctx, ids, mtp_n_past_base, slot.id);
}
static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) {
@@ -133,8 +134,53 @@ static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, con
cache_and_sync_slot_mtp_hidden(slot, ctx, hidden, n_embd);
}
static void set_external_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector<float> & rows, int n_embd) {
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, rows, n_embd);
struct server_mtp_warmup {
llama_context * ctx_tgt;
server_slot * slot;
};
static int32_t server_mtp_warmup_batch(
llama_context * ctx_tgt,
llama_context * ctx_mtp,
const llama_batch * batch,
server_slot & slot) {
if (!ctx_tgt || !ctx_mtp || !batch || batch->n_tokens <= 0) {
return 0;
}
const float * emb = llama_get_embeddings(ctx_tgt);
const int n_embd_src = get_ctx_mtp_n_embd(ctx_tgt);
const int n_embd_dst = get_ctx_mtp_n_embd(ctx_mtp);
if (emb == nullptr || n_embd_src <= 0 || n_embd_dst <= 0) {
return -1;
}
if (n_embd_src != n_embd_dst) {
LOG_ERROR("MTP warmup hidden state width mismatch", {
{"n_embd_src", n_embd_src},
{"n_embd_dst", n_embd_dst},
});
return -1;
}
const float * last_hidden = emb + (batch->n_tokens - 1) * n_embd_src;
if (slot.use_gemma4_external_mtp) {
cache_and_sync_slot_mtp_hidden(slot, ctx_tgt, last_hidden, n_embd_dst);
return 0;
}
cache_slot_mtp_hidden(slot, last_hidden, n_embd_dst);
llama_set_draft_input_hidden_state(ctx_mtp, emb);
return mtp_update_kv_cache(ctx_mtp, *batch, true);
}
static int32_t server_mtp_media_warmup_callback(void * user_data, const llama_batch * batch) {
auto * data = static_cast<server_mtp_warmup *>(user_data);
if (data == nullptr || data->slot == nullptr) {
return 0;
}
return server_mtp_warmup_batch(data->ctx_tgt, get_slot_mtp_ctx(*data->slot, data->ctx_tgt), batch, *data->slot);
}
void server_speculative_checkpoint::clear() {
@@ -156,7 +202,8 @@ static void discard_speculative_checkpoint(server_slot & slot, llama_context * c
static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) {
slot.spec_ckpt.clear();
slot.spec_ckpt.n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1);
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
slot.spec_ckpt.n_past = slot.cache_tokens.pos_next(n_pre_spec_tokens);
slot.spec_ckpt.sampled = slot.sampled;
const int max_tokens = (int)slot.drafted.size() + 1;
@@ -266,7 +313,8 @@ bool server_context::load_model(const gpt_params& params_) {
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
return false;
}
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE &&
params_base.speculative.type != COMMON_SPECULATIVE_TYPE_MTP) {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
}
@@ -417,7 +465,7 @@ void server_context::init() {
if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
if (mctx && !slot.has_mtp) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return;
}
@@ -3366,12 +3414,15 @@ void server_context::add_sampled_tokens() {
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max_pre = slot.get_n_draft_max();
if (n_draft_max_pre > 0) {
if (mctx) {
if (mctx && !slot.has_mtp) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
static const llama_tokens empty_prompt;
const llama_tokens & cached_text_tokens = slot.has_mtp
? empty_prompt
: slot.cache_tokens.get_text_tokens();
auto & params_spec = slot.params.speculative;
@@ -3797,7 +3848,15 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
// process the image
size_t n_tokens_out = 0;
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out);
server_mtp_warmup mtp_media_warmup {
ctx,
slot.has_mtp && slot.spec ? &slot : nullptr,
};
mtmd_helper_eval_batch_callback mtp_media_callback =
mtp_media_warmup.slot ? server_mtp_media_warmup_callback : nullptr;
int32_t res = slot.prompt_tokens.process_chunk(
ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out,
mtp_media_callback, &mtp_media_warmup);
if (res != 0) {
LLAMA_LOG_ERROR("failed to process image, res = %d\n", res);
slot.release();
@@ -4000,8 +4059,9 @@ static void restore_speculative_checkpoint(
if (slot.use_gemma4_external_mtp) {
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, slot.mtp_hidden_state, n_embd);
} else {
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, slot.spec_ckpt.n_past, slot.id);
llama_context * mtp_ctx = get_slot_mtp_ctx(slot, ctx);
llama_set_draft_input_hidden_state(mtp_ctx, slot.mtp_hidden_state.data());
mtp_accept_tokens(mtp_ctx, ids, slot.spec_ckpt.n_past, slot.id);
if (n_accepted > 1) {
memmove(slot.mtp_hidden_state.data(),
@@ -4063,7 +4123,8 @@ void server_context::speculative_decoding_accept() {
int32_t mtp_n_past_base = 0;
std::vector<float> mtp_hidden_state_pre;
if (slot.has_mtp) {
mtp_n_past_base = slot.n_past - (slot.drafted.size() + 1);
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens);
const int n_embd = get_ctx_mtp_n_embd(ctx);
if (!ids.empty()) {
@@ -4101,7 +4162,9 @@ void server_context::speculative_decoding_accept() {
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
// add accepted tokens to the prompt
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
for (auto it = ids.begin(); it != ids.end() - 1; ++it) {
slot.cache_tokens.push_back(*it);
}
slot.sampled = ids.back(); // last accepted token
slot.n_past = slot.cache_tokens.n_tokens();
@@ -4484,42 +4547,24 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
continue; // continue loop of n_batch
}
bool mtp_warmup_needed = false;
llama_context * batch_mtp_target = nullptr;
std::vector<float> batch_mtp_hidden_state;
server_slot * mtp_warmup_slot = nullptr;
if (params_base.has_mtp) {
for (auto & slot : slots) {
if (slot.spec && slot.has_mtp) {
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
if (mc) {
batch_mtp_target = mc;
break;
}
}
}
for (auto& slot : slots) {
if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) ||
(slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) {
bool has_tokens_for_slot = (batch_view.n_tokens > 0 && batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id);
if (has_tokens_for_slot) {
mtp_warmup_needed = true;
mtp_warmup_slot = &slot;
break;
}
}
}
if (mtp_warmup_needed) {
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
const int n_embd_src = get_ctx_mtp_n_embd(ctx);
const int n_embd_dst = get_ctx_mtp_n_embd(mtp_target);
const int n_toks = batch_view.n_tokens;
batch_mtp_hidden_state.assign(n_toks * n_embd_dst, 0.0f);
for (int t = 0; t < n_toks; t++) {
const float* emb_t = llama_get_embeddings_ith(ctx, t);
if (emb_t) {
const int n_copy = std::min(n_embd_src, n_embd_dst);
memcpy(batch_mtp_hidden_state.data() + t * n_embd_dst, emb_t, n_copy * sizeof(float));
}
}
}
if (mtp_warmup_slot && mtp_warmup_slot->spec && mtp_warmup_slot->has_mtp) {
llama_context * mtp_ctx = get_slot_mtp_ctx(*mtp_warmup_slot, ctx);
if (server_mtp_warmup_batch(ctx, mtp_ctx, &batch_view, *mtp_warmup_slot) != 0) {
LOG_ERROR("%s\n", "failed to warm up MTP state from prompt batch");
}
}
@@ -4547,7 +4592,11 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
}
if (slot.n_decoded == 0 && slot.can_speculate()) {
common_speculative_begin(slot.spec, slot.cache_tokens.get_text_tokens());
static const llama_tokens empty_prompt;
const llama_tokens & spec_prompt = slot.has_mtp
? empty_prompt
: slot.cache_tokens.get_text_tokens();
common_speculative_begin(slot.spec, spec_prompt);
}
if (slot.i_batch_dft.size() > 0) {
@@ -4568,7 +4617,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
completion_token_output result;
const int tok_idx = slot.i_batch - i;
if (params_base.has_mtp && slot.n_decoded == 0) {
if (slot.has_mtp && slot.n_decoded == 0) {
const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx);
if (emb_i) {
const int n_embd = get_ctx_mtp_n_embd(ctx);
@@ -4641,20 +4690,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.i_batch = -1;
}
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
if (params_use_gemma4_external_mtp(params_base)) {
for (auto & slot : slots) {
if (slot.spec && slot.has_mtp && !slot.mtp_hidden_state.empty()) {
sync_slot_mtp_hidden(slot, ctx);
}
}
} else {
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
mtp_update_kv_cache(mtp_target, batch_view, true);
}
}
// speculative decoding - main model sample and accept
speculative_decoding_accept();
}