mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-24 22:59:14 +00:00
server : support MTP with multimodal prompts (#1758)
Synchronize MTP state after mtmd decode batches so multimodal prompt chunks do not desync the draft context.
This commit is contained in:
@@ -115,14 +115,15 @@ static void apply_slot_mtp_accept(
|
||||
return;
|
||||
}
|
||||
|
||||
llama_context * mtp_ctx = get_slot_mtp_ctx(slot, ctx);
|
||||
if (slot.use_gemma4_external_mtp) {
|
||||
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, mtp_hidden_state, n_embd);
|
||||
return;
|
||||
}
|
||||
|
||||
slot.mtp_hidden_state = mtp_hidden_state;
|
||||
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, mtp_n_past_base, slot.id);
|
||||
llama_set_draft_input_hidden_state(mtp_ctx, slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(mtp_ctx, ids, mtp_n_past_base, slot.id);
|
||||
}
|
||||
|
||||
static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) {
|
||||
@@ -133,8 +134,53 @@ static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, con
|
||||
cache_and_sync_slot_mtp_hidden(slot, ctx, hidden, n_embd);
|
||||
}
|
||||
|
||||
static void set_external_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector<float> & rows, int n_embd) {
|
||||
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, rows, n_embd);
|
||||
struct server_mtp_warmup {
|
||||
llama_context * ctx_tgt;
|
||||
server_slot * slot;
|
||||
};
|
||||
|
||||
static int32_t server_mtp_warmup_batch(
|
||||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_mtp,
|
||||
const llama_batch * batch,
|
||||
server_slot & slot) {
|
||||
if (!ctx_tgt || !ctx_mtp || !batch || batch->n_tokens <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const float * emb = llama_get_embeddings(ctx_tgt);
|
||||
const int n_embd_src = get_ctx_mtp_n_embd(ctx_tgt);
|
||||
const int n_embd_dst = get_ctx_mtp_n_embd(ctx_mtp);
|
||||
if (emb == nullptr || n_embd_src <= 0 || n_embd_dst <= 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (n_embd_src != n_embd_dst) {
|
||||
LOG_ERROR("MTP warmup hidden state width mismatch", {
|
||||
{"n_embd_src", n_embd_src},
|
||||
{"n_embd_dst", n_embd_dst},
|
||||
});
|
||||
return -1;
|
||||
}
|
||||
|
||||
const float * last_hidden = emb + (batch->n_tokens - 1) * n_embd_src;
|
||||
if (slot.use_gemma4_external_mtp) {
|
||||
cache_and_sync_slot_mtp_hidden(slot, ctx_tgt, last_hidden, n_embd_dst);
|
||||
return 0;
|
||||
}
|
||||
|
||||
cache_slot_mtp_hidden(slot, last_hidden, n_embd_dst);
|
||||
llama_set_draft_input_hidden_state(ctx_mtp, emb);
|
||||
return mtp_update_kv_cache(ctx_mtp, *batch, true);
|
||||
}
|
||||
|
||||
static int32_t server_mtp_media_warmup_callback(void * user_data, const llama_batch * batch) {
|
||||
auto * data = static_cast<server_mtp_warmup *>(user_data);
|
||||
if (data == nullptr || data->slot == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return server_mtp_warmup_batch(data->ctx_tgt, get_slot_mtp_ctx(*data->slot, data->ctx_tgt), batch, *data->slot);
|
||||
}
|
||||
|
||||
void server_speculative_checkpoint::clear() {
|
||||
@@ -156,7 +202,8 @@ static void discard_speculative_checkpoint(server_slot & slot, llama_context * c
|
||||
|
||||
static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) {
|
||||
slot.spec_ckpt.clear();
|
||||
slot.spec_ckpt.n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1);
|
||||
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
|
||||
slot.spec_ckpt.n_past = slot.cache_tokens.pos_next(n_pre_spec_tokens);
|
||||
slot.spec_ckpt.sampled = slot.sampled;
|
||||
|
||||
const int max_tokens = (int)slot.drafted.size() + 1;
|
||||
@@ -266,7 +313,8 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
|
||||
return false;
|
||||
}
|
||||
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE &&
|
||||
params_base.speculative.type != COMMON_SPECULATIVE_TYPE_MTP) {
|
||||
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
|
||||
}
|
||||
@@ -417,7 +465,7 @@ void server_context::init() {
|
||||
if (can_spec) {
|
||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
||||
if (slot.spec) {
|
||||
if (mctx) {
|
||||
if (mctx && !slot.has_mtp) {
|
||||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
||||
return;
|
||||
}
|
||||
@@ -3366,12 +3414,15 @@ void server_context::add_sampled_tokens() {
|
||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||
const int n_draft_max_pre = slot.get_n_draft_max();
|
||||
if (n_draft_max_pre > 0) {
|
||||
if (mctx) {
|
||||
if (mctx && !slot.has_mtp) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
static const llama_tokens empty_prompt;
|
||||
const llama_tokens & cached_text_tokens = slot.has_mtp
|
||||
? empty_prompt
|
||||
: slot.cache_tokens.get_text_tokens();
|
||||
|
||||
auto & params_spec = slot.params.speculative;
|
||||
|
||||
@@ -3797,7 +3848,15 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
// process the image
|
||||
size_t n_tokens_out = 0;
|
||||
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
|
||||
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out);
|
||||
server_mtp_warmup mtp_media_warmup {
|
||||
ctx,
|
||||
slot.has_mtp && slot.spec ? &slot : nullptr,
|
||||
};
|
||||
mtmd_helper_eval_batch_callback mtp_media_callback =
|
||||
mtp_media_warmup.slot ? server_mtp_media_warmup_callback : nullptr;
|
||||
int32_t res = slot.prompt_tokens.process_chunk(
|
||||
ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out,
|
||||
mtp_media_callback, &mtp_media_warmup);
|
||||
if (res != 0) {
|
||||
LLAMA_LOG_ERROR("failed to process image, res = %d\n", res);
|
||||
slot.release();
|
||||
@@ -4000,8 +4059,9 @@ static void restore_speculative_checkpoint(
|
||||
if (slot.use_gemma4_external_mtp) {
|
||||
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, slot.mtp_hidden_state, n_embd);
|
||||
} else {
|
||||
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, slot.spec_ckpt.n_past, slot.id);
|
||||
llama_context * mtp_ctx = get_slot_mtp_ctx(slot, ctx);
|
||||
llama_set_draft_input_hidden_state(mtp_ctx, slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(mtp_ctx, ids, slot.spec_ckpt.n_past, slot.id);
|
||||
|
||||
if (n_accepted > 1) {
|
||||
memmove(slot.mtp_hidden_state.data(),
|
||||
@@ -4063,7 +4123,8 @@ void server_context::speculative_decoding_accept() {
|
||||
int32_t mtp_n_past_base = 0;
|
||||
std::vector<float> mtp_hidden_state_pre;
|
||||
if (slot.has_mtp) {
|
||||
mtp_n_past_base = slot.n_past - (slot.drafted.size() + 1);
|
||||
const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1);
|
||||
mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens);
|
||||
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
if (!ids.empty()) {
|
||||
@@ -4101,7 +4162,9 @@ void server_context::speculative_decoding_accept() {
|
||||
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
||||
|
||||
// add accepted tokens to the prompt
|
||||
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
|
||||
for (auto it = ids.begin(); it != ids.end() - 1; ++it) {
|
||||
slot.cache_tokens.push_back(*it);
|
||||
}
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
|
||||
@@ -4484,42 +4547,24 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
bool mtp_warmup_needed = false;
|
||||
llama_context * batch_mtp_target = nullptr;
|
||||
std::vector<float> batch_mtp_hidden_state;
|
||||
server_slot * mtp_warmup_slot = nullptr;
|
||||
if (params_base.has_mtp) {
|
||||
for (auto & slot : slots) {
|
||||
if (slot.spec && slot.has_mtp) {
|
||||
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
|
||||
if (mc) {
|
||||
batch_mtp_target = mc;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& slot : slots) {
|
||||
if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) ||
|
||||
(slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) {
|
||||
bool has_tokens_for_slot = (batch_view.n_tokens > 0 && batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id);
|
||||
if (has_tokens_for_slot) {
|
||||
mtp_warmup_needed = true;
|
||||
mtp_warmup_slot = &slot;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mtp_warmup_needed) {
|
||||
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
|
||||
const int n_embd_src = get_ctx_mtp_n_embd(ctx);
|
||||
const int n_embd_dst = get_ctx_mtp_n_embd(mtp_target);
|
||||
const int n_toks = batch_view.n_tokens;
|
||||
batch_mtp_hidden_state.assign(n_toks * n_embd_dst, 0.0f);
|
||||
for (int t = 0; t < n_toks; t++) {
|
||||
const float* emb_t = llama_get_embeddings_ith(ctx, t);
|
||||
if (emb_t) {
|
||||
const int n_copy = std::min(n_embd_src, n_embd_dst);
|
||||
memcpy(batch_mtp_hidden_state.data() + t * n_embd_dst, emb_t, n_copy * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mtp_warmup_slot && mtp_warmup_slot->spec && mtp_warmup_slot->has_mtp) {
|
||||
llama_context * mtp_ctx = get_slot_mtp_ctx(*mtp_warmup_slot, ctx);
|
||||
if (server_mtp_warmup_batch(ctx, mtp_ctx, &batch_view, *mtp_warmup_slot) != 0) {
|
||||
LOG_ERROR("%s\n", "failed to warm up MTP state from prompt batch");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4547,7 +4592,11 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
|
||||
if (slot.n_decoded == 0 && slot.can_speculate()) {
|
||||
common_speculative_begin(slot.spec, slot.cache_tokens.get_text_tokens());
|
||||
static const llama_tokens empty_prompt;
|
||||
const llama_tokens & spec_prompt = slot.has_mtp
|
||||
? empty_prompt
|
||||
: slot.cache_tokens.get_text_tokens();
|
||||
common_speculative_begin(slot.spec, spec_prompt);
|
||||
}
|
||||
|
||||
if (slot.i_batch_dft.size() > 0) {
|
||||
@@ -4568,7 +4617,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
completion_token_output result;
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
if (params_base.has_mtp && slot.n_decoded == 0) {
|
||||
if (slot.has_mtp && slot.n_decoded == 0) {
|
||||
const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx);
|
||||
if (emb_i) {
|
||||
const int n_embd = get_ctx_mtp_n_embd(ctx);
|
||||
@@ -4641,20 +4690,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
|
||||
if (params_use_gemma4_external_mtp(params_base)) {
|
||||
for (auto & slot : slots) {
|
||||
if (slot.spec && slot.has_mtp && !slot.mtp_hidden_state.empty()) {
|
||||
sync_slot_mtp_hidden(slot, ctx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
|
||||
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
|
||||
mtp_update_kv_cache(mtp_target, batch_view, true);
|
||||
}
|
||||
}
|
||||
|
||||
// speculative decoding - main model sample and accept
|
||||
speculative_decoding_accept();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user