mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 17:14:17 +00:00
server: enable checkpoint for recurrent models (#1310)
* server: enable checkpoint for recurrent models create checkpoint after cancel fix ban string and rm context during rewind add checkpoint interval only save recurrent cache * save checkpoint during pp --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -671,7 +671,7 @@ static inline uint32_t llama_kv_v_row_embd(
|
||||
uint32_t il) {
|
||||
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
|
||||
// so per-token V rows include only attention values.
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
|
||||
if (llm_arch_is_hybrid(model.arch)) {
|
||||
return hparams.n_embd_v_gqa(il);
|
||||
}
|
||||
|
||||
@@ -732,7 +732,7 @@ static bool llama_kv_cache_init(
|
||||
cache.hybrid = llm_arch_is_hybrid(model.arch);
|
||||
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
|
||||
// standard layout to match the mainline hybrid path when flash attention is off.
|
||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE;
|
||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn && !llm_arch_is_hybrid(model.arch);
|
||||
|
||||
cache.head = 0;
|
||||
cache.size = kv_size;
|
||||
@@ -744,7 +744,7 @@ static bool llama_kv_cache_init(
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(kv_size);
|
||||
|
||||
if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
|
||||
if (cache.recurrent || llm_arch_is_hybrid(model.arch)) {
|
||||
// init state copy sources
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
cache.cells[i].src = i;
|
||||
@@ -829,7 +829,7 @@ static bool llama_kv_cache_init(
|
||||
std::vector<size_t> mem_split(model.splits.size(), 0);
|
||||
|
||||
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size);
|
||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
|
||||
if (llm_arch_is_hybrid(model.arch) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
|
||||
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
|
||||
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
|
||||
}
|
||||
@@ -1398,6 +1398,19 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
|
||||
return result;
|
||||
}
|
||||
|
||||
static llama_pos llama_kv_cache_seq_pos_min(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id)) {
|
||||
result = cache.cells[i].pos;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
||||
cache.do_defrag = true;
|
||||
}
|
||||
@@ -3227,7 +3240,7 @@ static int llama_decode_internal(
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) &&
|
||||
if (llm_arch_is_hybrid(model.arch) &&
|
||||
n_tokens > 1 &&
|
||||
batch_all.n_seq_id != nullptr &&
|
||||
batch_all.seq_id != nullptr) {
|
||||
@@ -5735,6 +5748,13 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
|
||||
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_min(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
if (ctx->kv_self.hybrid || ctx->kv_self.recurrent) {
|
||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||
}
|
||||
return llama_kv_cache_seq_pos_min(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||
}
|
||||
@@ -5876,10 +5896,11 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
|
||||
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1,
|
||||
llama_state_seq_flags flags = 0) {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
|
||||
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
|
||||
// v_state: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it may be the case with MLA)
|
||||
@@ -5895,7 +5916,7 @@ struct llama_data_write {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1;
|
||||
@@ -5924,7 +5945,7 @@ struct llama_data_write {
|
||||
if (v_state == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
||||
@@ -5951,7 +5972,7 @@ struct llama_data_write {
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
||||
@@ -6019,7 +6040,7 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
@@ -6055,7 +6076,7 @@ struct llama_data_write {
|
||||
write(&cell_count, sizeof(cell_count));
|
||||
|
||||
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
|
||||
write_kv_cache_data(ctx, cell_ranges, seq_id);
|
||||
write_kv_cache_data(ctx, cell_ranges, seq_id, flags);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -6266,10 +6287,10 @@ struct llama_data_read {
|
||||
GGML_ASSERT(sum_split_row_size == row_size);
|
||||
}
|
||||
|
||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
|
||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
|
||||
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
|
||||
// v_state: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it may be the case with MLA)
|
||||
@@ -6298,7 +6319,7 @@ struct llama_data_read {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
|
||||
|
||||
|
||||
// Read type of key
|
||||
@@ -6346,7 +6367,7 @@ struct llama_data_read {
|
||||
if (v_state == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@@ -6394,7 +6415,7 @@ struct llama_data_read {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@@ -6529,11 +6550,11 @@ struct llama_data_read {
|
||||
return true;
|
||||
}
|
||||
|
||||
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||
uint32_t cell_count;
|
||||
read_to(&cell_count, sizeof(cell_count));
|
||||
|
||||
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id);
|
||||
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id, flags);
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
@@ -6895,41 +6916,41 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
|
||||
}
|
||||
}
|
||||
|
||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
data_ctx.write_kv_cache(ctx, seq_id);
|
||||
data_ctx.write_kv_cache(ctx, seq_id, flags);
|
||||
|
||||
return data_ctx.get_size_written();
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_data_write_dummy data_ctx;
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_data_write_buffer data_ctx(dst, size, ctx->model);
|
||||
try {
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
||||
data_ctx.read_kv_cache(ctx, dest_seq_id, flags);
|
||||
|
||||
return data_ctx.get_size_read();
|
||||
}
|
||||
|
||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
|
||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
|
||||
llama_data_read_buffer data_ctx(src, size);
|
||||
try {
|
||||
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
||||
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@@ -6948,7 +6969,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_data_write_file data_ctx(&file, ctx->model);
|
||||
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, 0);
|
||||
|
||||
const size_t res = file.tell();
|
||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
||||
@@ -6986,7 +7007,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
|
||||
{
|
||||
const size_t state_size = file.size() - file.tell();
|
||||
llama_data_read_file data_ctx(&file);
|
||||
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
||||
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, 0);
|
||||
if (!nread) {
|
||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user