save checkpoint during pp

This commit is contained in:
firecoperana
2026-02-25 19:05:54 -06:00
parent 233898704c
commit 7962e9a4b3
4 changed files with 47 additions and 28 deletions

View File

@@ -361,7 +361,7 @@ void server_slot::reset() {
rewind_status = false;
generated_token_probs.clear();
checkpoint_pos = 0;
// Reset speculative decoding stats
n_draft_total = 0;
@@ -1248,6 +1248,16 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
params_base.can_ban_phrases = false;
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
// make checkpoints only for completion tasks
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
params_base.do_checkpoint = do_checkpoint;
if (slot.n_buffer != 0) {
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
}
@@ -2632,6 +2642,16 @@ void server_context::add_sampled_tokens() {
}
}
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
create_checkpoint(slot);
slot.checkpoint_pos = pos;
}
}
}
void server_context::apply_checkpoint(server_slot & slot) {
const auto pos_min_thold = std::max(0, slot.n_past - 1);
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
@@ -2696,20 +2716,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
}
void server_context::create_checkpoint(server_slot & slot) {
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
// make checkpoints only for completion tasks
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
if (!do_checkpoint) {
return;
}
bool do_checkpoint = true;
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
@@ -2743,9 +2750,9 @@ void server_context::create_checkpoint(server_slot & slot) {
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created context checkpoint %d of %d took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n,
(ggml_time_us() - t_start) / 1000.0, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
(ggml_time_us() - t_start) / 1000.0);
}
}
@@ -3221,7 +3228,9 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
void server_context::release_slot_after_final_response(server_slot & slot) {
slot.print_timings();
create_checkpoint(slot);
if (params_base.do_checkpoint) {
create_checkpoint(slot);
}
slot.release();
slot.released = true;
metrics.on_prediction(slot);
@@ -3404,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
create_checkpoint_at_interval(slot, params_base);
}
continue; // continue loop of slots
}
@@ -3447,14 +3459,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
// save checkpoint during generation
if (params_base.ctx_checkpoints_interval > 0) {
if (slot.n_decoded % params_base.ctx_checkpoints_interval == 0) {
if (params_base.do_checkpoint) {
create_checkpoint(slot);
}
}
// save checkpoint during generation
if (slot.n_decoded > 1) {
create_checkpoint_at_interval(slot, params_base);
}
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
result.tok = id;
@@ -3516,7 +3530,7 @@ void server_context::update_slots() {
// apply context-shift if needed
// TODO: simplify and improve
context_shift();
// start populating the batch for this iteration
common_batch_clear(batch);