From bb358223cd9f7b6e827541937ac21a039151e4a9 Mon Sep 17 00:00:00 2001 From: firecoperana Date: Fri, 14 Nov 2025 16:40:13 +0000 Subject: [PATCH] server: cache prompt to host memory (#954) * server : host-memory prompt caching change similarity calculation and prompt save conditions Remove unneeded token limit rename variable Separate prompt save and load logic change default values change log remove truncate prompt logic * add description * bug fixes * remove token limit in init --------- Co-authored-by: firecoperana --- common/common.cpp | 20 +- common/common.h | 5 +- examples/server/server.cpp | 362 ++++++++++++++++++++++++++++++++----- examples/server/utils.hpp | 10 +- 4 files changed, 347 insertions(+), 50 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 789fcbd3..2ff5f204 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1816,6 +1816,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.ctx_shift = false; return true; } + if (arg == "-cram" || arg == "--cache-ram") { + CHECK_ARG + params.cache_ram_mib = std::stoi(argv[i]); + return true; + } + if (arg == "-crs" || arg == "--cache-ram-similarity") { + CHECK_ARG + params.cache_ram_similarity = std::stof(argv[i]); + return true; + } + if (arg == "-cram-n-min" || arg == "--cache-ram-n-min") { + CHECK_ARG + params.cache_ram_n_min = std::stoi(argv[i]); + return true; + } if (arg == "--pos") { CHECK_ARG params.i_pos = std::stoi(argv[i]); @@ -1995,6 +2010,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft }); + options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); + options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); + options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); @@ -2007,7 +2025,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-no-fmoe, --no-fused-moe", "disable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" }); options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); - options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); + options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); options.push_back({ "*", "-rcache, --rope-cache", "enable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" }); options.push_back({ "*", "-gr, --graph-reuse", "enable graph reuse (default: %s)", params.graph_reuse ? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); diff --git a/common/common.h b/common/common.h index f27eca97..8c07127e 100644 --- a/common/common.h +++ b/common/common.h @@ -330,7 +330,10 @@ struct gpt_params { std::string sql_save_file; std::string sqlite_zstd_ext_file; - float slot_prompt_similarity = 0.5f; + float slot_prompt_similarity = 0.1f; + int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. + int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram + float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e4f03470..a39c11ef 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -139,6 +139,7 @@ struct server_task { int id = -1; // to be filled by server_queue int id_multi = -1; int id_target = -1; + //int id_slot = -1; // used by SERVER_TASK_TYPE_INFERENCE server_tokens tokens; @@ -148,6 +149,10 @@ struct server_task { bool infill = false; bool embedding = false; + + server_task() = default; + server_task(server_task_type type) : type(type) {} + }; struct server_task_result { @@ -531,7 +536,7 @@ struct server_task_result { } }; -inline std::string stop_type_to_str(stop_type type) { +static inline std::string stop_type_to_str(stop_type type) { switch (type) { case STOP_TYPE_EOS: return "eos"; case STOP_TYPE_WORD: return "word"; @@ -579,6 +584,212 @@ struct slot_params { }; + +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto& checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const { + size_t res = 0; + + for (const auto& state : states) { + res += state.size(); + } + + return res; + } + + size_t n_tokens() const { + size_t res = 0; + + for (const auto& state : states) { + res += state.n_tokens(); + } + + return res; + } + + server_prompt* alloc(const server_prompt& prompt, size_t state_size) { + for (auto it = states.begin(); it != states.end();) { + const size_t len = it->tokens.get_common_prefix(prompt.tokens); + + // first check if the current state is contained fully in the cache + if (len == prompt.tokens.size()) { + LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + // next, remove any cached prompts that are fully contained in the current prompt + else if(len == it->tokens.size()) { + LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", len); + it = states.erase(it); + } + else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } + catch (const std::bad_alloc& e) { + LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4 * size()); + + LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto& cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; + } + + bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + //float sim_best = float(lcp_best) / tokens_new.size(); + float sim_best = get_slot_similarity(lcp_best, tokens_new.size(), prompt.tokens.size()); + + LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + //const float sim_cur = float(lcp_cur) / tokens_new.size(); + const float sim_cur = get_slot_similarity(lcp_cur, tokens_new.size(), it->tokens.size()); + if (sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + it_best = it; + } + } + + if (it_best != states.end()) { + LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot); + if (n != size) { + LLAMA_LOG_INFO("failed to restore state with size %zu\n", size); + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; + } + + void update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + // average size per token + const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); + + // dynamically increase the token limit if it can fit in the memory limit + const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size / size_per_token) : limit_tokens; + + //if (limit_tokens > 0) { + // + // while (states.size() > 1 && n_tokens() > limit_tokens_cur) { + // if (states.empty()) { + // break; + // } + + // LLAMA_LOG_INFO(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", + // limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); + + // states.pop_front(); + // } + //} + + LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); + + for (const auto& state : states) { + LLAMA_LOG_INFO(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", + (const void*)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } + } +}; + + struct server_slot { int id; int id_task = -1; @@ -589,9 +800,12 @@ struct server_slot { slot_state state = SLOT_STATE_IDLE; slot_command command = SLOT_COMMAND_NONE; + llama_context* ctx = nullptr; // used to determine the slot that has been used the longest int64_t t_last_used = -1; + std::unique_ptr task; + // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; @@ -627,6 +841,33 @@ struct server_slot { std::string oaicompat_model; std::string stopping_word; stop_type stop; + + server_prompt server_prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + assert(server_prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size(ctx, id); + + LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int)server_prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto* cur = prompt_cache.alloc(server_prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id); + } + + void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) { + bool res = prompt_cache.load(server_prompt, tokens, ctx, id); + if (!res) { + LLAMA_LOG_INFO("failed to load prompt from cache\n"); + } + } + + // sampling llama_token sampled; struct llama_sampling_params sparams; @@ -689,6 +930,8 @@ struct server_slot { chat_msg = {}; json_schema = json(); generated_tool_call_ids.clear(); + + task.reset(); } bool has_budget(gpt_params &global_params) { @@ -726,6 +969,7 @@ struct server_slot { if (state == SLOT_STATE_PROCESSING) { t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; command = SLOT_COMMAND_RELEASE; + task.reset(); } } @@ -1176,12 +1420,16 @@ struct server_context { server_queue queue_tasks; server_response queue_results; + std::unique_ptr prompt_cache; + server_metrics metrics; common_chat_templates_ptr chat_templates; oaicompat_parser_options oai_parser_opt; // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + int32_t cache_ram_n_min = 0; + float cache_ram_similarity = 0.5f; ~server_context() { if (ctx) { @@ -1340,6 +1588,7 @@ struct server_context { server_slot slot; slot.id = i; + slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params.n_predict; slot.mctx = mctx; @@ -1412,6 +1661,21 @@ struct server_context { metrics.init(); + if (params.cache_ram_mib != 0) { + if (params.cache_ram_mib < 0) { + LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit"); + } + else { + LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib); + } + LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + // only apply ram size limit. No token limit for now. + prompt_cache = std::make_unique(params.cache_ram_mib, 0); + } + else { + LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it @@ -1483,11 +1747,12 @@ struct server_context { server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; + bool update_cache = false; // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { int max_lcp_len = 0; - float similarity = 0; + float sim_best = 0; for (server_slot & slot : slots) { // skip the slot if it is not available @@ -1499,23 +1764,22 @@ struct server_context { if (cache_tokens.empty()) { continue; } - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - int lcp_len = cache_tokens.get_common_prefix(task.tokens); - // fraction of the common substring length compared to the current slot's prompt length - const float similarity = float(lcp_len) / task.tokens.size(); + size_t lcp_len = cache_tokens.get_common_prefix(task.tokens); + // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length + const float sim_cur = get_slot_similarity(lcp_len, task.tokens.size(), cache_tokens.size()); // select the current slot if the criteria match - if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; max_lcp_len = lcp_len; ret = &slot; } } - if (ret != nullptr) { LOG_VERBOSE("selected slot by lcp similarity", { {"id_slot", ret->id}, {"max_lcp_len", max_lcp_len}, - {"similarity", similarity}, + {"similarity", sim_best}, }); } } @@ -1528,7 +1792,6 @@ struct server_context { if (!slot.available()) { continue; } - // select the current slot if the criteria match if (slot.t_last_used < t_last) { t_last = slot.t_last_used; @@ -1543,7 +1806,46 @@ struct server_context { }); } } + if (ret) { + const auto& tokens = ret->cache_tokens; + float f_keep = 0.0f; + if (!tokens.empty()) { + size_t lcp_len = tokens.get_common_prefix(task.tokens); + f_keep = float(lcp_len) / tokens.size(); + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < cache_ram_similarity) { + update_cache = true; + } + } + update_cache = update_cache && prompt_cache; + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + // don't update the cache if the slot's context is above cache_ram_n_min + update_cache = update_cache && tokens.size() >= cache_ram_n_min; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + LLAMA_LOG_INFO("prompt cache: cache size: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n", + tokens.size(), cache_ram_n_min, f_keep, cache_ram_similarity); + if (update_cache) { + const int64_t t_start = ggml_time_us(); + LLAMA_LOG_INFO("updating prompt cache\n"); + ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + ret->prompt_save(*prompt_cache); + LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + // has prompts saved earlier to load + if (!prompt_cache->states.empty()) { + const int64_t t_start = ggml_time_us(); + ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + ret->prompt_load(*prompt_cache, task.tokens); + prompt_cache->update(); + ret->cache_tokens = server_tokens(ret->server_prompt.tokens.get_text_tokens(), false); // recover cache tokens + LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + } return ret; } @@ -3007,40 +3309,10 @@ struct server_context { slot.params.n_keep = slot.n_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { - if (!params.ctx_shift) { - send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - int n_keep = slot.params.n_keep; - int n_discard = erased_blocks * n_block_size; - llama_tokens new_tokens = prompt_tokens.get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - new_tokens.resize(prompt_tokens.size() - n_discard); - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - LOG_VERBOSE("input truncated", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, - }); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + if (slot.n_prompt_tokens >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER); + slot.release(); + continue; } llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); @@ -3881,6 +4153,8 @@ int main(int argc, char ** argv) { // Necessary similarity of prompt for slot selection ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + ctx_server.cache_ram_n_min = params.cache_ram_n_min; + ctx_server.cache_ram_similarity = params.cache_ram_similarity; #ifdef SQLITE3_MODERN_CPP_SUPPORT auto db_handle = std::make_shared(params.sql_save_file); bool sqlite_extension_loaded = false; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index fbd67573..67f4f816 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -255,6 +255,10 @@ static std::string gen_tool_call_id() { // // other common utils // +static float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length) { + float sim = float(lcp) * 2 / (prompt_length + cache_length); + return sim; +} static size_t common_part(const std::vector & a, const std::vector & b) { size_t i; @@ -1026,7 +1030,7 @@ public: } } - server_tokens(std::vector& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const std::vector& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} llama_pos pos_next() const { if (!has_mtmd) { @@ -1068,9 +1072,7 @@ public: if (it != map_pos_to_media.end()) { return it->second; } - else { - throw std::runtime_error("Chunk not found"); - } + throw std::runtime_error("Chunk not found"); } void push_back(llama_token tok) {