mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 11:00:00 +00:00
save checkpoint during pp
This commit is contained in:
@@ -361,7 +361,7 @@ void server_slot::reset() {
|
||||
rewind_status = false;
|
||||
|
||||
generated_token_probs.clear();
|
||||
|
||||
checkpoint_pos = 0;
|
||||
|
||||
// Reset speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
@@ -1248,6 +1248,16 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
||||
params_base.can_ban_phrases = false;
|
||||
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||
params_base.do_checkpoint = do_checkpoint;
|
||||
if (slot.n_buffer != 0) {
|
||||
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
|
||||
}
|
||||
@@ -2632,6 +2642,16 @@ void server_context::add_sampled_tokens() {
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
|
||||
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
||||
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
||||
create_checkpoint(slot);
|
||||
slot.checkpoint_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::apply_checkpoint(server_slot & slot) {
|
||||
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
||||
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
@@ -2696,20 +2716,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint(server_slot & slot) {
|
||||
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||
if (!do_checkpoint) {
|
||||
return;
|
||||
}
|
||||
bool do_checkpoint = true;
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
|
||||
@@ -2743,9 +2750,9 @@ void server_context::create_checkpoint(server_slot & slot) {
|
||||
|
||||
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n,
|
||||
(ggml_time_us() - t_start) / 1000.0, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
||||
(ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3221,7 +3228,9 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
|
||||
|
||||
void server_context::release_slot_after_final_response(server_slot & slot) {
|
||||
slot.print_timings();
|
||||
create_checkpoint(slot);
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
slot.release();
|
||||
slot.released = true;
|
||||
metrics.on_prediction(slot);
|
||||
@@ -3404,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
@@ -3447,14 +3459,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
if (params_base.ctx_checkpoints_interval > 0) {
|
||||
if (slot.n_decoded % params_base.ctx_checkpoints_interval == 0) {
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
if (slot.n_decoded > 1) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
result.tok = id;
|
||||
@@ -3516,7 +3530,7 @@ void server_context::update_slots() {
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
context_shift();
|
||||
|
||||
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user