From 66323b92f7e46491d975f7a0866125fe109a870c Mon Sep 17 00:00:00 2001 From: firecoperana <18252262+firecoperana@users.noreply.github.com> Date: Sat, 21 Feb 2026 11:24:12 -0600 Subject: [PATCH] Qwen3.5-MoE: fix regenerating message error (#1295) Co-authored-by: firecoperana --- common/common.h | 2 + examples/server/server-context.cpp | 128 ++++++++++++++++++++++++++++- examples/server/server-context.h | 3 + examples/server/server-task.h | 1 + include/llama.h | 6 ++ src/llama-arch.cpp | 20 +++++ src/llama-arch.h | 3 + src/llama-context.h | 1 + src/llama-model.cpp | 8 ++ src/llama-model.h | 1 + src/llama.cpp | 3 +- 11 files changed, 172 insertions(+), 4 deletions(-) diff --git a/common/common.h b/common/common.h index 227eb243..8a958cdd 100644 --- a/common/common.h +++ b/common/common.h @@ -414,6 +414,8 @@ struct gpt_params { std::string sqlite_zstd_ext_file; float slot_prompt_similarity = 0.1f; + + int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 73eba614..104cb793 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -2587,6 +2587,123 @@ void server_context::add_sampled_tokens() { } } +void server_context::apply_checkpoint(server_slot & slot) { + const auto pos_min_thold = std::max(0, slot.n_past - 1); + if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) { + int32_t pos_min = 0; + if (llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) { + pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + } + + if (pos_min > pos_min_thold+2) { + // TODO: support can be added in the future when corresponding vision models get released + GGML_ASSERT(!slot.cache_tokens.has_mtmd); + + SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min); + + // search for a context checkpoint + const auto it = std::find_if( + slot.server_cached_prompt.checkpoints.rbegin(), + slot.server_cached_prompt.checkpoints.rend(), + [&](const auto & cur) { + // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + return cur.pos_min < pos_min_thold; + } + ); + + bool do_reset = it == slot.server_cached_prompt.checkpoints.rend(); + + if (!do_reset) { + // restore the context checkpoint + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id); + + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); + do_reset = true; + //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); + } else { + slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); + } + } + + if (do_reset) { + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + slot.n_past = 0; + slot.n_past_prompt = 0; + } + } + } + + { + // erase any checkpoints with pos_min > pos_min_thold + for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) { + const auto & cur = *it; + if (cur.pos_min > pos_min_thold) { + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); + it = slot.server_cached_prompt.checkpoints.erase(it); + } else { + ++it; + } + } + } +} + +void server_context::create_checkpoint(server_slot & slot) { + //bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + + //// make checkpoints only for completion tasks + //do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + + //// make a checkpoint of the parts of the memory that cannot be rolled back. + //// checkpoints are created only if: + //// - the model architecture is marked as recurrent or hybrid + //// + //// TODO: try to make this conditional on the context or the memory module, instead of the model type + //do_checkpoint = do_checkpoint && ( + // llama_model_is_recurrent(model) || + // llama_model_is_hybrid(model) + // ); + //int32_t pos_min = 0; + //if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) { + // pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + //} + //const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + + //// no need for empty or small checkpoints + //do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 5); + + //// no need to create checkpoints that are too close together + //do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max + 64); + + //if (do_checkpoint) { + // while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.n_ctx_checkpoints) { + // // make room for the new checkpoint, if needed + // const auto & cur = slot.server_cached_prompt.checkpoints.front(); + + // SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + // cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); + + // slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); + // } + + // const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id); + + // auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + // /*.pos_min = */ pos_min, + // /*.pos_max = */ pos_max, + // /*.data = */ std::vector(checkpoint_size), + // }); + + // llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id); + + // SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + // (int)slot.server_cached_prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); + //} +} + void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) { if (params_base.cont_batching || batch.n_tokens == 0) { for (auto& slot : slots) { @@ -2760,7 +2877,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t } } } - + apply_checkpoint(slot); if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) { // we have to evaluate at least 1 token to generate logits. LOG_INFO("we have to evaluate at least 1 token to generate logits", { @@ -2916,6 +3033,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; + //create_checkpoint(slot); + LOG_VERBOSE("prompt done", { {"id_slot", slot.id}, {"n_past", slot.n_past}, @@ -3008,10 +3127,11 @@ void server_context::speculative_decoding_accept() { populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i); } - if (slot.n_buffer == 0) { + if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) { if (!process_token(result, slot)) { // release slot because of stop condition send_final_response(slot); + //create_checkpoint(slot); slot.release(); slot.print_timings(); metrics.on_prediction(slot); @@ -3046,6 +3166,7 @@ void server_context::send_token_results(completion_token_outputs& results, serve count++; if (!has_next) { send_final_response(slot); + //create_checkpoint(slot); slot.release(); slot.print_timings(); metrics.on_prediction(slot); @@ -3259,7 +3380,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) { populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); } - if (slot.n_buffer == 0) { + // no ban string for recurrent/hybrid model + if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) { slot.token_buffer = { result }; send_token_results(slot.token_buffer, slot); } else { diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 2153b44b..1c16cc35 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -349,4 +349,7 @@ struct server_context { // Re-aggregates all active vectors and updates the model state bool apply_control_vectors_internal(); + void create_checkpoint(server_slot & slot); + + void apply_checkpoint(server_slot & slot); }; diff --git a/examples/server/server-task.h b/examples/server/server-task.h index b62f5a7d..cbb2ce16 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -368,6 +368,7 @@ struct server_prompt { int n_tokens() const { return tokens.size(); } + }; struct server_prompt_cache { diff --git a/include/llama.h b/include/llama.h index b8cbde9b..5332d810 100644 --- a/include/llama.h +++ b/include/llama.h @@ -627,6 +627,12 @@ extern "C" { // to the decoder to start generating output sequence. For other models, it returns -1. LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns true if the model is recurrent (like Mamba, RWKV, etc.) + LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + + // Returns true if the model is hybrid (like Jamba, Granite, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fd748a57..86332a66 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -246,3 +246,23 @@ const char * llama_model_arch_name(llm_arch arch) { } return it->second; } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_QWEN3MOE: + return true; + default: + return false; + } +} + diff --git a/src/llama-arch.h b/src/llama-arch.h index e447261a..fbfba4ac 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -342,3 +342,6 @@ enum llm_tensor { llm_arch llm_arch_from_string(const std::string & name); const char * llama_model_arch_name(llm_arch arch); + +bool llm_arch_is_recurrent(const llm_arch & arch); +bool llm_arch_is_hybrid(const llm_arch & arch); diff --git a/src/llama-context.h b/src/llama-context.h index b2b755ee..9f4255fd 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -37,6 +37,7 @@ struct llama_kv_cache { bool do_defrag = false; bool do_copy = false; bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool hybrid = false; bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7436684f..4cda5330 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1731,3 +1731,11 @@ const char * llama_model_type_name(e_model type) { default: return "?B"; } } + +bool llama_model_is_recurrent(const llama_model * model) { + return llm_arch_is_recurrent(model->arch); +} + +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); +} diff --git a/src/llama-model.h b/src/llama-model.h index 53600194..2b415c54 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -507,3 +507,4 @@ struct LLM_TN { std::string llama_model_ftype_name(llama_ftype ftype); const char * llama_model_type_name(e_model type); + diff --git a/src/llama.cpp b/src/llama.cpp index 28ffdc4b..3d22de1e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -721,7 +721,8 @@ static bool llama_kv_cache_init( cache.has_shift = false; // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.recurrent = llm_arch_is_recurrent(model.arch); + cache.hybrid = llm_arch_is_hybrid(model.arch); // qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in // standard layout to match the mainline hybrid path when flash attention is off. cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE;