diff --git a/common/common.cpp b/common/common.cpp index 6486c097..942379d6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2041,6 +2041,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--ctx-checkpoints") { + CHECK_ARG + params.ctx_checkpoints_n = std::stoi(argv[i]); + return true; + } + if (arg == "--ctx-checkpoints-interval") { + CHECK_ARG + params.ctx_checkpoints_interval = std::stoi(argv[i]); + return true; + } if (arg == "-cram" || arg == "--cache-ram") { CHECK_ARG params.cache_ram_mib = std::stoi(argv[i]); @@ -2235,7 +2245,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx }); - options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); + + options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n}); + options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval}); + options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); diff --git a/common/common.h b/common/common.h index c14d2e22..44653a6f 100644 --- a/common/common.h +++ b/common/common.h @@ -280,6 +280,8 @@ struct gpt_params { std::vector ban_phrases; // strings that are banned in generation int32_t banned_n = 1; // number of tokens that are banned in the phrase size_t n_buffer = 0; // number of token buffers for string ban + bool can_ban_phrases = true; // whether to ban strings + bool do_checkpoint = false; // do checkpoint for recurrent models only std::vector kv_overrides; std::vector tensor_buft_overrides; @@ -418,7 +420,8 @@ struct gpt_params { float slot_prompt_similarity = 0.1f; - int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot + int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot + int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 36a72aa9..4ed5d5ce 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -188,8 +188,8 @@ int main(int argc, char ** argv) { // save seq 0 and load into seq 1 { // save kv of seq 0 - std::vector seq_store(llama_state_seq_get_size(ctx3, 0)); - const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0); + std::vector seq_store(llama_state_seq_get_size(ctx3, 0, 0)); + const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0, 0); if (ncopy != seq_store.size()) { fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); llama_free(ctx3); @@ -203,7 +203,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 - const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1); + const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1, 0); if (nset != seq_store.size()) { fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); llama_free(ctx3); diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index af9e1afb..05f66f0c 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -315,7 +315,7 @@ void server_context::init() { void server_slot::prompt_save(server_prompt_cache& prompt_cache) const { assert(server_cached_prompt.data.size() == 0); - const size_t cur_size = llama_state_seq_get_size(ctx, id); + const size_t cur_size = llama_state_seq_get_size(ctx, id, 0); LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n", (int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); @@ -325,7 +325,7 @@ void server_slot::prompt_save(server_prompt_cache& prompt_cache) const { return; } - llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id); + llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0); } void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) { @@ -361,7 +361,7 @@ void server_slot::reset() { rewind_status = false; generated_token_probs.clear(); - + checkpoint_pos = 0; // Reset speculative decoding stats n_draft_total = 0; @@ -1246,7 +1246,22 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias); slot.banned_n = json_value(data, "banned_n", params_base.banned_n); } - + if (llama_model_has_recurrent(llama_get_model(slot.ctx))) { + params_base.can_ban_phrases = false; + bool do_checkpoint = params_base.ctx_checkpoints_n > 0; + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION; + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + // do_checkpoint = do_checkpoint && llama_model_has_recurrent(model); + params_base.do_checkpoint = do_checkpoint; + if (slot.n_buffer != 0) { + LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n"); + } + } { const auto& stop = data.find("stop"); if (stop != data.end() && stop->is_array()) { @@ -2142,7 +2157,7 @@ void server_context::process_single_task(server_task&& task) { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); server_task_result result; @@ -2552,6 +2567,7 @@ void server_context::context_shift() { void server_context::add_sampled_tokens() { for (auto& slot : slots) { + slot.released = false; if (slot.state == SLOT_STATE_IDLE) { continue; } @@ -2626,15 +2642,22 @@ void server_context::add_sampled_tokens() { } } +void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) { + if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) { + auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) { + create_checkpoint(slot); + slot.checkpoint_pos = pos; + } + } +} + void server_context::apply_checkpoint(server_slot & slot) { const auto pos_min_thold = std::max(0, slot.n_past - 1); if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) { - int32_t pos_min = 0; - if (llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) { - pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); - } + int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); - if (pos_min > pos_min_thold+2) { + if (pos_min > pos_min_thold) { // TODO: support can be added in the future when corresponding vision models get released GGML_ASSERT(!slot.cache_tokens.has_mtmd); @@ -2654,8 +2677,9 @@ void server_context::apply_checkpoint(server_slot & slot) { if (!do_reset) { // restore the context checkpoint + const int64_t t_start = ggml_time_us(); const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id); + const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != checkpoint_size) { SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); @@ -2663,7 +2687,8 @@ void server_context::apply_checkpoint(server_slot & slot) { //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); + slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt)); + SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); } } @@ -2691,56 +2716,44 @@ void server_context::apply_checkpoint(server_slot & slot) { } void server_context::create_checkpoint(server_slot & slot) { - //bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + bool do_checkpoint = true; + int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); + const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); - //// make checkpoints only for completion tasks - //do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16); - //// make a checkpoint of the parts of the memory that cannot be rolled back. - //// checkpoints are created only if: - //// - the model architecture is marked as recurrent or hybrid - //// - //// TODO: try to make this conditional on the context or the memory module, instead of the model type - //do_checkpoint = do_checkpoint && ( - // llama_model_is_recurrent(model) || - // llama_model_is_hybrid(model) - // ); - //int32_t pos_min = 0; - //if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) { - // pos_min = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); - //} - //const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max); - //// no need for empty or small checkpoints - //do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 5); + if (do_checkpoint) { + const int64_t t_start = ggml_time_us(); + while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) { + // make room for the new checkpoint, if needed + const auto & cur = slot.server_cached_prompt.checkpoints.front(); - //// no need to create checkpoints that are too close together - //do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max + 64); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); - //if (do_checkpoint) { - // while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.n_ctx_checkpoints) { - // // make room for the new checkpoint, if needed - // const auto & cur = slot.server_cached_prompt.checkpoints.front(); + slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); + } - // SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - // cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); + const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - // slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); - // } + auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.pos_min_prompt = */ pos_min + slot.n_past_offset, + /*.pos_max_prompt = */ pos_max + slot.n_past_offset , + /*.data = */ std::vector(checkpoint_size), + }); - // const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id); + llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - // auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - // /*.pos_min = */ pos_min, - // /*.pos_max = */ pos_max, - // /*.data = */ std::vector(checkpoint_size), - // }); - - // llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id); - - // SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - // (int)slot.server_cached_prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); - //} + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n", + (int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024, + (ggml_time_us() - t_start) / 1000.0); + } } void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) { @@ -2798,8 +2811,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t } slot.n_past = 0; - slot.n_buffer = 0; - slot.token_buffer.clear(); slot.n_prompt_tokens = prompt_tokens.size(); LOG_VERBOSE("prompt tokenized", { @@ -2900,6 +2911,8 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t } slot.n_past = prefix.first; slot.n_past_prompt = prefix.second; + slot.n_past_offset = slot.n_past_prompt - slot.n_past; + if (slot.n_past != slot.n_past_prompt) { LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n"); } @@ -3074,8 +3087,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - //create_checkpoint(slot); - LOG_VERBOSE("prompt done", { {"id_slot", slot.id}, {"n_past", slot.n_past}, @@ -3187,14 +3198,11 @@ void server_context::speculative_decoding_accept() { populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i); } - if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) { + if (slot.n_buffer == 0 || !params_base.can_ban_phrases) { if (!process_token(result, slot)) { // release slot because of stop condition send_final_response(slot); - //create_checkpoint(slot); - slot.release(); - slot.print_timings(); - metrics.on_prediction(slot); + release_slot_after_final_response(slot); break; } } else { @@ -3218,6 +3226,15 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_ return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end(); } +void server_context::release_slot_after_final_response(server_slot & slot) { + slot.print_timings(); + if (params_base.do_checkpoint) { + create_checkpoint(slot); + } + slot.release(); + slot.released = true; + metrics.on_prediction(slot); +} void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) { int count = 0; @@ -3226,10 +3243,7 @@ void server_context::send_token_results(completion_token_outputs& results, serve count++; if (!has_next) { send_final_response(slot); - //create_checkpoint(slot); - slot.release(); - slot.print_timings(); - metrics.on_prediction(slot); + release_slot_after_final_response(slot); break; } if (n > 0 && count >= n) { @@ -3266,7 +3280,7 @@ inline int32_t check_ban_phrase(const server_slot& slot) { } if (found) { std::vector unused; - LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str()); + LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str()); n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused); n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n; } @@ -3299,6 +3313,8 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) { size_t n_keep = slot.cache_tokens.size() - n_rewind; slot.sampled = slot.cache_tokens[n_keep]; slot.cache_tokens.keep_first(n_keep); + llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1); + } void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) { @@ -3397,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) { for (auto& slot : slots) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { + if (slot.command == SLOT_COMMAND_LOAD_PROMPT) { + create_checkpoint_at_interval(slot, params_base); + } continue; // continue loop of slots } @@ -3440,6 +3459,14 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.t_start_generation = ggml_time_us(); slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); + if (params_base.do_checkpoint) { + create_checkpoint(slot); + } + } + + // save checkpoint during generation + if (slot.n_decoded > 1) { + create_checkpoint_at_interval(slot, params_base); } slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; @@ -3452,7 +3479,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } // no ban string for recurrent/hybrid model - if (slot.n_buffer == 0 || llama_model_is_hybrid(llama_get_model(slot.ctx)) || llama_model_is_recurrent(llama_get_model(slot.ctx))) { + if (slot.n_buffer == 0 || !params_base.can_ban_phrases) { slot.token_buffer = { result }; send_token_results(slot.token_buffer, slot); } else { @@ -3503,7 +3530,7 @@ void server_context::update_slots() { // apply context-shift if needed // TODO: simplify and improve context_shift(); - + // start populating the batch for this iteration common_batch_clear(batch); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 8ff64d1a..f6444ffb 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -32,6 +32,7 @@ struct server_slot { llama_batch batch_spec = {}; llama_context * ctx_dft = nullptr; + bool released = false; slot_state state = SLOT_STATE_IDLE; slot_command command = SLOT_COMMAND_NONE; @@ -45,6 +46,7 @@ struct server_slot { int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; int32_t n_past_prompt = 0; + int32_t n_past_offset = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t n_discarded_prompt = 0; @@ -102,6 +104,8 @@ struct server_slot { void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens); + size_t checkpoint_pos = 0; + // sampling llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; @@ -355,4 +359,8 @@ struct server_context { void create_checkpoint(server_slot & slot); void apply_checkpoint(server_slot & slot); + + void create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base); + + void release_slot_after_final_response(server_slot & slot); }; diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp index 873f214c..318149b1 100644 --- a/examples/server/server-task.cpp +++ b/examples/server/server-task.cpp @@ -1117,7 +1117,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token if (it_best != states.end()) { LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt); const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot); + const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot, 0); if (n != size) { LLAMA_LOG_INFO("failed to restore state with size %zu\n", size); return false; diff --git a/examples/server/server-task.h b/examples/server/server-task.h index cbb2ce16..f3d3705f 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -344,6 +344,8 @@ using server_task_result_ptr = std::unique_ptr; struct server_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; + llama_pos pos_min_prompt; + llama_pos pos_max_prompt; std::vector data; diff --git a/include/llama.h b/include/llama.h index aec65af3..338381b4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -645,6 +645,8 @@ extern "C" { // Returns true if the model is hybrid (like Jamba, Granite, etc.) LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + LLAMA_API bool llama_model_has_recurrent(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -735,6 +737,11 @@ extern "C" { llama_seq_id * cells_sequences; }; + // work only with partial states, such as recurrent cache (e.g. Mamba) +#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 + + typedef uint32_t llama_state_seq_flags; + // Create an empty KV cache view. (use only for debugging purposes) LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); @@ -813,6 +820,11 @@ extern "C" { struct llama_context * ctx, llama_seq_id seq_id); + // Returns the smallest position present in the KV cache for the specified sequence + LLAMA_API llama_pos llama_kv_cache_seq_pos_min( + struct llama_context * ctx, + llama_seq_id seq_id); + // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() @@ -889,14 +901,16 @@ extern "C" { // Get the exact size needed to copy the KV cache of a single sequence LLAMA_API size_t llama_state_seq_get_size( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id, + llama_state_seq_flags flags); // Copy the KV cache of a single sequence into the specified buffer LLAMA_API size_t llama_state_seq_get_data( struct llama_context * ctx, uint8_t * dst, size_t size, - llama_seq_id seq_id); + llama_seq_id seq_id, + llama_state_seq_flags flags); // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence // Returns: @@ -906,7 +920,8 @@ extern "C" { struct llama_context * ctx, const uint8_t * src, size_t size, - llama_seq_id dest_seq_id); + llama_seq_id dest_seq_id, + llama_state_seq_flags flags); LLAMA_API size_t llama_state_seq_save_file( struct llama_context * ctx, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index a7f9e785..975aafb1 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -142,7 +142,7 @@ ggml_cgraph * llm_build_context::build_k_shift() { ggml_set_input(lctx.inp_K_shift); for (int il = 0; il < n_layer; ++il) { - if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) { + if (llm_arch_is_hybrid(model.arch) && hparams.is_recurrent(il)) { continue; } if (kv_self.k_l[il] == nullptr) { @@ -241,7 +241,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector & ids) } for (int il = 0; il < n_layer; ++il) { - if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) { + if (llm_arch_is_hybrid(model.arch) && hparams.is_recurrent(il)) { continue; } if (kv_self.k_l[il] == nullptr) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4cda5330..733e8137 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1739,3 +1739,7 @@ bool llama_model_is_recurrent(const llama_model * model) { bool llama_model_is_hybrid(const llama_model * model) { return llm_arch_is_hybrid(model->arch); } + +bool llama_model_has_recurrent(const llama_model * model) { + return llm_arch_is_hybrid(model->arch) || llm_arch_is_recurrent(model->arch); +} diff --git a/src/llama.cpp b/src/llama.cpp index a5ceeb4f..84377705 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -671,7 +671,7 @@ static inline uint32_t llama_kv_v_row_embd( uint32_t il) { // qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence), // so per-token V rows include only attention values. - if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) { + if (llm_arch_is_hybrid(model.arch)) { return hparams.n_embd_v_gqa(il); } @@ -732,7 +732,7 @@ static bool llama_kv_cache_init( cache.hybrid = llm_arch_is_hybrid(model.arch); // qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in // standard layout to match the mainline hybrid path when flash attention is off. - cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE; + cache.v_trans = !cache.recurrent && !cparams.flash_attn && !llm_arch_is_hybrid(model.arch); cache.head = 0; cache.size = kv_size; @@ -744,7 +744,7 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) { + if (cache.recurrent || llm_arch_is_hybrid(model.arch)) { // init state copy sources for (uint32_t i = 0; i < cache.size; ++i) { cache.cells[i].src = i; @@ -829,7 +829,7 @@ static bool llama_kv_cache_init( std::vector mem_split(model.splits.size(), 0); const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size); - if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && qnext_state_slots < std::max(1, cparams.n_seq_max)) { + if (llm_arch_is_hybrid(model.arch) && qnext_state_slots < std::max(1, cparams.n_seq_max)) { LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n", __func__, std::max(1, cparams.n_seq_max), qnext_state_slots); } @@ -1398,6 +1398,19 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama return result; } +static llama_pos llama_kv_cache_seq_pos_min(struct llama_kv_cache & cache, llama_seq_id seq_id) { + llama_pos result = -1; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id)) { + result = cache.cells[i].pos; + break; + } + } + + return result; +} + static void llama_kv_cache_defrag(struct llama_kv_cache & cache) { cache.do_defrag = true; } @@ -3227,7 +3240,7 @@ static int llama_decode_internal( auto tim1 = ggml_time_us(); #endif uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && + if (llm_arch_is_hybrid(model.arch) && n_tokens > 1 && batch_all.n_seq_id != nullptr && batch_all.seq_id != nullptr) { @@ -5735,6 +5748,13 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); } +llama_pos llama_kv_cache_seq_pos_min(struct llama_context * ctx, llama_seq_id seq_id) { + if (ctx->kv_self.hybrid || ctx->kv_self.recurrent) { + return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); + } + return llama_kv_cache_seq_pos_min(ctx->kv_self, seq_id); +} + llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); } @@ -5876,10 +5896,11 @@ struct llama_data_write { } } - void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges, llama_seq_id seq_id = -1, + llama_state_seq_flags flags = 0) { const struct llama_kv_cache & kv_self = ctx->kv_self; const struct llama_hparams & hparams = ctx->model.hparams; - + bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0; // v_state: 0 -> not transposed V cache // 1 -> transposed V cache // 2 -> no V cache (as it may be the case with MLA) @@ -5895,7 +5916,7 @@ struct llama_data_write { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; - const bool has_k_cache = kv_self.k_l[il] != nullptr; + const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv; // Write key type const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1; @@ -5924,7 +5945,7 @@ struct llama_data_write { if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); - const bool has_v_cache = kv_self.v_l[il] != nullptr; + const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv; // Write value type const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1; @@ -5951,7 +5972,7 @@ struct llama_data_write { const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); - const bool has_v_cache = kv_self.v_l[il] != nullptr; + const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv; // Write value type const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1; @@ -6019,7 +6040,7 @@ struct llama_data_write { } } - void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { + void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) { const struct llama_kv_cache & kv_self = ctx->kv_self; std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -6055,7 +6076,7 @@ struct llama_data_write { write(&cell_count, sizeof(cell_count)); write_kv_cache_meta(kv_self, cell_ranges, seq_id); - write_kv_cache_data(ctx, cell_ranges, seq_id); + write_kv_cache_data(ctx, cell_ranges, seq_id, flags); } }; @@ -6266,10 +6287,10 @@ struct llama_data_read { GGML_ASSERT(sum_split_row_size == row_size); } - bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { + bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) { const struct llama_hparams & hparams = ctx->model.hparams; struct llama_kv_cache & kv_self = ctx->kv_self; - + bool need_kv = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0; // v_state: 0 -> not transposed V cache // 1 -> transposed V cache // 2 -> no V cache (as it may be the case with MLA) @@ -6298,7 +6319,7 @@ struct llama_data_read { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; - const bool has_k_cache = kv_self.k_l[il] != nullptr; + const bool has_k_cache = kv_self.k_l[il] != nullptr && need_kv; // Read type of key @@ -6346,7 +6367,7 @@ struct llama_data_read { if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); - const bool has_v_cache = kv_self.v_l[il] != nullptr; + const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv; // Read type of value int32_t v_type_i_ref; @@ -6394,7 +6415,7 @@ struct llama_data_read { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); - const bool has_v_cache = kv_self.v_l[il] != nullptr; + const bool has_v_cache = kv_self.v_l[il] != nullptr && need_kv; // Read type of value int32_t v_type_i_ref; @@ -6529,11 +6550,11 @@ struct llama_data_read { return true; } - void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) { uint32_t cell_count; read_to(&cell_count, sizeof(cell_count)); - bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id); + bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id, flags); if (!res) { if (seq_id == -1) { @@ -6895,41 +6916,41 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session } } -static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { +static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { llama_synchronize(ctx); - data_ctx.write_kv_cache(ctx, seq_id); + data_ctx.write_kv_cache(ctx, seq_id, flags); return data_ctx.get_size_written(); } -size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) { +size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { llama_data_write_dummy data_ctx; - return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags); } -size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { +size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { llama_data_write_buffer data_ctx(dst, size, ctx->model); try { - return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what()); return 0; } } -static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { +static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id, llama_state_seq_flags flags) { llama_synchronize(ctx); - data_ctx.read_kv_cache(ctx, dest_seq_id); + data_ctx.read_kv_cache(ctx, dest_seq_id, flags); return data_ctx.get_size_read(); } -size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) { +size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id, llama_state_seq_flags flags) { llama_data_read_buffer data_ctx(src, size); try { - return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); + return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what()); return 0; @@ -6948,7 +6969,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con // save the context state using stream saving llama_data_write_file data_ctx(&file, ctx->model); - llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + llama_state_seq_get_data_internal(ctx, data_ctx, seq_id, 0); const size_t res = file.tell(); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); @@ -6986,7 +7007,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con { const size_t state_size = file.size() - file.tell(); llama_data_read_file data_ctx(&file); - const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); + const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id, 0); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0;