server: cache prompt to host memory (#954)

* server : host-memory prompt caching

change similarity calculation and prompt save conditions

Remove unneeded token limit

rename variable

Separate prompt save and load logic

change default values

change log

remove truncate prompt logic

* add description

* bug fixes

* remove token limit in init

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-11-14 16:40:13 +00:00
committed by GitHub
parent 00dffb5e68
commit bb358223cd
4 changed files with 347 additions and 50 deletions

View File

@@ -139,6 +139,7 @@ struct server_task {
int id = -1; // to be filled by server_queue
int id_multi = -1;
int id_target = -1;
//int id_slot = -1;
// used by SERVER_TASK_TYPE_INFERENCE
server_tokens tokens;
@@ -148,6 +149,10 @@ struct server_task {
bool infill = false;
bool embedding = false;
server_task() = default;
server_task(server_task_type type) : type(type) {}
};
struct server_task_result {
@@ -531,7 +536,7 @@ struct server_task_result {
}
};
inline std::string stop_type_to_str(stop_type type) {
static inline std::string stop_type_to_str(stop_type type) {
switch (type) {
case STOP_TYPE_EOS: return "eos";
case STOP_TYPE_WORD: return "word";
@@ -579,6 +584,212 @@ struct slot_params {
};
struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
std::list<server_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
for (const auto& checkpoint : checkpoints) {
res += checkpoint.size();
}
return res;
}
int n_tokens() const {
return tokens.size();
}
};
struct server_prompt_cache {
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib);
this->limit_tokens = limit_tokens;
}
std::list<server_prompt> states;
// in bytes, 0 = no limit
size_t limit_size = 0;
// in tokens, 0 = no limit
size_t limit_tokens = 0;
size_t size() const {
size_t res = 0;
for (const auto& state : states) {
res += state.size();
}
return res;
}
size_t n_tokens() const {
size_t res = 0;
for (const auto& state : states) {
res += state.n_tokens();
}
return res;
}
server_prompt* alloc(const server_prompt& prompt, size_t state_size) {
for (auto it = states.begin(); it != states.end();) {
const size_t len = it->tokens.get_common_prefix(prompt.tokens);
// first check if the current state is contained fully in the cache
if (len == prompt.tokens.size()) {
LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n");
return nullptr;
}
// next, remove any cached prompts that are fully contained in the current prompt
else if(len == it->tokens.size()) {
LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", len);
it = states.erase(it);
}
else {
++it;
}
}
std::vector<uint8_t> state_data;
// check if we can allocate enough memory for the new state
try {
state_data.resize(state_size);
}
catch (const std::bad_alloc& e) {
LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what());
limit_size = std::max<size_t>(1, 0.4 * size());
LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
update();
return nullptr;
}
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
auto& cur = states.emplace_back();
cur = {
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};
return &cur;
}
bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = float(lcp_best) / prompt.tokens.size();
//float sim_best = float(lcp_best) / tokens_new.size();
float sim_best = get_slot_similarity(lcp_best, tokens_new.size(), prompt.tokens.size());
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
auto it_best = states.end();
// find the most similar cached prompt, that would also preserve the most context
for (auto it = states.begin(); it != states.end(); ++it) {
const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
const float f_keep_cur = float(lcp_cur) / it->tokens.size();
//const float sim_cur = float(lcp_cur) / tokens_new.size();
const float sim_cur = get_slot_similarity(lcp_cur, tokens_new.size(), it->tokens.size());
if (sim_best < sim_cur) {
f_keep_best = f_keep_cur;
sim_best = sim_cur;
it_best = it;
}
}
if (it_best != states.end()) {
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
if (n != size) {
LLAMA_LOG_INFO("failed to restore state with size %zu\n", size);
return false;
}
it_best->data.clear();
it_best->data.shrink_to_fit();
prompt = std::move(*it_best);
states.erase(it_best);
}
return true;
}
void update() {
if (limit_size > 0) {
// always keep at least one state, regardless of the limits
while (states.size() > 1 && size() > limit_size) {
if (states.empty()) {
break;
}
LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
states.pop_front();
}
}
// average size per token
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
// dynamically increase the token limit if it can fit in the memory limit
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size / size_per_token) : limit_tokens;
//if (limit_tokens > 0) {
//
// while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
// if (states.empty()) {
// break;
// }
// LLAMA_LOG_INFO(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
// limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
// states.pop_front();
// }
//}
LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
for (const auto& state : states) {
LLAMA_LOG_INFO(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
(const void*)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
}
}
};
struct server_slot {
int id;
int id_task = -1;
@@ -589,9 +800,12 @@ struct server_slot {
slot_state state = SLOT_STATE_IDLE;
slot_command command = SLOT_COMMAND_NONE;
llama_context* ctx = nullptr;
// used to determine the slot that has been used the longest
int64_t t_last_used = -1;
std::unique_ptr<const server_task> task;
// generation props
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
@@ -627,6 +841,33 @@ struct server_slot {
std::string oaicompat_model;
std::string stopping_word;
stop_type stop;
server_prompt server_prompt;
void prompt_save(server_prompt_cache & prompt_cache) const {
assert(server_prompt.data.size() == 0);
const size_t cur_size = llama_state_seq_get_size(ctx, id);
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
(int)server_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
auto* cur = prompt_cache.alloc(server_prompt, cur_size);
if (cur == nullptr) {
return;
}
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id);
}
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
bool res = prompt_cache.load(server_prompt, tokens, ctx, id);
if (!res) {
LLAMA_LOG_INFO("failed to load prompt from cache\n");
}
}
// sampling
llama_token sampled;
struct llama_sampling_params sparams;
@@ -689,6 +930,8 @@ struct server_slot {
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
task.reset();
}
bool has_budget(gpt_params &global_params) {
@@ -726,6 +969,7 @@ struct server_slot {
if (state == SLOT_STATE_PROCESSING) {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = SLOT_COMMAND_RELEASE;
task.reset();
}
}
@@ -1176,12 +1420,16 @@ struct server_context {
server_queue queue_tasks;
server_response queue_results;
std::unique_ptr<server_prompt_cache> prompt_cache;
server_metrics metrics;
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
int32_t cache_ram_n_min = 0;
float cache_ram_similarity = 0.5f;
~server_context() {
if (ctx) {
@@ -1340,6 +1588,7 @@ struct server_context {
server_slot slot;
slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
slot.mctx = mctx;
@@ -1412,6 +1661,21 @@ struct server_context {
metrics.init();
if (params.cache_ram_mib != 0) {
if (params.cache_ram_mib < 0) {
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit");
}
else {
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib);
}
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
// only apply ram size limit. No token limit for now.
prompt_cache = std::make_unique<server_prompt_cache>(params.cache_ram_mib, 0);
}
else {
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
}
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
@@ -1483,11 +1747,12 @@ struct server_context {
server_slot * get_available_slot(const server_task & task) {
server_slot * ret = nullptr;
bool update_cache = false;
// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
int max_lcp_len = 0;
float similarity = 0;
float sim_best = 0;
for (server_slot & slot : slots) {
// skip the slot if it is not available
@@ -1499,23 +1764,22 @@ struct server_context {
if (cache_tokens.empty()) {
continue;
}
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = cache_tokens.get_common_prefix(task.tokens);
// fraction of the common substring length compared to the current slot's prompt length
const float similarity = float(lcp_len) / task.tokens.size();
size_t lcp_len = cache_tokens.get_common_prefix(task.tokens);
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
const float sim_cur = get_slot_similarity(lcp_len, task.tokens.size(), cache_tokens.size());
// select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
sim_best = sim_cur;
max_lcp_len = lcp_len;
ret = &slot;
}
}
if (ret != nullptr) {
LOG_VERBOSE("selected slot by lcp similarity", {
{"id_slot", ret->id},
{"max_lcp_len", max_lcp_len},
{"similarity", similarity},
{"similarity", sim_best},
});
}
}
@@ -1528,7 +1792,6 @@ struct server_context {
if (!slot.available()) {
continue;
}
// select the current slot if the criteria match
if (slot.t_last_used < t_last) {
t_last = slot.t_last_used;
@@ -1543,7 +1806,46 @@ struct server_context {
});
}
}
if (ret) {
const auto& tokens = ret->cache_tokens;
float f_keep = 0.0f;
if (!tokens.empty()) {
size_t lcp_len = tokens.get_common_prefix(task.tokens);
f_keep = float(lcp_len) / tokens.size();
// if we are about to lose a large portion of the existing context - save it in the prompt cache
if (f_keep < cache_ram_similarity) {
update_cache = true;
}
}
update_cache = update_cache && prompt_cache;
// cache prompts only for completion tasks
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
// don't update the cache if the slot's context is above cache_ram_n_min
update_cache = update_cache && tokens.size() >= cache_ram_n_min;
// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);
LLAMA_LOG_INFO("prompt cache: cache size: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
tokens.size(), cache_ram_n_min, f_keep, cache_ram_similarity);
if (update_cache) {
const int64_t t_start = ggml_time_us();
LLAMA_LOG_INFO("updating prompt cache\n");
ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
ret->prompt_save(*prompt_cache);
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
// has prompts saved earlier to load
if (!prompt_cache->states.empty()) {
const int64_t t_start = ggml_time_us();
ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
ret->prompt_load(*prompt_cache, task.tokens);
prompt_cache->update();
ret->cache_tokens = server_tokens(ret->server_prompt.tokens.get_text_tokens(), false); // recover cache tokens
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
}
return ret;
}
@@ -3007,40 +3309,10 @@ struct server_context {
slot.params.n_keep = slot.n_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
if (!params.ctx_shift) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
slot.release();
continue;
}
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
int n_keep = slot.params.n_keep;
int n_discard = erased_blocks * n_block_size;
llama_tokens new_tokens = prompt_tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
}
new_tokens.resize(prompt_tokens.size() - n_discard);
prompt_tokens.clear();
prompt_tokens.insert(new_tokens);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("input truncated", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
});
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
if (slot.n_prompt_tokens >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
slot.release();
continue;
}
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
@@ -3881,6 +4153,8 @@ int main(int argc, char ** argv) {
// Necessary similarity of prompt for slot selection
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
ctx_server.cache_ram_n_min = params.cache_ram_n_min;
ctx_server.cache_ram_similarity = params.cache_ram_similarity;
#ifdef SQLITE3_MODERN_CPP_SUPPORT
auto db_handle = std::make_shared<DatabaseHandle>(params.sql_save_file);
bool sqlite_extension_loaded = false;