mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 11:51:53 +00:00
Qwen3.5-MoE: fix regenerating message error (#1295)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -414,6 +414,8 @@ struct gpt_params {
|
|||||||
std::string sqlite_zstd_ext_file;
|
std::string sqlite_zstd_ext_file;
|
||||||
|
|
||||||
float slot_prompt_similarity = 0.1f;
|
float slot_prompt_similarity = 0.1f;
|
||||||
|
|
||||||
|
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||||
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
||||||
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
||||||
|
|||||||
@@ -2587,6 +2587,123 @@ void server_context::add_sampled_tokens() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void server_context::apply_checkpoint(server_slot & slot) {
|
||||||
|
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
||||||
|
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||||
|
int32_t pos_min = 0;
|
||||||
|
if (llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||||
|
pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos_min > pos_min_thold+2) {
|
||||||
|
// TODO: support can be added in the future when corresponding vision models get released
|
||||||
|
GGML_ASSERT(!slot.cache_tokens.has_mtmd);
|
||||||
|
|
||||||
|
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
|
||||||
|
|
||||||
|
// search for a context checkpoint
|
||||||
|
const auto it = std::find_if(
|
||||||
|
slot.server_cached_prompt.checkpoints.rbegin(),
|
||||||
|
slot.server_cached_prompt.checkpoints.rend(),
|
||||||
|
[&](const auto & cur) {
|
||||||
|
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||||
|
return cur.pos_min < pos_min_thold;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
bool do_reset = it == slot.server_cached_prompt.checkpoints.rend();
|
||||||
|
|
||||||
|
if (!do_reset) {
|
||||||
|
// restore the context checkpoint
|
||||||
|
const size_t checkpoint_size = it->data.size();
|
||||||
|
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id);
|
||||||
|
|
||||||
|
if (n != checkpoint_size) {
|
||||||
|
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||||
|
do_reset = true;
|
||||||
|
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||||
|
} else {
|
||||||
|
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||||
|
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (do_reset) {
|
||||||
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||||
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||||
|
slot.n_past = 0;
|
||||||
|
slot.n_past_prompt = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// erase any checkpoints with pos_min > pos_min_thold
|
||||||
|
for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
|
||||||
|
const auto & cur = *it;
|
||||||
|
if (cur.pos_min > pos_min_thold) {
|
||||||
|
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||||
|
it = slot.server_cached_prompt.checkpoints.erase(it);
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void server_context::create_checkpoint(server_slot & slot) {
|
||||||
|
//bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||||
|
|
||||||
|
//// make checkpoints only for completion tasks
|
||||||
|
//do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
|
||||||
|
//// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||||
|
//// checkpoints are created only if:
|
||||||
|
//// - the model architecture is marked as recurrent or hybrid
|
||||||
|
////
|
||||||
|
//// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||||
|
//do_checkpoint = do_checkpoint && (
|
||||||
|
// llama_model_is_recurrent(model) ||
|
||||||
|
// llama_model_is_hybrid(model)
|
||||||
|
// );
|
||||||
|
//int32_t pos_min = 0;
|
||||||
|
//if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
|
||||||
|
// pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||||
|
//}
|
||||||
|
//const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||||
|
|
||||||
|
//// no need for empty or small checkpoints
|
||||||
|
//do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 5);
|
||||||
|
|
||||||
|
//// no need to create checkpoints that are too close together
|
||||||
|
//do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max + 64);
|
||||||
|
|
||||||
|
//if (do_checkpoint) {
|
||||||
|
// while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.n_ctx_checkpoints) {
|
||||||
|
// // make room for the new checkpoint, if needed
|
||||||
|
// const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
||||||
|
|
||||||
|
// SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||||
|
// cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||||
|
|
||||||
|
// slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id);
|
||||||
|
|
||||||
|
// auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||||
|
// /*.pos_min = */ pos_min,
|
||||||
|
// /*.pos_max = */ pos_max,
|
||||||
|
// /*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||||
|
// });
|
||||||
|
|
||||||
|
// llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id);
|
||||||
|
|
||||||
|
// SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||||
|
// (int)slot.server_cached_prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
|
||||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto& slot : slots) {
|
for (auto& slot : slots) {
|
||||||
@@ -2760,7 +2877,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
apply_checkpoint(slot);
|
||||||
if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) {
|
if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) {
|
||||||
// we have to evaluate at least 1 token to generate logits.
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
||||||
@@ -2916,6 +3033,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
|||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
|
//create_checkpoint(slot);
|
||||||
|
|
||||||
LOG_VERBOSE("prompt done", {
|
LOG_VERBOSE("prompt done", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
{"n_past", slot.n_past},
|
{"n_past", slot.n_past},
|
||||||
@@ -3008,10 +3127,11 @@ void server_context::speculative_decoding_accept() {
|
|||||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.n_buffer == 0) {
|
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||||
if (!process_token(result, slot)) {
|
if (!process_token(result, slot)) {
|
||||||
// release slot because of stop condition
|
// release slot because of stop condition
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
|
//create_checkpoint(slot);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.print_timings();
|
slot.print_timings();
|
||||||
metrics.on_prediction(slot);
|
metrics.on_prediction(slot);
|
||||||
@@ -3046,6 +3166,7 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
|||||||
count++;
|
count++;
|
||||||
if (!has_next) {
|
if (!has_next) {
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
|
//create_checkpoint(slot);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.print_timings();
|
slot.print_timings();
|
||||||
metrics.on_prediction(slot);
|
metrics.on_prediction(slot);
|
||||||
@@ -3259,7 +3380,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.n_buffer == 0) {
|
// no ban string for recurrent/hybrid model
|
||||||
|
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||||
slot.token_buffer = { result };
|
slot.token_buffer = { result };
|
||||||
send_token_results(slot.token_buffer, slot);
|
send_token_results(slot.token_buffer, slot);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -349,4 +349,7 @@ struct server_context {
|
|||||||
// Re-aggregates all active vectors and updates the model state
|
// Re-aggregates all active vectors and updates the model state
|
||||||
bool apply_control_vectors_internal();
|
bool apply_control_vectors_internal();
|
||||||
|
|
||||||
|
void create_checkpoint(server_slot & slot);
|
||||||
|
|
||||||
|
void apply_checkpoint(server_slot & slot);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -368,6 +368,7 @@ struct server_prompt {
|
|||||||
int n_tokens() const {
|
int n_tokens() const {
|
||||||
return tokens.size();
|
return tokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_prompt_cache {
|
struct server_prompt_cache {
|
||||||
|
|||||||
@@ -627,6 +627,12 @@ extern "C" {
|
|||||||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||||
|
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Returns true if the model is hybrid (like Jamba, Granite, etc.)
|
||||||
|
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
LLAMA_API uint32_t llama_model_quantize(
|
LLAMA_API uint32_t llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
|
|||||||
@@ -246,3 +246,23 @@ const char * llama_model_arch_name(llm_arch arch) {
|
|||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||||
|
switch (arch) {
|
||||||
|
case LLM_ARCH_MAMBA:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||||
|
switch (arch) {
|
||||||
|
case LLM_ARCH_QWEN3NEXT:
|
||||||
|
case LLM_ARCH_QWEN3MOE:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -342,3 +342,6 @@ enum llm_tensor {
|
|||||||
llm_arch llm_arch_from_string(const std::string & name);
|
llm_arch llm_arch_from_string(const std::string & name);
|
||||||
|
|
||||||
const char * llama_model_arch_name(llm_arch arch);
|
const char * llama_model_arch_name(llm_arch arch);
|
||||||
|
|
||||||
|
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||||
|
bool llm_arch_is_hybrid(const llm_arch & arch);
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ struct llama_kv_cache {
|
|||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
bool do_copy = false;
|
bool do_copy = false;
|
||||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||||
|
bool hybrid = false;
|
||||||
bool v_trans = true; // the value tensor is transposed
|
bool v_trans = true; // the value tensor is transposed
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
// Note: The value of head isn't only used to optimize searching
|
||||||
|
|||||||
@@ -1731,3 +1731,11 @@ const char * llama_model_type_name(e_model type) {
|
|||||||
default: return "?B";
|
default: return "?B";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_model_is_recurrent(const llama_model * model) {
|
||||||
|
return llm_arch_is_recurrent(model->arch);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_model_is_hybrid(const llama_model * model) {
|
||||||
|
return llm_arch_is_hybrid(model->arch);
|
||||||
|
}
|
||||||
|
|||||||
@@ -507,3 +507,4 @@ struct LLM_TN {
|
|||||||
std::string llama_model_ftype_name(llama_ftype ftype);
|
std::string llama_model_ftype_name(llama_ftype ftype);
|
||||||
|
|
||||||
const char * llama_model_type_name(e_model type);
|
const char * llama_model_type_name(e_model type);
|
||||||
|
|
||||||
|
|||||||
@@ -721,7 +721,8 @@ static bool llama_kv_cache_init(
|
|||||||
cache.has_shift = false;
|
cache.has_shift = false;
|
||||||
|
|
||||||
// TODO: find a nicer way to add other recurrent model architectures
|
// TODO: find a nicer way to add other recurrent model architectures
|
||||||
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
|
cache.recurrent = llm_arch_is_recurrent(model.arch);
|
||||||
|
cache.hybrid = llm_arch_is_hybrid(model.arch);
|
||||||
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
|
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
|
||||||
// standard layout to match the mainline hybrid path when flash attention is off.
|
// standard layout to match the mainline hybrid path when flash attention is off.
|
||||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE;
|
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE;
|
||||||
|
|||||||
Reference in New Issue
Block a user