mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
server: enable checkpoint for recurrent models (#1310)
* server: enable checkpoint for recurrent models create checkpoint after cancel fix ban string and rm context during rewind add checkpoint interval only save recurrent cache * save checkpoint during pp --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -2041,6 +2041,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--ctx-checkpoints") {
|
||||||
|
CHECK_ARG
|
||||||
|
params.ctx_checkpoints_n = std::stoi(argv[i]);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (arg == "--ctx-checkpoints-interval") {
|
||||||
|
CHECK_ARG
|
||||||
|
params.ctx_checkpoints_interval = std::stoi(argv[i]);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-cram" || arg == "--cache-ram") {
|
if (arg == "-cram" || arg == "--cache-ram") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.cache_ram_mib = std::stoi(argv[i]);
|
params.cache_ram_mib = std::stoi(argv[i]);
|
||||||
@@ -2235,7 +2245,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
|
|
||||||
options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
|
options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
|
||||||
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx });
|
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx });
|
||||||
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
|
||||||
|
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
|
||||||
|
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
|
||||||
|
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
||||||
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
|
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
|
||||||
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
|
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
|
||||||
options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
|
options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
|
||||||
|
|||||||
@@ -280,6 +280,8 @@ struct gpt_params {
|
|||||||
std::vector<std::string> ban_phrases; // strings that are banned in generation
|
std::vector<std::string> ban_phrases; // strings that are banned in generation
|
||||||
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
||||||
size_t n_buffer = 0; // number of token buffers for string ban
|
size_t n_buffer = 0; // number of token buffers for string ban
|
||||||
|
bool can_ban_phrases = true; // whether to ban strings
|
||||||
|
bool do_checkpoint = false; // do checkpoint for recurrent models only
|
||||||
|
|
||||||
std::vector<llama_model_kv_override> kv_overrides;
|
std::vector<llama_model_kv_override> kv_overrides;
|
||||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||||
@@ -418,7 +420,8 @@ struct gpt_params {
|
|||||||
|
|
||||||
float slot_prompt_similarity = 0.1f;
|
float slot_prompt_similarity = 0.1f;
|
||||||
|
|
||||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot
|
||||||
|
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
||||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||||
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
||||||
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
||||||
|
|||||||
@@ -188,8 +188,8 @@ int main(int argc, char ** argv) {
|
|||||||
// save seq 0 and load into seq 1
|
// save seq 0 and load into seq 1
|
||||||
{
|
{
|
||||||
// save kv of seq 0
|
// save kv of seq 0
|
||||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0, 0));
|
||||||
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
|
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0, 0);
|
||||||
if (ncopy != seq_store.size()) {
|
if (ncopy != seq_store.size()) {
|
||||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
@@ -203,7 +203,7 @@ int main(int argc, char ** argv) {
|
|||||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||||
|
|
||||||
// restore kv into seq 1
|
// restore kv into seq 1
|
||||||
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
|
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1, 0);
|
||||||
if (nset != seq_store.size()) {
|
if (nset != seq_store.size()) {
|
||||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
|
|||||||
@@ -315,7 +315,7 @@ void server_context::init() {
|
|||||||
void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
||||||
assert(server_cached_prompt.data.size() == 0);
|
assert(server_cached_prompt.data.size() == 0);
|
||||||
|
|
||||||
const size_t cur_size = llama_state_seq_get_size(ctx, id);
|
const size_t cur_size = llama_state_seq_get_size(ctx, id, 0);
|
||||||
|
|
||||||
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
||||||
(int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
(int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
||||||
@@ -325,7 +325,7 @@ void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id);
|
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
|
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
|
||||||
@@ -361,7 +361,7 @@ void server_slot::reset() {
|
|||||||
rewind_status = false;
|
rewind_status = false;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
|
checkpoint_pos = 0;
|
||||||
|
|
||||||
// Reset speculative decoding stats
|
// Reset speculative decoding stats
|
||||||
n_draft_total = 0;
|
n_draft_total = 0;
|
||||||
@@ -1246,7 +1246,22 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
|||||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||||
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
||||||
}
|
}
|
||||||
|
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
||||||
|
params_base.can_ban_phrases = false;
|
||||||
|
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
||||||
|
// make checkpoints only for completion tasks
|
||||||
|
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||||
|
// checkpoints are created only if:
|
||||||
|
// - the model architecture is marked as recurrent or hybrid
|
||||||
|
//
|
||||||
|
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||||
|
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||||
|
params_base.do_checkpoint = do_checkpoint;
|
||||||
|
if (slot.n_buffer != 0) {
|
||||||
|
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
{
|
{
|
||||||
const auto& stop = data.find("stop");
|
const auto& stop = data.find("stop");
|
||||||
if (stop != data.end() && stop->is_array()) {
|
if (stop != data.end() && stop->is_array()) {
|
||||||
@@ -2142,7 +2157,7 @@ void server_context::process_single_task(server_task&& task) {
|
|||||||
|
|
||||||
// Erase token cache
|
// Erase token cache
|
||||||
const size_t n_erased = slot->cache_tokens.size();
|
const size_t n_erased = slot->cache_tokens.size();
|
||||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||||
slot->cache_tokens.clear();
|
slot->cache_tokens.clear();
|
||||||
|
|
||||||
server_task_result result;
|
server_task_result result;
|
||||||
@@ -2552,6 +2567,7 @@ void server_context::context_shift() {
|
|||||||
|
|
||||||
void server_context::add_sampled_tokens() {
|
void server_context::add_sampled_tokens() {
|
||||||
for (auto& slot : slots) {
|
for (auto& slot : slots) {
|
||||||
|
slot.released = false;
|
||||||
if (slot.state == SLOT_STATE_IDLE) {
|
if (slot.state == SLOT_STATE_IDLE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -2626,15 +2642,22 @@ void server_context::add_sampled_tokens() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
|
||||||
|
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
||||||
|
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||||
|
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
||||||
|
create_checkpoint(slot);
|
||||||
|
slot.checkpoint_pos = pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void server_context::apply_checkpoint(server_slot & slot) {
|
void server_context::apply_checkpoint(server_slot & slot) {
|
||||||
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
||||||
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||||
int32_t pos_min = 0;
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||||
if (llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
|
||||||
pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pos_min > pos_min_thold+2) {
|
if (pos_min > pos_min_thold) {
|
||||||
// TODO: support can be added in the future when corresponding vision models get released
|
// TODO: support can be added in the future when corresponding vision models get released
|
||||||
GGML_ASSERT(!slot.cache_tokens.has_mtmd);
|
GGML_ASSERT(!slot.cache_tokens.has_mtmd);
|
||||||
|
|
||||||
@@ -2654,8 +2677,9 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
|||||||
|
|
||||||
if (!do_reset) {
|
if (!do_reset) {
|
||||||
// restore the context checkpoint
|
// restore the context checkpoint
|
||||||
|
const int64_t t_start = ggml_time_us();
|
||||||
const size_t checkpoint_size = it->data.size();
|
const size_t checkpoint_size = it->data.size();
|
||||||
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id);
|
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
|
||||||
if (n != checkpoint_size) {
|
if (n != checkpoint_size) {
|
||||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||||
@@ -2663,7 +2687,8 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
|||||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||||
} else {
|
} else {
|
||||||
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
|
||||||
|
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2691,56 +2716,44 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void server_context::create_checkpoint(server_slot & slot) {
|
void server_context::create_checkpoint(server_slot & slot) {
|
||||||
//bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
bool do_checkpoint = true;
|
||||||
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||||
|
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||||
|
|
||||||
//// make checkpoints only for completion tasks
|
// no need for empty or small checkpoints
|
||||||
//do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
|
||||||
|
|
||||||
//// make a checkpoint of the parts of the memory that cannot be rolled back.
|
// no need to create checkpoints that are too close together
|
||||||
//// checkpoints are created only if:
|
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
|
||||||
//// - the model architecture is marked as recurrent or hybrid
|
|
||||||
////
|
|
||||||
//// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
|
||||||
//do_checkpoint = do_checkpoint && (
|
|
||||||
// llama_model_is_recurrent(model) ||
|
|
||||||
// llama_model_is_hybrid(model)
|
|
||||||
// );
|
|
||||||
//int32_t pos_min = 0;
|
|
||||||
//if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
|
|
||||||
// pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
||||||
//}
|
|
||||||
//const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
||||||
|
|
||||||
//// no need for empty or small checkpoints
|
if (do_checkpoint) {
|
||||||
//do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 5);
|
const int64_t t_start = ggml_time_us();
|
||||||
|
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
|
||||||
|
// make room for the new checkpoint, if needed
|
||||||
|
const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
||||||
|
|
||||||
//// no need to create checkpoints that are too close together
|
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||||
//do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max + 64);
|
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||||
|
|
||||||
//if (do_checkpoint) {
|
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||||
// while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.n_ctx_checkpoints) {
|
}
|
||||||
// // make room for the new checkpoint, if needed
|
|
||||||
// const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
|
||||||
|
|
||||||
// SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
// cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
|
||||||
|
|
||||||
// slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||||
// }
|
/*.pos_min = */ pos_min,
|
||||||
|
/*.pos_max = */ pos_max,
|
||||||
|
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
|
||||||
|
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
|
||||||
|
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||||
|
});
|
||||||
|
|
||||||
// const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id);
|
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
|
||||||
// auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
||||||
// /*.pos_min = */ pos_min,
|
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
||||||
// /*.pos_max = */ pos_max,
|
(ggml_time_us() - t_start) / 1000.0);
|
||||||
// /*.data = */ std::vector<uint8_t>(checkpoint_size),
|
}
|
||||||
// });
|
|
||||||
|
|
||||||
// llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id);
|
|
||||||
|
|
||||||
// SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
||||||
// (int)slot.server_cached_prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||||
@@ -2798,8 +2811,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_buffer = 0;
|
|
||||||
slot.token_buffer.clear();
|
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
LOG_VERBOSE("prompt tokenized", {
|
LOG_VERBOSE("prompt tokenized", {
|
||||||
@@ -2900,6 +2911,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
|||||||
}
|
}
|
||||||
slot.n_past = prefix.first;
|
slot.n_past = prefix.first;
|
||||||
slot.n_past_prompt = prefix.second;
|
slot.n_past_prompt = prefix.second;
|
||||||
|
slot.n_past_offset = slot.n_past_prompt - slot.n_past;
|
||||||
|
|
||||||
if (slot.n_past != slot.n_past_prompt) {
|
if (slot.n_past != slot.n_past_prompt) {
|
||||||
LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
|
LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
|
||||||
}
|
}
|
||||||
@@ -3074,8 +3087,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
|||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
//create_checkpoint(slot);
|
|
||||||
|
|
||||||
LOG_VERBOSE("prompt done", {
|
LOG_VERBOSE("prompt done", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
{"n_past", slot.n_past},
|
{"n_past", slot.n_past},
|
||||||
@@ -3187,14 +3198,11 @@ void server_context::speculative_decoding_accept() {
|
|||||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
||||||
if (!process_token(result, slot)) {
|
if (!process_token(result, slot)) {
|
||||||
// release slot because of stop condition
|
// release slot because of stop condition
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
//create_checkpoint(slot);
|
release_slot_after_final_response(slot);
|
||||||
slot.release();
|
|
||||||
slot.print_timings();
|
|
||||||
metrics.on_prediction(slot);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -3218,6 +3226,15 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
|
|||||||
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void server_context::release_slot_after_final_response(server_slot & slot) {
|
||||||
|
slot.print_timings();
|
||||||
|
if (params_base.do_checkpoint) {
|
||||||
|
create_checkpoint(slot);
|
||||||
|
}
|
||||||
|
slot.release();
|
||||||
|
slot.released = true;
|
||||||
|
metrics.on_prediction(slot);
|
||||||
|
}
|
||||||
|
|
||||||
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
||||||
int count = 0;
|
int count = 0;
|
||||||
@@ -3226,10 +3243,7 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
|||||||
count++;
|
count++;
|
||||||
if (!has_next) {
|
if (!has_next) {
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
//create_checkpoint(slot);
|
release_slot_after_final_response(slot);
|
||||||
slot.release();
|
|
||||||
slot.print_timings();
|
|
||||||
metrics.on_prediction(slot);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (n > 0 && count >= n) {
|
if (n > 0 && count >= n) {
|
||||||
@@ -3266,7 +3280,7 @@ inline int32_t check_ban_phrase(const server_slot& slot) {
|
|||||||
}
|
}
|
||||||
if (found) {
|
if (found) {
|
||||||
std::vector<size_t> unused;
|
std::vector<size_t> unused;
|
||||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str());
|
LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str());
|
||||||
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
|
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
|
||||||
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
|
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
|
||||||
}
|
}
|
||||||
@@ -3299,6 +3313,8 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
|||||||
size_t n_keep = slot.cache_tokens.size() - n_rewind;
|
size_t n_keep = slot.cache_tokens.size() - n_rewind;
|
||||||
slot.sampled = slot.cache_tokens[n_keep];
|
slot.sampled = slot.cache_tokens[n_keep];
|
||||||
slot.cache_tokens.keep_first(n_keep);
|
slot.cache_tokens.keep_first(n_keep);
|
||||||
|
llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
||||||
@@ -3397,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
|
|
||||||
for (auto& slot : slots) {
|
for (auto& slot : slots) {
|
||||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||||
|
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||||
|
create_checkpoint_at_interval(slot, params_base);
|
||||||
|
}
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3440,6 +3459,14 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
slot.t_start_generation = ggml_time_us();
|
slot.t_start_generation = ggml_time_us();
|
||||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||||
metrics.on_prompt_eval(slot);
|
metrics.on_prompt_eval(slot);
|
||||||
|
if (params_base.do_checkpoint) {
|
||||||
|
create_checkpoint(slot);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// save checkpoint during generation
|
||||||
|
if (slot.n_decoded > 1) {
|
||||||
|
create_checkpoint_at_interval(slot, params_base);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||||
@@ -3452,7 +3479,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// no ban string for recurrent/hybrid model
|
// no ban string for recurrent/hybrid model
|
||||||
if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) {
|
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
||||||
slot.token_buffer = { result };
|
slot.token_buffer = { result };
|
||||||
send_token_results(slot.token_buffer, slot);
|
send_token_results(slot.token_buffer, slot);
|
||||||
} else {
|
} else {
|
||||||
@@ -3503,7 +3530,7 @@ void server_context::update_slots() {
|
|||||||
// apply context-shift if needed
|
// apply context-shift if needed
|
||||||
// TODO: simplify and improve
|
// TODO: simplify and improve
|
||||||
context_shift();
|
context_shift();
|
||||||
|
|
||||||
// start populating the batch for this iteration
|
// start populating the batch for this iteration
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ struct server_slot {
|
|||||||
llama_batch batch_spec = {};
|
llama_batch batch_spec = {};
|
||||||
llama_context * ctx_dft = nullptr;
|
llama_context * ctx_dft = nullptr;
|
||||||
|
|
||||||
|
bool released = false;
|
||||||
slot_state state = SLOT_STATE_IDLE;
|
slot_state state = SLOT_STATE_IDLE;
|
||||||
slot_command command = SLOT_COMMAND_NONE;
|
slot_command command = SLOT_COMMAND_NONE;
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ struct server_slot {
|
|||||||
int32_t n_ctx = 0; // context size per slot
|
int32_t n_ctx = 0; // context size per slot
|
||||||
int32_t n_past = 0;
|
int32_t n_past = 0;
|
||||||
int32_t n_past_prompt = 0;
|
int32_t n_past_prompt = 0;
|
||||||
|
int32_t n_past_offset = 0;
|
||||||
int32_t n_decoded = 0;
|
int32_t n_decoded = 0;
|
||||||
int32_t n_remaining = -1;
|
int32_t n_remaining = -1;
|
||||||
int32_t n_discarded_prompt = 0;
|
int32_t n_discarded_prompt = 0;
|
||||||
@@ -102,6 +104,8 @@ struct server_slot {
|
|||||||
|
|
||||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
||||||
|
|
||||||
|
size_t checkpoint_pos = 0;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||||
llama_tokens drafted;
|
llama_tokens drafted;
|
||||||
@@ -355,4 +359,8 @@ struct server_context {
|
|||||||
void create_checkpoint(server_slot & slot);
|
void create_checkpoint(server_slot & slot);
|
||||||
|
|
||||||
void apply_checkpoint(server_slot & slot);
|
void apply_checkpoint(server_slot & slot);
|
||||||
|
|
||||||
|
void create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base);
|
||||||
|
|
||||||
|
void release_slot_after_final_response(server_slot & slot);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1117,7 +1117,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
|
|||||||
if (it_best != states.end()) {
|
if (it_best != states.end()) {
|
||||||
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt);
|
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt);
|
||||||
const size_t size = it_best->data.size();
|
const size_t size = it_best->data.size();
|
||||||
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
|
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot, 0);
|
||||||
if (n != size) {
|
if (n != size) {
|
||||||
LLAMA_LOG_INFO("failed to restore state with size %zu\n", size);
|
LLAMA_LOG_INFO("failed to restore state with size %zu\n", size);
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -344,6 +344,8 @@ using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
|||||||
struct server_prompt_checkpoint {
|
struct server_prompt_checkpoint {
|
||||||
llama_pos pos_min;
|
llama_pos pos_min;
|
||||||
llama_pos pos_max;
|
llama_pos pos_max;
|
||||||
|
llama_pos pos_min_prompt;
|
||||||
|
llama_pos pos_max_prompt;
|
||||||
|
|
||||||
std::vector<uint8_t> data;
|
std::vector<uint8_t> data;
|
||||||
|
|
||||||
|
|||||||
@@ -645,6 +645,8 @@ extern "C" {
|
|||||||
// Returns true if the model is hybrid (like Jamba, Granite, etc.)
|
// Returns true if the model is hybrid (like Jamba, Granite, etc.)
|
||||||
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
|
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
|
||||||
|
|
||||||
|
LLAMA_API bool llama_model_has_recurrent(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
LLAMA_API uint32_t llama_model_quantize(
|
LLAMA_API uint32_t llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
@@ -735,6 +737,11 @@ extern "C" {
|
|||||||
llama_seq_id * cells_sequences;
|
llama_seq_id * cells_sequences;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// work only with partial states, such as recurrent cache (e.g. Mamba)
|
||||||
|
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
|
||||||
|
|
||||||
|
typedef uint32_t llama_state_seq_flags;
|
||||||
|
|
||||||
// Create an empty KV cache view. (use only for debugging purposes)
|
// Create an empty KV cache view. (use only for debugging purposes)
|
||||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
||||||
|
|
||||||
@@ -813,6 +820,11 @@ extern "C" {
|
|||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Returns the smallest position present in the KV cache for the specified sequence
|
||||||
|
LLAMA_API llama_pos llama_kv_cache_seq_pos_min(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Defragment the KV cache
|
// Defragment the KV cache
|
||||||
// This will be applied:
|
// This will be applied:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
@@ -889,14 +901,16 @@ extern "C" {
|
|||||||
// Get the exact size needed to copy the KV cache of a single sequence
|
// Get the exact size needed to copy the KV cache of a single sequence
|
||||||
LLAMA_API size_t llama_state_seq_get_size(
|
LLAMA_API size_t llama_state_seq_get_size(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
// Copy the KV cache of a single sequence into the specified buffer
|
// Copy the KV cache of a single sequence into the specified buffer
|
||||||
LLAMA_API size_t llama_state_seq_get_data(
|
LLAMA_API size_t llama_state_seq_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst,
|
uint8_t * dst,
|
||||||
size_t size,
|
size_t size,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
||||||
// Returns:
|
// Returns:
|
||||||
@@ -906,7 +920,8 @@ extern "C" {
|
|||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const uint8_t * src,
|
const uint8_t * src,
|
||||||
size_t size,
|
size_t size,
|
||||||
llama_seq_id dest_seq_id);
|
llama_seq_id dest_seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
LLAMA_API size_t llama_state_seq_save_file(
|
LLAMA_API size_t llama_state_seq_save_file(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ ggml_cgraph * llm_build_context::build_k_shift() {
|
|||||||
ggml_set_input(lctx.inp_K_shift);
|
ggml_set_input(lctx.inp_K_shift);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) {
|
if (llm_arch_is_hybrid(model.arch) && hparams.is_recurrent(il)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (kv_self.k_l[il] == nullptr) {
|
if (kv_self.k_l[il] == nullptr) {
|
||||||
@@ -241,7 +241,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) {
|
if (llm_arch_is_hybrid(model.arch) && hparams.is_recurrent(il)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (kv_self.k_l[il] == nullptr) {
|
if (kv_self.k_l[il] == nullptr) {
|
||||||
|
|||||||
@@ -1739,3 +1739,7 @@ bool llama_model_is_recurrent(const llama_model * model) {
|
|||||||
bool llama_model_is_hybrid(const llama_model * model) {
|
bool llama_model_is_hybrid(const llama_model * model) {
|
||||||
return llm_arch_is_hybrid(model->arch);
|
return llm_arch_is_hybrid(model->arch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_model_has_recurrent(const llama_model * model) {
|
||||||
|
return llm_arch_is_hybrid(model->arch) || llm_arch_is_recurrent(model->arch);
|
||||||
|
}
|
||||||
|
|||||||
@@ -671,7 +671,7 @@ static inline uint32_t llama_kv_v_row_embd(
|
|||||||
uint32_t il) {
|
uint32_t il) {
|
||||||
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
|
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
|
||||||
// so per-token V rows include only attention values.
|
// so per-token V rows include only attention values.
|
||||||
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
|
if (llm_arch_is_hybrid(model.arch)) {
|
||||||
return hparams.n_embd_v_gqa(il);
|
return hparams.n_embd_v_gqa(il);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -732,7 +732,7 @@ static bool llama_kv_cache_init(
|
|||||||
cache.hybrid = llm_arch_is_hybrid(model.arch);
|
cache.hybrid = llm_arch_is_hybrid(model.arch);
|
||||||
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
|
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
|
||||||
// standard layout to match the mainline hybrid path when flash attention is off.
|
// standard layout to match the mainline hybrid path when flash attention is off.
|
||||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE;
|
cache.v_trans = !cache.recurrent && !cparams.flash_attn && !llm_arch_is_hybrid(model.arch);
|
||||||
|
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
cache.size = kv_size;
|
cache.size = kv_size;
|
||||||
@@ -744,7 +744,7 @@ static bool llama_kv_cache_init(
|
|||||||
cache.cells.clear();
|
cache.cells.clear();
|
||||||
cache.cells.resize(kv_size);
|
cache.cells.resize(kv_size);
|
||||||
|
|
||||||
if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) {
|
if (cache.recurrent || llm_arch_is_hybrid(model.arch)) {
|
||||||
// init state copy sources
|
// init state copy sources
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
cache.cells[i].src = i;
|
cache.cells[i].src = i;
|
||||||
@@ -829,7 +829,7 @@ static bool llama_kv_cache_init(
|
|||||||
std::vector<size_t> mem_split(model.splits.size(), 0);
|
std::vector<size_t> mem_split(model.splits.size(), 0);
|
||||||
|
|
||||||
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size);
|
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size);
|
||||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
|
if (llm_arch_is_hybrid(model.arch) && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
|
||||||
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
|
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
|
||||||
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
|
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
|
||||||
}
|
}
|
||||||
@@ -1398,6 +1398,19 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static llama_pos llama_kv_cache_seq_pos_min(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||||
|
llama_pos result = -1;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
|
if (cache.cells[i].has_seq_id(seq_id)) {
|
||||||
|
result = cache.cells[i].pos;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
||||||
cache.do_defrag = true;
|
cache.do_defrag = true;
|
||||||
}
|
}
|
||||||
@@ -3227,7 +3240,7 @@ static int llama_decode_internal(
|
|||||||
auto tim1 = ggml_time_us();
|
auto tim1 = ggml_time_us();
|
||||||
#endif
|
#endif
|
||||||
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||||
if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) &&
|
if (llm_arch_is_hybrid(model.arch) &&
|
||||||
n_tokens > 1 &&
|
n_tokens > 1 &&
|
||||||
batch_all.n_seq_id != nullptr &&
|
batch_all.n_seq_id != nullptr &&
|
||||||
batch_all.seq_id != nullptr) {
|
batch_all.seq_id != nullptr) {
|
||||||
@@ -5735,6 +5748,13 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
|
|||||||
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_pos llama_kv_cache_seq_pos_min(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
if (ctx->kv_self.hybrid || ctx->kv_self.recurrent) {
|
||||||
|
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||||
|
}
|
||||||
|
return llama_kv_cache_seq_pos_min(ctx->kv_self, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
|
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||||
}
|
}
|
||||||
@@ -5876,10 +5896,11 @@ struct llama_data_write {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
|
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1,
|
||||||
|
llama_state_seq_flags flags = 0) {
|
||||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||||
|
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
|
||||||
// v_state: 0 -> not transposed V cache
|
// v_state: 0 -> not transposed V cache
|
||||||
// 1 -> transposed V cache
|
// 1 -> transposed V cache
|
||||||
// 2 -> no V cache (as it may be the case with MLA)
|
// 2 -> no V cache (as it may be the case with MLA)
|
||||||
@@ -5895,7 +5916,7 @@ struct llama_data_write {
|
|||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||||
const bool has_k_cache = kv_self.k_l[il] != nullptr;
|
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
|
||||||
|
|
||||||
// Write key type
|
// Write key type
|
||||||
const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1;
|
const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1;
|
||||||
@@ -5924,7 +5945,7 @@ struct llama_data_write {
|
|||||||
if (v_state == 0) {
|
if (v_state == 0) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
||||||
@@ -5951,7 +5972,7 @@ struct llama_data_write {
|
|||||||
const uint32_t kv_size = kv_self.size;
|
const uint32_t kv_size = kv_self.size;
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
|
||||||
@@ -6019,7 +6040,7 @@ struct llama_data_write {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
uint32_t cell_count = 0;
|
uint32_t cell_count = 0;
|
||||||
@@ -6055,7 +6076,7 @@ struct llama_data_write {
|
|||||||
write(&cell_count, sizeof(cell_count));
|
write(&cell_count, sizeof(cell_count));
|
||||||
|
|
||||||
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
|
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
|
||||||
write_kv_cache_data(ctx, cell_ranges, seq_id);
|
write_kv_cache_data(ctx, cell_ranges, seq_id, flags);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -6266,10 +6287,10 @@ struct llama_data_read {
|
|||||||
GGML_ASSERT(sum_split_row_size == row_size);
|
GGML_ASSERT(sum_split_row_size == row_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
|
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||||
|
bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0;
|
||||||
// v_state: 0 -> not transposed V cache
|
// v_state: 0 -> not transposed V cache
|
||||||
// 1 -> transposed V cache
|
// 1 -> transposed V cache
|
||||||
// 2 -> no V cache (as it may be the case with MLA)
|
// 2 -> no V cache (as it may be the case with MLA)
|
||||||
@@ -6298,7 +6319,7 @@ struct llama_data_read {
|
|||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||||
const bool has_k_cache = kv_self.k_l[il] != nullptr;
|
const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv;
|
||||||
|
|
||||||
|
|
||||||
// Read type of key
|
// Read type of key
|
||||||
@@ -6346,7 +6367,7 @@ struct llama_data_read {
|
|||||||
if (v_state == 0) {
|
if (v_state == 0) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
@@ -6394,7 +6415,7 @@ struct llama_data_read {
|
|||||||
// For each layer, read the values for each cell (transposed)
|
// For each layer, read the values for each cell (transposed)
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
|
||||||
const bool has_v_cache = kv_self.v_l[il] != nullptr;
|
const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv;
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
@@ -6529,11 +6550,11 @@ struct llama_data_read {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) {
|
||||||
uint32_t cell_count;
|
uint32_t cell_count;
|
||||||
read_to(&cell_count, sizeof(cell_count));
|
read_to(&cell_count, sizeof(cell_count));
|
||||||
|
|
||||||
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id);
|
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id, flags);
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
if (seq_id == -1) {
|
if (seq_id == -1) {
|
||||||
@@ -6895,41 +6916,41 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.write_kv_cache(ctx, seq_id);
|
data_ctx.write_kv_cache(ctx, seq_id, flags);
|
||||||
|
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
llama_data_write_dummy data_ctx;
|
llama_data_write_dummy data_ctx;
|
||||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
llama_data_write_buffer data_ctx(dst, size, ctx->model);
|
llama_data_write_buffer data_ctx(dst, size, ctx->model);
|
||||||
try {
|
try {
|
||||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
data_ctx.read_kv_cache(ctx, dest_seq_id, flags);
|
||||||
|
|
||||||
return data_ctx.get_size_read();
|
return data_ctx.get_size_read();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
|
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
|
||||||
llama_data_read_buffer data_ctx(src, size);
|
llama_data_read_buffer data_ctx(src, size);
|
||||||
try {
|
try {
|
||||||
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
@@ -6948,7 +6969,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
|
|||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_data_write_file data_ctx(&file, ctx->model);
|
llama_data_write_file data_ctx(&file, ctx->model);
|
||||||
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, 0);
|
||||||
|
|
||||||
const size_t res = file.tell();
|
const size_t res = file.tell();
|
||||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
||||||
@@ -6986,7 +7007,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
|
|||||||
{
|
{
|
||||||
const size_t state_size = file.size() - file.tell();
|
const size_t state_size = file.size() - file.tell();
|
||||||
llama_data_read_file data_ctx(&file);
|
llama_data_read_file data_ctx(&file);
|
||||||
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, 0);
|
||||||
if (!nread) {
|
if (!nread) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
Reference in New Issue
Block a user