diff --git a/common/common.cpp b/common/common.cpp index 789fcbd3..2ff5f204 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1816,6 +1816,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.ctx_shift = false; return true; } + if (arg == "-cram" || arg == "--cache-ram") { + CHECK_ARG + params.cache_ram_mib = std::stoi(argv[i]); + return true; + } + if (arg == "-crs" || arg == "--cache-ram-similarity") { + CHECK_ARG + params.cache_ram_similarity = std::stof(argv[i]); + return true; + } + if (arg == "-cram-n-min" || arg == "--cache-ram-n-min") { + CHECK_ARG + params.cache_ram_n_min = std::stoi(argv[i]); + return true; + } if (arg == "--pos") { CHECK_ARG params.i_pos = std::stoi(argv[i]); @@ -1995,6 +2010,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft }); + options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); + options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); + options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); @@ -2007,7 +2025,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-no-fmoe, --no-fused-moe", "disable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" }); options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); - options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); + options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); options.push_back({ "*", "-rcache, --rope-cache", "enable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" }); options.push_back({ "*", "-gr, --graph-reuse", "enable graph reuse (default: %s)", params.graph_reuse ? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); diff --git a/common/common.h b/common/common.h index f27eca97..8c07127e 100644 --- a/common/common.h +++ b/common/common.h @@ -330,7 +330,10 @@ struct gpt_params { std::string sql_save_file; std::string sqlite_zstd_ext_file; - float slot_prompt_similarity = 0.5f; + float slot_prompt_similarity = 0.1f; + int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. + int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram + float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e4f03470..a39c11ef 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -139,6 +139,7 @@ struct server_task { int id = -1; // to be filled by server_queue int id_multi = -1; int id_target = -1; + //int id_slot = -1; // used by SERVER_TASK_TYPE_INFERENCE server_tokens tokens; @@ -148,6 +149,10 @@ struct server_task { bool infill = false; bool embedding = false; + + server_task() = default; + server_task(server_task_type type) : type(type) {} + }; struct server_task_result { @@ -531,7 +536,7 @@ struct server_task_result { } }; -inline std::string stop_type_to_str(stop_type type) { +static inline std::string stop_type_to_str(stop_type type) { switch (type) { case STOP_TYPE_EOS: return "eos"; case STOP_TYPE_WORD: return "word"; @@ -579,6 +584,212 @@ struct slot_params { }; + +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto& checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const { + size_t res = 0; + + for (const auto& state : states) { + res += state.size(); + } + + return res; + } + + size_t n_tokens() const { + size_t res = 0; + + for (const auto& state : states) { + res += state.n_tokens(); + } + + return res; + } + + server_prompt* alloc(const server_prompt& prompt, size_t state_size) { + for (auto it = states.begin(); it != states.end();) { + const size_t len = it->tokens.get_common_prefix(prompt.tokens); + + // first check if the current state is contained fully in the cache + if (len == prompt.tokens.size()) { + LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + // next, remove any cached prompts that are fully contained in the current prompt + else if(len == it->tokens.size()) { + LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", len); + it = states.erase(it); + } + else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } + catch (const std::bad_alloc& e) { + LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4 * size()); + + LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto& cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; + } + + bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + //float sim_best = float(lcp_best) / tokens_new.size(); + float sim_best = get_slot_similarity(lcp_best, tokens_new.size(), prompt.tokens.size()); + + LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + //const float sim_cur = float(lcp_cur) / tokens_new.size(); + const float sim_cur = get_slot_similarity(lcp_cur, tokens_new.size(), it->tokens.size()); + if (sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + it_best = it; + } + } + + if (it_best != states.end()) { + LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot); + if (n != size) { + LLAMA_LOG_INFO("failed to restore state with size %zu\n", size); + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; + } + + void update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + // average size per token + const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); + + // dynamically increase the token limit if it can fit in the memory limit + const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size / size_per_token) : limit_tokens; + + //if (limit_tokens > 0) { + // + // while (states.size() > 1 && n_tokens() > limit_tokens_cur) { + // if (states.empty()) { + // break; + // } + + // LLAMA_LOG_INFO(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", + // limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); + + // states.pop_front(); + // } + //} + + LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); + + for (const auto& state : states) { + LLAMA_LOG_INFO(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", + (const void*)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } + } +}; + + struct server_slot { int id; int id_task = -1; @@ -589,9 +800,12 @@ struct server_slot { slot_state state = SLOT_STATE_IDLE; slot_command command = SLOT_COMMAND_NONE; + llama_context* ctx = nullptr; // used to determine the slot that has been used the longest int64_t t_last_used = -1; + std::unique_ptr task; + // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; @@ -627,6 +841,33 @@ struct server_slot { std::string oaicompat_model; std::string stopping_word; stop_type stop; + + server_prompt server_prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + assert(server_prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size(ctx, id); + + LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int)server_prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto* cur = prompt_cache.alloc(server_prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id); + } + + void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) { + bool res = prompt_cache.load(server_prompt, tokens, ctx, id); + if (!res) { + LLAMA_LOG_INFO("failed to load prompt from cache\n"); + } + } + + // sampling llama_token sampled; struct llama_sampling_params sparams; @@ -689,6 +930,8 @@ struct server_slot { chat_msg = {}; json_schema = json(); generated_tool_call_ids.clear(); + + task.reset(); } bool has_budget(gpt_params &global_params) { @@ -726,6 +969,7 @@ struct server_slot { if (state == SLOT_STATE_PROCESSING) { t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; command = SLOT_COMMAND_RELEASE; + task.reset(); } } @@ -1176,12 +1420,16 @@ struct server_context { server_queue queue_tasks; server_response queue_results; + std::unique_ptr prompt_cache; + server_metrics metrics; common_chat_templates_ptr chat_templates; oaicompat_parser_options oai_parser_opt; // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + int32_t cache_ram_n_min = 0; + float cache_ram_similarity = 0.5f; ~server_context() { if (ctx) { @@ -1340,6 +1588,7 @@ struct server_context { server_slot slot; slot.id = i; + slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params.n_predict; slot.mctx = mctx; @@ -1412,6 +1661,21 @@ struct server_context { metrics.init(); + if (params.cache_ram_mib != 0) { + if (params.cache_ram_mib < 0) { + LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit"); + } + else { + LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib); + } + LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + // only apply ram size limit. No token limit for now. + prompt_cache = std::make_unique(params.cache_ram_mib, 0); + } + else { + LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it @@ -1483,11 +1747,12 @@ struct server_context { server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; + bool update_cache = false; // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { int max_lcp_len = 0; - float similarity = 0; + float sim_best = 0; for (server_slot & slot : slots) { // skip the slot if it is not available @@ -1499,23 +1764,22 @@ struct server_context { if (cache_tokens.empty()) { continue; } - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - int lcp_len = cache_tokens.get_common_prefix(task.tokens); - // fraction of the common substring length compared to the current slot's prompt length - const float similarity = float(lcp_len) / task.tokens.size(); + size_t lcp_len = cache_tokens.get_common_prefix(task.tokens); + // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length + const float sim_cur = get_slot_similarity(lcp_len, task.tokens.size(), cache_tokens.size()); // select the current slot if the criteria match - if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; max_lcp_len = lcp_len; ret = &slot; } } - if (ret != nullptr) { LOG_VERBOSE("selected slot by lcp similarity", { {"id_slot", ret->id}, {"max_lcp_len", max_lcp_len}, - {"similarity", similarity}, + {"similarity", sim_best}, }); } } @@ -1528,7 +1792,6 @@ struct server_context { if (!slot.available()) { continue; } - // select the current slot if the criteria match if (slot.t_last_used < t_last) { t_last = slot.t_last_used; @@ -1543,7 +1806,46 @@ struct server_context { }); } } + if (ret) { + const auto& tokens = ret->cache_tokens; + float f_keep = 0.0f; + if (!tokens.empty()) { + size_t lcp_len = tokens.get_common_prefix(task.tokens); + f_keep = float(lcp_len) / tokens.size(); + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < cache_ram_similarity) { + update_cache = true; + } + } + update_cache = update_cache && prompt_cache; + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + // don't update the cache if the slot's context is above cache_ram_n_min + update_cache = update_cache && tokens.size() >= cache_ram_n_min; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + LLAMA_LOG_INFO("prompt cache: cache size: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n", + tokens.size(), cache_ram_n_min, f_keep, cache_ram_similarity); + if (update_cache) { + const int64_t t_start = ggml_time_us(); + LLAMA_LOG_INFO("updating prompt cache\n"); + ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + ret->prompt_save(*prompt_cache); + LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + // has prompts saved earlier to load + if (!prompt_cache->states.empty()) { + const int64_t t_start = ggml_time_us(); + ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + ret->prompt_load(*prompt_cache, task.tokens); + prompt_cache->update(); + ret->cache_tokens = server_tokens(ret->server_prompt.tokens.get_text_tokens(), false); // recover cache tokens + LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + } return ret; } @@ -3007,40 +3309,10 @@ struct server_context { slot.params.n_keep = slot.n_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { - if (!params.ctx_shift) { - send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - int n_keep = slot.params.n_keep; - int n_discard = erased_blocks * n_block_size; - llama_tokens new_tokens = prompt_tokens.get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - new_tokens.resize(prompt_tokens.size() - n_discard); - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - LOG_VERBOSE("input truncated", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, - }); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + if (slot.n_prompt_tokens >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER); + slot.release(); + continue; } llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); @@ -3881,6 +4153,8 @@ int main(int argc, char ** argv) { // Necessary similarity of prompt for slot selection ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + ctx_server.cache_ram_n_min = params.cache_ram_n_min; + ctx_server.cache_ram_similarity = params.cache_ram_similarity; #ifdef SQLITE3_MODERN_CPP_SUPPORT auto db_handle = std::make_shared(params.sql_save_file); bool sqlite_extension_loaded = false; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index fbd67573..67f4f816 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -255,6 +255,10 @@ static std::string gen_tool_call_id() { // // other common utils // +static float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length) { + float sim = float(lcp) * 2 / (prompt_length + cache_length); + return sim; +} static size_t common_part(const std::vector & a, const std::vector & b) { size_t i; @@ -1026,7 +1030,7 @@ public: } } - server_tokens(std::vector& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const std::vector& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} llama_pos pos_next() const { if (!has_mtmd) { @@ -1068,9 +1072,7 @@ public: if (it != map_pos_to_media.end()) { return it->second; } - else { - throw std::runtime_error("Chunk not found"); - } + throw std::runtime_error("Chunk not found"); } void push_back(llama_token tok) {