mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-26 16:14:10 +00:00
server: enable checkpoint for recurrent models (#1310)
* server: enable checkpoint for recurrent models create checkpoint after cancel fix ban string and rm context during rewind add checkpoint interval only save recurrent cache * save checkpoint during pp --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -2041,6 +2041,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--ctx-checkpoints") {
|
||||
CHECK_ARG
|
||||
params.ctx_checkpoints_n = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--ctx-checkpoints-interval") {
|
||||
CHECK_ARG
|
||||
params.ctx_checkpoints_interval = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-cram" || arg == "--cache-ram") {
|
||||
CHECK_ARG
|
||||
params.cache_ram_mib = std::stoi(argv[i]);
|
||||
@@ -2235,7 +2245,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
|
||||
options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
|
||||
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx });
|
||||
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
||||
|
||||
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
|
||||
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
|
||||
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
||||
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
|
||||
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
|
||||
options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
|
||||
|
||||
@@ -280,6 +280,8 @@ struct gpt_params {
|
||||
std::vector<std::string> ban_phrases; // strings that are banned in generation
|
||||
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
||||
size_t n_buffer = 0; // number of token buffers for string ban
|
||||
bool can_ban_phrases = true; // whether to ban strings
|
||||
bool do_checkpoint = false; // do checkpoint for recurrent models only
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
@@ -418,7 +420,8 @@ struct gpt_params {
|
||||
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||
int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot
|
||||
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
||||
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
||||
|
||||
@@ -188,8 +188,8 @@ int main(int argc, char ** argv) {
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
// save kv of seq 0
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
||||
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0, 0));
|
||||
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0, 0);
|
||||
if (ncopy != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||
llama_free(ctx3);
|
||||
@@ -203,7 +203,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 1
|
||||
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
|
||||
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1, 0);
|
||||
if (nset != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||
llama_free(ctx3);
|
||||
|
||||
@@ -315,7 +315,7 @@ void server_context::init() {
|
||||
void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
||||
assert(server_cached_prompt.data.size() == 0);
|
||||
|
||||
const size_t cur_size = llama_state_seq_get_size(ctx, id);
|
||||
const size_t cur_size = llama_state_seq_get_size(ctx, id, 0);
|
||||
|
||||
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
||||
(int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
||||
@@ -325,7 +325,7 @@ void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id);
|
||||
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0);
|
||||
}
|
||||
|
||||
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
|
||||
@@ -361,7 +361,7 @@ void server_slot::reset() {
|
||||
rewind_status = false;
|
||||
|
||||
generated_token_probs.clear();
|
||||
|
||||
checkpoint_pos = 0;
|
||||
|
||||
// Reset speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
@@ -1246,7 +1246,22 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
||||
}
|
||||
|
||||
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
||||
params_base.can_ban_phrases = false;
|
||||
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||
params_base.do_checkpoint = do_checkpoint;
|
||||
if (slot.n_buffer != 0) {
|
||||
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
|
||||
}
|
||||
}
|
||||
{
|
||||
const auto& stop = data.find("stop");
|
||||
if (stop != data.end() && stop->is_array()) {
|
||||
@@ -2142,7 +2157,7 @@ void server_context::process_single_task(server_task&& task) {
|
||||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||
slot->cache_tokens.clear();
|
||||
|
||||
server_task_result result;
|
||||
@@ -2552,6 +2567,7 @@ void server_context::context_shift() {
|
||||
|
||||
void server_context::add_sampled_tokens() {
|
||||
for (auto& slot : slots) {
|
||||
slot.released = false;
|
||||
if (slot.state == SLOT_STATE_IDLE) {
|
||||
continue;
|
||||
}
|
||||
@@ -2626,15 +2642,22 @@ void server_context::add_sampled_tokens() {
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
|
||||
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
||||
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
||||
create_checkpoint(slot);
|
||||
slot.checkpoint_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::apply_checkpoint(server_slot & slot) {
|
||||
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
||||
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
int32_t pos_min = 0;
|
||||
if (llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||
pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
}
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
|
||||
if (pos_min > pos_min_thold+2) {
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.cache_tokens.has_mtmd);
|
||||
|
||||
@@ -2654,8 +2677,9 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const int64_t t_start = ggml_time_us();
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id);
|
||||
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
@@ -2663,7 +2687,8 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
|
||||
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2691,56 +2716,44 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint(server_slot & slot) {
|
||||
//bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||
bool do_checkpoint = true;
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
|
||||
//// make checkpoints only for completion tasks
|
||||
//do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
|
||||
|
||||
//// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
//// checkpoints are created only if:
|
||||
//// - the model architecture is marked as recurrent or hybrid
|
||||
////
|
||||
//// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
//do_checkpoint = do_checkpoint && (
|
||||
// llama_model_is_recurrent(model) ||
|
||||
// llama_model_is_hybrid(model)
|
||||
// );
|
||||
//int32_t pos_min = 0;
|
||||
//if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
|
||||
// pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
//}
|
||||
//const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
|
||||
|
||||
//// no need for empty or small checkpoints
|
||||
//do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 5);
|
||||
if (do_checkpoint) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
||||
|
||||
//// no need to create checkpoints that are too close together
|
||||
//do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max + 64);
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
|
||||
//if (do_checkpoint) {
|
||||
// while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.n_ctx_checkpoints) {
|
||||
// // make room for the new checkpoint, if needed
|
||||
// const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
||||
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
// SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
// cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
// slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||
// }
|
||||
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
|
||||
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
// const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id);
|
||||
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
// auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
// /*.pos_min = */ pos_min,
|
||||
// /*.pos_max = */ pos_max,
|
||||
// /*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
// });
|
||||
|
||||
// llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id);
|
||||
|
||||
// SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
// (int)slot.server_cached_prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
//}
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
||||
(ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||
@@ -2798,8 +2811,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_buffer = 0;
|
||||
slot.token_buffer.clear();
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("prompt tokenized", {
|
||||
@@ -2900,6 +2911,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
}
|
||||
slot.n_past = prefix.first;
|
||||
slot.n_past_prompt = prefix.second;
|
||||
slot.n_past_offset = slot.n_past_prompt - slot.n_past;
|
||||
|
||||
if (slot.n_past != slot.n_past_prompt) {
|
||||
LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
|
||||
}
|
||||
@@ -3074,8 +3087,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
//create_checkpoint(slot);
|
||||
|
||||
LOG_VERBOSE("prompt done", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_past", slot.n_past},
|
||||
@@ -3187,14 +3198,11 @@ void server_context::speculative_decoding_accept() {
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
||||
}
|
||||
|
||||
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
send_final_response(slot);
|
||||
//create_checkpoint(slot);
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
metrics.on_prediction(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
@@ -3218,6 +3226,15 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
|
||||
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
||||
}
|
||||
|
||||
void server_context::release_slot_after_final_response(server_slot & slot) {
|
||||
slot.print_timings();
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
slot.release();
|
||||
slot.released = true;
|
||||
metrics.on_prediction(slot);
|
||||
}
|
||||
|
||||
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
||||
int count = 0;
|
||||
@@ -3226,10 +3243,7 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
||||
count++;
|
||||
if (!has_next) {
|
||||
send_final_response(slot);
|
||||
//create_checkpoint(slot);
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
metrics.on_prediction(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
break;
|
||||
}
|
||||
if (n > 0 && count >= n) {
|
||||
@@ -3266,7 +3280,7 @@ inline int32_t check_ban_phrase(const server_slot& slot) {
|
||||
}
|
||||
if (found) {
|
||||
std::vector<size_t> unused;
|
||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str());
|
||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str());
|
||||
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
|
||||
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
|
||||
}
|
||||
@@ -3299,6 +3313,8 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
||||
size_t n_keep = slot.cache_tokens.size() - n_rewind;
|
||||
slot.sampled = slot.cache_tokens[n_keep];
|
||||
slot.cache_tokens.keep_first(n_keep);
|
||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1);
|
||||
|
||||
}
|
||||
|
||||
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
||||
@@ -3397,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
@@ -3440,6 +3459,14 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
if (params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
if (slot.n_decoded > 1) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
@@ -3452,7 +3479,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
}
|
||||
|
||||
// no ban string for recurrent/hybrid model
|
||||
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
||||
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
||||
slot.token_buffer = { result };
|
||||
send_token_results(slot.token_buffer, slot);
|
||||
} else {
|
||||
@@ -3503,7 +3530,7 @@ void server_context::update_slots() {
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
context_shift();
|
||||
|
||||
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ struct server_slot {
|
||||
llama_batch batch_spec = {};
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
bool released = false;
|
||||
slot_state state = SLOT_STATE_IDLE;
|
||||
slot_command command = SLOT_COMMAND_NONE;
|
||||
|
||||
@@ -45,6 +46,7 @@ struct server_slot {
|
||||
int32_t n_ctx = 0; // context size per slot
|
||||
int32_t n_past = 0;
|
||||
int32_t n_past_prompt = 0;
|
||||
int32_t n_past_offset = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t n_remaining = -1;
|
||||
int32_t n_discarded_prompt = 0;
|
||||
@@ -102,6 +104,8 @@ struct server_slot {
|
||||
|
||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
||||
|
||||
size_t checkpoint_pos = 0;
|
||||
|
||||
// sampling
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_tokens drafted;
|
||||
@@ -355,4 +359,8 @@ struct server_context {
|
||||
void create_checkpoint(server_slot & slot);
|
||||
|
||||
void apply_checkpoint(server_slot & slot);
|
||||
|
||||
void create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base);
|
||||
|
||||
void release_slot_after_final_response(server_slot & slot);
|
||||
};
|
||||
|
||||
@@ -1117,7 +1117,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
|
||||
if (it_best != states.end()) {
|
||||
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt);
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
|
||||
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
LLAMA_LOG_INFO("failed to restore state with size %zu\n", size);
|
||||
return false;
|
||||
|
||||
@@ -344,6 +344,8 @@ using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
llama_pos pos_min_prompt;
|
||||
llama_pos pos_max_prompt;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
|
||||
@@ -645,6 +645,8 @@ extern "C" {
|
||||
// Returns true if the model is hybrid (like Jamba, Granite, etc.)
|
||||
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
|
||||
|
||||
LLAMA_API bool llama_model_has_recurrent(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
@@ -735,6 +737,11 @@ extern "C" {
|
||||
llama_seq_id * cells_sequences;
|
||||
};
|
||||
|
||||
// work only with partial states, such as recurrent cache (e.g. Mamba)
|
||||
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
|
||||
|
||||
typedef uint32_t llama_state_seq_flags;
|
||||
|
||||
// Create an empty KV cache view. (use only for debugging purposes)
|
||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
||||
|
||||
@@ -813,6 +820,11 @@ extern "C" {
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
LLAMA_API llama_pos llama_kv_cache_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
@@ -889,14 +901,16 @@ extern "C" {
|
||||
// Get the exact size needed to copy the KV cache of a single sequence
|
||||
LLAMA_API size_t llama_state_seq_get_size(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags);
|
||||
|
||||
// Copy the KV cache of a single sequence into the specified buffer
|
||||
LLAMA_API size_t llama_state_seq_get_data(
|
||||
struct llama_context * ctx,
|
||||
uint8_t * dst,
|
||||
size_t size,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags);
|
||||
|
||||
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
||||
// Returns:
|
||||
@@ -906,7 +920,8 @@ extern "C" {
|
||||
struct llama_context * ctx,
|
||||
const uint8_t * src,
|
||||
size_t size,
|
||||
llama_seq_id dest_seq_id);
|
||||
llama_seq_id dest_seq_id,
|
||||
llama_state_seq_flags flags);
|
||||
|
||||
LLAMA_API size_t llama_state_seq_save_file(
|
||||
struct llama_context * ctx,
|
||||
|
||||
@@ -142,7 +142,7 @@ ggml_cgraph * llm_build_context::build_k_shift() {
|
||||
ggml_set_input(lctx.inp_K_shift);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) {
|
||||
if (llm_arch_is_hybrid(model.arch) && hparams.is_recurrent(il)) {
|
||||
continue;
|
||||
}
|
||||
if (kv_self.k_l[il] == nullptr) {
|
||||
@@ -241,7 +241,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) {
|
||||
if (llm_arch_is_hybrid(model.arch) && hparams.is_recurrent(il)) {
|
||||
continue;
|
||||
}
|
||||
if (kv_self.k_l[il] == nullptr) {
|
||||
|
||||
@@ -1739,3 +1739,7 @@ bool llama_model_is_recurrent(const llama_model * model) {
|
||||
bool llama_model_is_hybrid(const llama_model * model) {
|
||||
return llm_arch_is_hybrid(model->arch);
|
||||
}
|
||||
|
||||
bool llama_model_has_recurrent(const llama_model * model) {
|
||||
return llm_arch_is_hybrid(model->arch) || llm_arch_is_recurrent(model->arch);
|
||||
}
|
||||
|
||||
@@ -671,7 +671,7 @@ static inline uint32_t llama_kv_v_row_embd(
|
||||
uint32_t il) {
|
||||
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
|
||||
// so per-token V rows include only attention values.
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
|
||||
if (llm_arch_is_hybrid(model.arch)) {
|
||||
return hparams.n_embd_v_gqa(il);
|
||||
}
|
||||
|
||||
@@ -732,7 +732,7 @@ static bool llama_kv_cache_init(
|
||||
cache.hybrid = llm_arch_is_hybrid(model.arch);
|
||||
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
|
||||
// standard layout to match the mainline hybrid path when flash attention is off.
|
||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE;
|
||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn && !llm_arch_is_hybrid(model.arch);
|
||||
|
||||
cache.head = 0;
|
||||
cache.size = kv_size;
|
||||
@@ -744,7 +744,7 @@ static bool llama_kv_cache_init(
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(kv_size);
|
||||
|
||||
if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
|
||||
if (cache.recurrent || llm_arch_is_hybrid(model.arch)) {
|
||||
// init state copy sources
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
cache.cells[i].src = i;
|
||||
@@ -829,7 +829,7 @@ static bool llama_kv_cache_init(
|
||||
std::vector<size_t> mem_split(model.splits.size(), 0);
|
||||
|
||||
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size);
|
||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
|
||||
if (llm_arch_is_hybrid(model.arch) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
|
||||
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
|
||||
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
|
||||
}
|
||||
@@ -1398,6 +1398,19 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
|
||||
return result;
|
||||
}
|
||||
|
||||
static llama_pos llama_kv_cache_seq_pos_min(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id)) {
|
||||
result = cache.cells[i].pos;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
||||
cache.do_defrag = true;
|
||||
}
|
||||
@@ -3227,7 +3240,7 @@ static int llama_decode_internal(
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) &&
|
||||
if (llm_arch_is_hybrid(model.arch) &&
|
||||
n_tokens > 1 &&
|
||||
batch_all.n_seq_id != nullptr &&
|
||||
batch_all.seq_id != nullptr) {
|
||||
@@ -5735,6 +5748,13 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
|
||||
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_min(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
if (ctx->kv_self.hybrid || ctx->kv_self.recurrent) {
|
||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||
}
|
||||
return llama_kv_cache_seq_pos_min(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||
}
|
||||
@@ -5876,10 +5896,11 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
|
||||
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1,
|
||||
llama_state_seq_flags flags = 0) {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
|
||||
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
|
||||
// v_state: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it may be the case with MLA)
|
||||
@@ -5895,7 +5916,7 @@ struct llama_data_write {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1;
|
||||
@@ -5924,7 +5945,7 @@ struct llama_data_write {
|
||||
if (v_state == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
||||
@@ -5951,7 +5972,7 @@ struct llama_data_write {
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
||||
@@ -6019,7 +6040,7 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
@@ -6055,7 +6076,7 @@ struct llama_data_write {
|
||||
write(&cell_count, sizeof(cell_count));
|
||||
|
||||
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
|
||||
write_kv_cache_data(ctx, cell_ranges, seq_id);
|
||||
write_kv_cache_data(ctx, cell_ranges, seq_id, flags);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -6266,10 +6287,10 @@ struct llama_data_read {
|
||||
GGML_ASSERT(sum_split_row_size == row_size);
|
||||
}
|
||||
|
||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
|
||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
|
||||
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
|
||||
// v_state: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it may be the case with MLA)
|
||||
@@ -6298,7 +6319,7 @@ struct llama_data_read {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr;
|
||||
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
|
||||
|
||||
|
||||
// Read type of key
|
||||
@@ -6346,7 +6367,7 @@ struct llama_data_read {
|
||||
if (v_state == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@@ -6394,7 +6415,7 @@ struct llama_data_read {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
||||
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@@ -6529,11 +6550,11 @@ struct llama_data_read {
|
||||
return true;
|
||||
}
|
||||
|
||||
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||
uint32_t cell_count;
|
||||
read_to(&cell_count, sizeof(cell_count));
|
||||
|
||||
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id);
|
||||
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id, flags);
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
@@ -6895,41 +6916,41 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
|
||||
}
|
||||
}
|
||||
|
||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
data_ctx.write_kv_cache(ctx, seq_id);
|
||||
data_ctx.write_kv_cache(ctx, seq_id, flags);
|
||||
|
||||
return data_ctx.get_size_written();
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_data_write_dummy data_ctx;
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_data_write_buffer data_ctx(dst, size, ctx->model);
|
||||
try {
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
||||
data_ctx.read_kv_cache(ctx, dest_seq_id, flags);
|
||||
|
||||
return data_ctx.get_size_read();
|
||||
}
|
||||
|
||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
|
||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
|
||||
llama_data_read_buffer data_ctx(src, size);
|
||||
try {
|
||||
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
||||
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@@ -6948,7 +6969,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_data_write_file data_ctx(&file, ctx->model);
|
||||
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, 0);
|
||||
|
||||
const size_t res = file.tell();
|
||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
||||
@@ -6986,7 +7007,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
|
||||
{
|
||||
const size_t state_size = file.size() - file.tell();
|
||||
llama_data_read_file data_ctx(&file);
|
||||
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
||||
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, 0);
|
||||
if (!nread) {
|
||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user