diff --git a/common/speculative.cpp b/common/speculative.cpp index 531ce010..8ca0b928 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -198,12 +198,14 @@ struct common_speculative_state_mtp : public common_speculative_state { llama_token id_last, llama_tokens & result) override { - int32_t n_past = (int32_t)prompt_tgt.size(); llama_seq_id seq_id = 0; llama_pos mtp_pos_max = llama_kv_cache_seq_pos_max(ctx_mtp, seq_id); - if (mtp_pos_max >= n_past) { - llama_kv_cache_seq_rm(ctx_mtp, seq_id, n_past, -1); + int32_t n_past = mtp_pos_max >= 0 ? (int32_t)mtp_pos_max + 1 : (int32_t)prompt_tgt.size(); + + if (!prompt_tgt.empty() && mtp_pos_max < (llama_pos)prompt_tgt.size() - 1) { + LOG_WRN("%s: MTP context not fully warmed up: pos_max = %d, expected = %d\n", + __func__, (int)mtp_pos_max, (int)prompt_tgt.size() - 1); } llama_context * ctx = ctx_mtp; @@ -1478,9 +1480,9 @@ std::vector mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { +int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { if (batch.n_tokens == 0) { - return; + return 0; } llama_seq_id seq_id = batch.seq_id[0][0]; @@ -1509,8 +1511,9 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b } } - llama_decode(ctx, mtp_batch); + const int32_t ret = llama_decode(ctx, mtp_batch); llama_set_mtp_op_type(ctx, MTP_OP_NONE); + return ret; } void mtp_accept_tokens( @@ -1527,7 +1530,11 @@ void mtp_accept_tokens( common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); } - mtp_update_kv_cache(ctx, accepted_batch, false); + if (mtp_update_kv_cache(ctx, accepted_batch, false) != 0) { + LOG_ERR("failed to update MTP KV cache for accepted tokens\n"); + llama_batch_free(accepted_batch); + return; + } auto & last = mtp_get_last_embd(ctx); auto embd = llama_get_embeddings_ith(ctx, ids.size() - 1); diff --git a/common/speculative.h b/common/speculative.h index 6061d133..58c4142e 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -63,7 +63,7 @@ std::vector mtp_speculative_gen_draft( llama_seq_id seq_id, bool constant_draft_positions = false); -void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); +int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); void mtp_accept_tokens( struct llama_context * ctx, diff --git a/examples/mtmd/mtmd-helper.cpp b/examples/mtmd/mtmd-helper.cpp index b3b5225f..7a5d0ee5 100644 --- a/examples/mtmd/mtmd-helper.cpp +++ b/examples/mtmd/mtmd-helper.cpp @@ -164,8 +164,7 @@ struct decode_embd_batch { } }; -// Helper function for decoding an image whose embeddings have already been calculated -int32_t mtmd_helper_decode_image_chunk( +static int32_t mtmd_helper_decode_image_chunk_impl( mtmd_context * ctx, struct llama_context * lctx, const mtmd_input_chunk * chunk, @@ -173,7 +172,9 @@ int32_t mtmd_helper_decode_image_chunk( llama_pos n_past, llama_seq_id seq_id, int32_t n_batch, - llama_pos * new_n_past) { + llama_pos * new_n_past, + mtmd_helper_eval_batch_callback callback, + void * callback_user_data) { auto chunk_type = mtmd_input_chunk_get_type(chunk); const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { @@ -231,6 +232,15 @@ int32_t mtmd_helper_decode_image_chunk( LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1); + if (callback) { + int32_t callback_ret = callback(callback_user_data, &batch_embd_view); + if (callback_ret != 0) { + LOG_ERR("failed to process %s decode callback\n", name); + llama_set_causal_attn(lctx, true); // restore causal attn + return callback_ret; + } + } + i_batch++; } @@ -243,6 +253,20 @@ int32_t mtmd_helper_decode_image_chunk( return 0; } +// Helper function for decoding an image whose embeddings have already been calculated +int32_t mtmd_helper_decode_image_chunk( + mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + float * encoded_embd, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + llama_pos * new_n_past) { + return mtmd_helper_decode_image_chunk_impl( + ctx, lctx, chunk, encoded_embd, n_past, seq_id, n_batch, new_n_past, nullptr, nullptr); +} + int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, struct llama_context * lctx, const mtmd_input_chunk * chunk, @@ -251,6 +275,20 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, int32_t n_batch, bool logits_last, llama_pos * new_n_past) { + return mtmd_helper_eval_chunk_single_with_callback( + ctx, lctx, chunk, n_past, seq_id, n_batch, logits_last, new_n_past, nullptr, nullptr); +} + +int32_t mtmd_helper_eval_chunk_single_with_callback(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past, + mtmd_helper_eval_batch_callback callback, + void * callback_user_data) { int32_t ret; llama_batch text_batch = llama_batch_init(n_batch, 0, 1); auto chunk_type = mtmd_input_chunk_get_type(chunk); @@ -282,6 +320,14 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, llama_batch_free(text_batch); return ret; } + if (callback) { + int32_t callback_ret = callback(callback_user_data, &text_batch); + if (callback_ret != 0) { + LOG_ERR("failed to process text decode callback\n"); + llama_batch_free(text_batch); + return callback_ret; + } + } *new_n_past += text_batch.n_tokens; } @@ -301,7 +347,8 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); float * embd = mtmd_get_output_embd(ctx); - ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); + ret = mtmd_helper_decode_image_chunk_impl( + ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, callback, callback_user_data); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); llama_batch_free(text_batch); diff --git a/examples/mtmd/mtmd-helper.h b/examples/mtmd/mtmd-helper.h index 5c0edc69..748e0c8f 100644 --- a/examples/mtmd/mtmd-helper.h +++ b/examples/mtmd/mtmd-helper.h @@ -20,6 +20,8 @@ extern "C" { // BREAKING CHANGES are expected. // +typedef int32_t (*mtmd_helper_eval_batch_callback)(void * user_data, const struct llama_batch * batch); + // helper function to construct a mtmd_bitmap from a file // it calls mtmd_helper_bitmap_init_from_buf() internally // returns nullptr on failure @@ -68,6 +70,19 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, bool logits_last, llama_pos * new_n_past); +// works like mtmd_helper_eval_chunk_single(), and calls callback after each successful llama_decode() batch +// the batch pointer is only valid for the duration of the callback +MTMD_API int32_t mtmd_helper_eval_chunk_single_with_callback(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past, + mtmd_helper_eval_batch_callback callback, + void * callback_user_data); + // helper function to decode an image whose embeddings have already been calculated // this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention) // ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure diff --git a/examples/server/server-common.cpp b/examples/server/server-common.cpp index d269e6dc..0aa26e96 100644 --- a/examples/server/server-common.cpp +++ b/examples/server/server-common.cpp @@ -2144,7 +2144,9 @@ int32_t server_tokens::process_chunk( size_t idx, llama_pos pos, int32_t seq_id, - size_t& n_tokens_out) const { + size_t& n_tokens_out, + mtmd_helper_eval_batch_callback callback, + void * callback_user_data) const { const auto& chunk = find_chunk(idx); const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; @@ -2152,13 +2154,15 @@ int32_t server_tokens::process_chunk( int32_t n_batch = llama_n_batch(ctx); int64_t t0 = ggml_time_ms(); llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + int32_t result = mtmd_helper_eval_chunk_single_with_callback(mctx, ctx, chunk.get(), pos, seq_id, n_batch, true, // logits last - &new_n_past); + &new_n_past, + callback, + callback_user_data); LLAMA_LOG_INFO("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); if (result != 0) { LLAMA_LOG_ERROR("mtmd_helper_eval failed with status %d", result); diff --git a/examples/server/server-common.h b/examples/server/server-common.h index 211c2c28..e150ef28 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -440,7 +440,9 @@ public: size_t idx, llama_pos pos, int32_t seq_id, - size_t& n_tokens_out) const; + size_t& n_tokens_out, + mtmd_helper_eval_batch_callback callback = nullptr, + void * callback_user_data = nullptr) const; server_tokens clone() const; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index dd0f0f4d..8f3e01e0 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -115,14 +115,15 @@ static void apply_slot_mtp_accept( return; } + llama_context * mtp_ctx = get_slot_mtp_ctx(slot, ctx); if (slot.use_gemma4_external_mtp) { cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, mtp_hidden_state, n_embd); return; } slot.mtp_hidden_state = mtp_hidden_state; - llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data()); - mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, mtp_n_past_base, slot.id); + llama_set_draft_input_hidden_state(mtp_ctx, slot.mtp_hidden_state.data()); + mtp_accept_tokens(mtp_ctx, ids, mtp_n_past_base, slot.id); } static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) { @@ -133,8 +134,53 @@ static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, con cache_and_sync_slot_mtp_hidden(slot, ctx, hidden, n_embd); } -static void set_external_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector & rows, int n_embd) { - cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, rows, n_embd); +struct server_mtp_warmup { + llama_context * ctx_tgt; + server_slot * slot; +}; + +static int32_t server_mtp_warmup_batch( + llama_context * ctx_tgt, + llama_context * ctx_mtp, + const llama_batch * batch, + server_slot & slot) { + if (!ctx_tgt || !ctx_mtp || !batch || batch->n_tokens <= 0) { + return 0; + } + + const float * emb = llama_get_embeddings(ctx_tgt); + const int n_embd_src = get_ctx_mtp_n_embd(ctx_tgt); + const int n_embd_dst = get_ctx_mtp_n_embd(ctx_mtp); + if (emb == nullptr || n_embd_src <= 0 || n_embd_dst <= 0) { + return -1; + } + + if (n_embd_src != n_embd_dst) { + LOG_ERROR("MTP warmup hidden state width mismatch", { + {"n_embd_src", n_embd_src}, + {"n_embd_dst", n_embd_dst}, + }); + return -1; + } + + const float * last_hidden = emb + (batch->n_tokens - 1) * n_embd_src; + if (slot.use_gemma4_external_mtp) { + cache_and_sync_slot_mtp_hidden(slot, ctx_tgt, last_hidden, n_embd_dst); + return 0; + } + + cache_slot_mtp_hidden(slot, last_hidden, n_embd_dst); + llama_set_draft_input_hidden_state(ctx_mtp, emb); + return mtp_update_kv_cache(ctx_mtp, *batch, true); +} + +static int32_t server_mtp_media_warmup_callback(void * user_data, const llama_batch * batch) { + auto * data = static_cast(user_data); + if (data == nullptr || data->slot == nullptr) { + return 0; + } + + return server_mtp_warmup_batch(data->ctx_tgt, get_slot_mtp_ctx(*data->slot, data->ctx_tgt), batch, *data->slot); } void server_speculative_checkpoint::clear() { @@ -156,7 +202,8 @@ static void discard_speculative_checkpoint(server_slot & slot, llama_context * c static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) { slot.spec_ckpt.clear(); - slot.spec_ckpt.n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1); + const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1); + slot.spec_ckpt.n_past = slot.cache_tokens.pos_next(n_pre_spec_tokens); slot.spec_ckpt.sampled = slot.sampled; const int max_tokens = (int)slot.drafted.size() + 1; @@ -266,7 +313,8 @@ bool server_context::load_model(const gpt_params& params_) { LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal"); return false; } - if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) { + if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE && + params_base.speculative.type != COMMON_SPECULATIVE_TYPE_MTP) { params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled"); } @@ -417,7 +465,7 @@ void server_context::init() { if (can_spec) { slot.spec = common_speculative_init(params_base.speculative, slot.ctx); if (slot.spec) { - if (mctx) { + if (mctx && !slot.has_mtp) { SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); return; } @@ -3366,12 +3414,15 @@ void server_context::add_sampled_tokens() { // perform the speculative drafting for all sequences at the same time in a single batch const int n_draft_max_pre = slot.get_n_draft_max(); if (n_draft_max_pre > 0) { - if (mctx) { + if (mctx && !slot.has_mtp) { // we should never reach this, as speculative is automatically disabled if mmproj is loaded GGML_ABORT("not supported by multimodal"); } - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + static const llama_tokens empty_prompt; + const llama_tokens & cached_text_tokens = slot.has_mtp + ? empty_prompt + : slot.cache_tokens.get_text_tokens(); auto & params_spec = slot.params.speculative; @@ -3797,7 +3848,15 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t // process the image size_t n_tokens_out = 0; llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out); + server_mtp_warmup mtp_media_warmup { + ctx, + slot.has_mtp && slot.spec ? &slot : nullptr, + }; + mtmd_helper_eval_batch_callback mtp_media_callback = + mtp_media_warmup.slot ? server_mtp_media_warmup_callback : nullptr; + int32_t res = slot.prompt_tokens.process_chunk( + ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out, + mtp_media_callback, &mtp_media_warmup); if (res != 0) { LLAMA_LOG_ERROR("failed to process image, res = %d\n", res); slot.release(); @@ -4000,8 +4059,9 @@ static void restore_speculative_checkpoint( if (slot.use_gemma4_external_mtp) { cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, slot.mtp_hidden_state, n_embd); } else { - llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data()); - mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, slot.spec_ckpt.n_past, slot.id); + llama_context * mtp_ctx = get_slot_mtp_ctx(slot, ctx); + llama_set_draft_input_hidden_state(mtp_ctx, slot.mtp_hidden_state.data()); + mtp_accept_tokens(mtp_ctx, ids, slot.spec_ckpt.n_past, slot.id); if (n_accepted > 1) { memmove(slot.mtp_hidden_state.data(), @@ -4063,7 +4123,8 @@ void server_context::speculative_decoding_accept() { int32_t mtp_n_past_base = 0; std::vector mtp_hidden_state_pre; if (slot.has_mtp) { - mtp_n_past_base = slot.n_past - (slot.drafted.size() + 1); + const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1); + mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens); const int n_embd = get_ctx_mtp_n_embd(ctx); if (!ids.empty()) { @@ -4101,7 +4162,9 @@ void server_context::speculative_decoding_accept() { slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft); // add accepted tokens to the prompt - slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 }); + for (auto it = ids.begin(); it != ids.end() - 1; ++it) { + slot.cache_tokens.push_back(*it); + } slot.sampled = ids.back(); // last accepted token slot.n_past = slot.cache_tokens.n_tokens(); @@ -4484,42 +4547,24 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // continue loop of n_batch } - bool mtp_warmup_needed = false; - llama_context * batch_mtp_target = nullptr; - std::vector batch_mtp_hidden_state; + server_slot * mtp_warmup_slot = nullptr; if (params_base.has_mtp) { - for (auto & slot : slots) { - if (slot.spec && slot.has_mtp) { - llama_context * mc = common_speculative_get_mtp_ctx(slot.spec); - if (mc) { - batch_mtp_target = mc; - break; - } - } - } for (auto& slot : slots) { if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) || (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) { bool has_tokens_for_slot = (batch_view.n_tokens > 0 && batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id); if (has_tokens_for_slot) { - mtp_warmup_needed = true; + mtp_warmup_slot = &slot; break; } } } - if (mtp_warmup_needed) { - llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx; - const int n_embd_src = get_ctx_mtp_n_embd(ctx); - const int n_embd_dst = get_ctx_mtp_n_embd(mtp_target); - const int n_toks = batch_view.n_tokens; - batch_mtp_hidden_state.assign(n_toks * n_embd_dst, 0.0f); - for (int t = 0; t < n_toks; t++) { - const float* emb_t = llama_get_embeddings_ith(ctx, t); - if (emb_t) { - const int n_copy = std::min(n_embd_src, n_embd_dst); - memcpy(batch_mtp_hidden_state.data() + t * n_embd_dst, emb_t, n_copy * sizeof(float)); - } - } + } + + if (mtp_warmup_slot && mtp_warmup_slot->spec && mtp_warmup_slot->has_mtp) { + llama_context * mtp_ctx = get_slot_mtp_ctx(*mtp_warmup_slot, ctx); + if (server_mtp_warmup_batch(ctx, mtp_ctx, &batch_view, *mtp_warmup_slot) != 0) { + LOG_ERROR("%s\n", "failed to warm up MTP state from prompt batch"); } } @@ -4547,7 +4592,11 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } if (slot.n_decoded == 0 && slot.can_speculate()) { - common_speculative_begin(slot.spec, slot.cache_tokens.get_text_tokens()); + static const llama_tokens empty_prompt; + const llama_tokens & spec_prompt = slot.has_mtp + ? empty_prompt + : slot.cache_tokens.get_text_tokens(); + common_speculative_begin(slot.spec, spec_prompt); } if (slot.i_batch_dft.size() > 0) { @@ -4568,7 +4617,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { completion_token_output result; const int tok_idx = slot.i_batch - i; - if (params_base.has_mtp && slot.n_decoded == 0) { + if (slot.has_mtp && slot.n_decoded == 0) { const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx); if (emb_i) { const int n_embd = get_ctx_mtp_n_embd(ctx); @@ -4641,20 +4690,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.i_batch = -1; } - if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) { - if (params_use_gemma4_external_mtp(params_base)) { - for (auto & slot : slots) { - if (slot.spec && slot.has_mtp && !slot.mtp_hidden_state.empty()) { - sync_slot_mtp_hidden(slot, ctx); - } - } - } else { - llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx; - llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data()); - mtp_update_kv_cache(mtp_target, batch_view, true); - } - } - // speculative decoding - main model sample and accept speculative_decoding_accept(); } diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 0ab5b689..fca8c4b2 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -315,7 +315,9 @@ struct ggml_tensor * llm_build_context::build_inp_embd_mtp(struct ggml_tensor * cur = ggml_get_rows(ctx0, mtp_tok_embd, lctx.inp_tokens); } else { - return nullptr; + lctx.inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, batch.n_tokens); + ggml_set_input(lctx.inp_embd); + cur = lctx.inp_embd; } cb(cur, "inp_embd", -1);