server: enable checkpoint for recurrent models (#1310)

* server: enable checkpoint for recurrent models

create checkpoint after cancel

fix ban string and rm context during rewind

add checkpoint interval

only save recurrent cache

* save checkpoint during pp

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-02-25 23:51:18 -06:00
committed by GitHub
parent 216f44363f
commit 3fac78c48b
11 changed files with 204 additions and 111 deletions

View File

@@ -671,7 +671,7 @@ static inline uint32_t llama_kv_v_row_embd(
uint32_t il) {
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
// so per-token V rows include only attention values.
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
if (llm_arch_is_hybrid(model.arch)) {
return hparams.n_embd_v_gqa(il);
}
@@ -732,7 +732,7 @@ static bool llama_kv_cache_init(
cache.hybrid = llm_arch_is_hybrid(model.arch);
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
// standard layout to match the mainline hybrid path when flash attention is off.
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE;
cache.v_trans = !cache.recurrent && !cparams.flash_attn && !llm_arch_is_hybrid(model.arch);
cache.head = 0;
cache.size = kv_size;
@@ -744,7 +744,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(kv_size);
if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
if (cache.recurrent || llm_arch_is_hybrid(model.arch)) {
// init state copy sources
for (uint32_t i = 0; i < cache.size; ++i) {
cache.cells[i].src = i;
@@ -829,7 +829,7 @@ static bool llama_kv_cache_init(
std::vector<size_t> mem_split(model.splits.size(), 0);
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size);
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
if (llm_arch_is_hybrid(model.arch) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
}
@@ -1398,6 +1398,19 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
return result;
}
static llama_pos llama_kv_cache_seq_pos_min(struct llama_kv_cache & cache, llama_seq_id seq_id) {
llama_pos result = -1;
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id)) {
result = cache.cells[i].pos;
break;
}
}
return result;
}
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
cache.do_defrag = true;
}
@@ -3227,7 +3240,7 @@ static int llama_decode_internal(
auto tim1 = ggml_time_us();
#endif
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) &&
if (llm_arch_is_hybrid(model.arch) &&
n_tokens > 1 &&
batch_all.n_seq_id != nullptr &&
batch_all.seq_id != nullptr) {
@@ -5735,6 +5748,13 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_seq_pos_min(struct llama_context * ctx, llama_seq_id seq_id) {
if (ctx->kv_self.hybrid || ctx->kv_self.recurrent) {
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
}
return llama_kv_cache_seq_pos_min(ctx->kv_self, seq_id);
}
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
}
@@ -5876,10 +5896,11 @@ struct llama_data_write {
}
}
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1,
llama_state_seq_flags flags = 0) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
// v_state: 0 -> not transposed V cache
// 1 -> transposed V cache
// 2 -> no V cache (as it may be the case with MLA)
@@ -5895,7 +5916,7 @@ struct llama_data_write {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const bool has_k_cache = kv_self.k_l[il] != nullptr;
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
// Write key type
const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1;
@@ -5924,7 +5945,7 @@ struct llama_data_write {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
// Write value type
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
@@ -5951,7 +5972,7 @@ struct llama_data_write {
const uint32_t kv_size = kv_self.size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
// Write value type
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
@@ -6019,7 +6040,7 @@ struct llama_data_write {
}
}
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
@@ -6055,7 +6076,7 @@ struct llama_data_write {
write(&cell_count, sizeof(cell_count));
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
write_kv_cache_data(ctx, cell_ranges, seq_id);
write_kv_cache_data(ctx, cell_ranges, seq_id, flags);
}
};
@@ -6266,10 +6287,10 @@ struct llama_data_read {
GGML_ASSERT(sum_split_row_size == row_size);
}
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_kv_cache & kv_self = ctx->kv_self;
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
// v_state: 0 -> not transposed V cache
// 1 -> transposed V cache
// 2 -> no V cache (as it may be the case with MLA)
@@ -6298,7 +6319,7 @@ struct llama_data_read {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const bool has_k_cache = kv_self.k_l[il] != nullptr;
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
// Read type of key
@@ -6346,7 +6367,7 @@ struct llama_data_read {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
// Read type of value
int32_t v_type_i_ref;
@@ -6394,7 +6415,7 @@ struct llama_data_read {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
// Read type of value
int32_t v_type_i_ref;
@@ -6529,11 +6550,11 @@ struct llama_data_read {
return true;
}
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
uint32_t cell_count;
read_to(&cell_count, sizeof(cell_count));
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id);
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id, flags);
if (!res) {
if (seq_id == -1) {
@@ -6895,41 +6916,41 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
}
}
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_synchronize(ctx);
data_ctx.write_kv_cache(ctx, seq_id);
data_ctx.write_kv_cache(ctx, seq_id, flags);
return data_ctx.get_size_written();
}
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_data_write_dummy data_ctx;
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
}
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_data_write_buffer data_ctx(dst, size, ctx->model);
try {
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
return 0;
}
}
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
llama_synchronize(ctx);
data_ctx.read_kv_cache(ctx, dest_seq_id);
data_ctx.read_kv_cache(ctx, dest_seq_id, flags);
return data_ctx.get_size_read();
}
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
llama_data_read_buffer data_ctx(src, size);
try {
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
return 0;
@@ -6948,7 +6969,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
// save the context state using stream saving
llama_data_write_file data_ctx(&file, ctx->model);
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, 0);
const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
@@ -6986,7 +7007,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
{
const size_t state_size = file.size() - file.tell();
llama_data_read_file data_ctx(&file);
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, 0);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;