mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 02:50:01 +00:00
save checkpoint during pp
This commit is contained in:
@@ -2247,7 +2247,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx });
|
||||
|
||||
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
|
||||
options.push_back({ "*", "--ctx-checkpoints-interval N", "number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
|
||||
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
|
||||
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
||||
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
|
||||
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
|
||||
|
||||
@@ -280,7 +280,8 @@ struct gpt_params {
|
||||
std::vector<std::string> ban_phrases; // strings that are banned in generation
|
||||
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
||||
size_t n_buffer = 0; // number of token buffers for string ban
|
||||
bool can_ban_phrases = true; // whether to ban strings
|
||||
bool can_ban_phrases = true; // whether to ban strings
|
||||
bool do_checkpoint = false; // do checkpoint for recurrent models only
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
@@ -420,7 +421,7 @@ struct gpt_params {
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot
|
||||
int32_t ctx_checkpoints_interval = 0; // number of tokens between each context checkpoints
|
||||
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
||||
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
||||
|
||||
@@ -361,7 +361,7 @@ void server_slot::reset() {
|
||||
rewind_status = false;
|
||||
|
||||
generated_token_probs.clear();
|
||||
|
||||
checkpoint_pos = 0;
|
||||
|
||||
// Reset speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
@@ -1248,6 +1248,16 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
||||
params_base.can_ban_phrases = false;
|
||||
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||
params_base.do_checkpoint = do_checkpoint;
|
||||
if (slot.n_buffer != 0) {
|
||||
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
|
||||
}
|
||||
@@ -2632,6 +2642,16 @@ void server_context::add_sampled_tokens() {
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
|
||||
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
||||
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
||||
create_checkpoint(slot);
|
||||
slot.checkpoint_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::apply_checkpoint(server_slot & slot) {
|
||||
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
||||
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
@@ -2696,20 +2716,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint(server_slot & slot) {
|
||||
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||
if (!do_checkpoint) {
|
||||
return;
|
||||
}
|
||||
bool do_checkpoint = true;
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
|
||||
@@ -2743,9 +2750,9 @@ void server_context::create_checkpoint(server_slot & slot) {
|
||||
|
||||
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n,
|
||||
(ggml_time_us() - t_start) / 1000.0, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
||||
(ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3221,7 +3228,9 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
|
||||
|
||||
void server_context::release_slot_after_final_response(server_slot & slot) {
|
||||
slot.print_timings();
|
||||
create_checkpoint(slot);
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
slot.release();
|
||||
slot.released = true;
|
||||
metrics.on_prediction(slot);
|
||||
@@ -3404,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
@@ -3447,14 +3459,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
if (params_base.ctx_checkpoints_interval > 0) {
|
||||
if (slot.n_decoded % params_base.ctx_checkpoints_interval == 0) {
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
if (slot.n_decoded > 1) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
result.tok = id;
|
||||
@@ -3516,7 +3530,7 @@ void server_context::update_slots() {
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
context_shift();
|
||||
|
||||
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
||||
@@ -104,6 +104,8 @@ struct server_slot {
|
||||
|
||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
||||
|
||||
size_t checkpoint_pos = 0;
|
||||
|
||||
// sampling
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_tokens drafted;
|
||||
@@ -358,5 +360,7 @@ struct server_context {
|
||||
|
||||
void apply_checkpoint(server_slot & slot);
|
||||
|
||||
void create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base);
|
||||
|
||||
void release_slot_after_final_response(server_slot & slot);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user