mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 00:24:11 +00:00
server: enable checkpoint for recurrent models (#1310)
* server: enable checkpoint for recurrent models create checkpoint after cancel fix ban string and rm context during rewind add checkpoint interval only save recurrent cache * save checkpoint during pp --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -315,7 +315,7 @@ void server_context::init() {
|
||||
void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
||||
assert(server_cached_prompt.data.size() == 0);
|
||||
|
||||
const size_t cur_size = llama_state_seq_get_size(ctx, id);
|
||||
const size_t cur_size = llama_state_seq_get_size(ctx, id, 0);
|
||||
|
||||
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
||||
(int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
||||
@@ -325,7 +325,7 @@ void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id);
|
||||
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0);
|
||||
}
|
||||
|
||||
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
|
||||
@@ -361,7 +361,7 @@ void server_slot::reset() {
|
||||
rewind_status = false;
|
||||
|
||||
generated_token_probs.clear();
|
||||
|
||||
checkpoint_pos = 0;
|
||||
|
||||
// Reset speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
@@ -1246,7 +1246,22 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
||||
}
|
||||
|
||||
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
||||
params_base.can_ban_phrases = false;
|
||||
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||
params_base.do_checkpoint = do_checkpoint;
|
||||
if (slot.n_buffer != 0) {
|
||||
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
|
||||
}
|
||||
}
|
||||
{
|
||||
const auto& stop = data.find("stop");
|
||||
if (stop != data.end() && stop->is_array()) {
|
||||
@@ -2142,7 +2157,7 @@ void server_context::process_single_task(server_task&& task) {
|
||||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||
slot->cache_tokens.clear();
|
||||
|
||||
server_task_result result;
|
||||
@@ -2552,6 +2567,7 @@ void server_context::context_shift() {
|
||||
|
||||
void server_context::add_sampled_tokens() {
|
||||
for (auto& slot : slots) {
|
||||
slot.released = false;
|
||||
if (slot.state == SLOT_STATE_IDLE) {
|
||||
continue;
|
||||
}
|
||||
@@ -2626,15 +2642,22 @@ void server_context::add_sampled_tokens() {
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
|
||||
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
||||
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
||||
create_checkpoint(slot);
|
||||
slot.checkpoint_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::apply_checkpoint(server_slot & slot) {
|
||||
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
||||
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
int32_t pos_min = 0;
|
||||
if (llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||
pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
}
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
|
||||
if (pos_min > pos_min_thold+2) {
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.cache_tokens.has_mtmd);
|
||||
|
||||
@@ -2654,8 +2677,9 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const int64_t t_start = ggml_time_us();
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id);
|
||||
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
@@ -2663,7 +2687,8 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
|
||||
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2691,56 +2716,44 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint(server_slot & slot) {
|
||||
//bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||
bool do_checkpoint = true;
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
|
||||
//// make checkpoints only for completion tasks
|
||||
//do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
|
||||
|
||||
//// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
//// checkpoints are created only if:
|
||||
//// - the model architecture is marked as recurrent or hybrid
|
||||
////
|
||||
//// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
//do_checkpoint = do_checkpoint && (
|
||||
// llama_model_is_recurrent(model) ||
|
||||
// llama_model_is_hybrid(model)
|
||||
// );
|
||||
//int32_t pos_min = 0;
|
||||
//if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
|
||||
// pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
//}
|
||||
//const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
|
||||
|
||||
//// no need for empty or small checkpoints
|
||||
//do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 5);
|
||||
if (do_checkpoint) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
||||
|
||||
//// no need to create checkpoints that are too close together
|
||||
//do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max + 64);
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
|
||||
//if (do_checkpoint) {
|
||||
// while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.n_ctx_checkpoints) {
|
||||
// // make room for the new checkpoint, if needed
|
||||
// const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
||||
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
// SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
// cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
// slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||
// }
|
||||
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
|
||||
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
// const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id);
|
||||
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
// auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
// /*.pos_min = */ pos_min,
|
||||
// /*.pos_max = */ pos_max,
|
||||
// /*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
// });
|
||||
|
||||
// llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id);
|
||||
|
||||
// SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
// (int)slot.server_cached_prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
//}
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
||||
(ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||
@@ -2798,8 +2811,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_buffer = 0;
|
||||
slot.token_buffer.clear();
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("prompt tokenized", {
|
||||
@@ -2900,6 +2911,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
slot.n_past = prefix.first;
|
||||
slot.n_past_prompt = prefix.second;
|
||||
slot.n_past_offset = slot.n_past_prompt - slot.n_past;
|
||||
|
||||
if (slot.n_past != slot.n_past_prompt) {
|
||||
LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
|
||||
}
|
||||
@@ -3074,8 +3087,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
//create_checkpoint(slot);
|
||||
|
||||
LOG_VERBOSE("prompt done", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_past", slot.n_past},
|
||||
@@ -3187,14 +3198,11 @@ void server_context::speculative_decoding_accept() {
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
||||
}
|
||||
|
||||
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
send_final_response(slot);
|
||||
//create_checkpoint(slot);
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
metrics.on_prediction(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
@@ -3218,6 +3226,15 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
|
||||
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
||||
}
|
||||
|
||||
void server_context::release_slot_after_final_response(server_slot & slot) {
|
||||
slot.print_timings();
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
slot.release();
|
||||
slot.released = true;
|
||||
metrics.on_prediction(slot);
|
||||
}
|
||||
|
||||
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
||||
int count = 0;
|
||||
@@ -3226,10 +3243,7 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
||||
count++;
|
||||
if (!has_next) {
|
||||
send_final_response(slot);
|
||||
//create_checkpoint(slot);
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
metrics.on_prediction(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
break;
|
||||
}
|
||||
if (n > 0 && count >= n) {
|
||||
@@ -3266,7 +3280,7 @@ inline int32_t check_ban_phrase(const server_slot& slot) {
|
||||
}
|
||||
if (found) {
|
||||
std::vector<size_t> unused;
|
||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str());
|
||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str());
|
||||
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
|
||||
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
|
||||
}
|
||||
@@ -3299,6 +3313,8 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
||||
size_t n_keep = slot.cache_tokens.size() - n_rewind;
|
||||
slot.sampled = slot.cache_tokens[n_keep];
|
||||
slot.cache_tokens.keep_first(n_keep);
|
||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1);
|
||||
|
||||
}
|
||||
|
||||
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
||||
@@ -3397,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
@@ -3440,6 +3459,14 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
if (slot.n_decoded > 1) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
@@ -3452,7 +3479,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
|
||||
// no ban string for recurrent/hybrid model
|
||||
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
||||
slot.token_buffer = { result };
|
||||
send_token_results(slot.token_buffer, slot);
|
||||
} else {
|
||||
@@ -3503,7 +3530,7 @@ void server_context::update_slots() {
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
context_shift();
|
||||
|
||||
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user