mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 08:34:09 +00:00
server: cache prompt to host memory (#954)
* server : host-memory prompt caching change similarity calculation and prompt save conditions Remove unneeded token limit rename variable Separate prompt save and load logic change default values change log remove truncate prompt logic * add description * bug fixes * remove token limit in init --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -139,6 +139,7 @@ struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int id_multi = -1;
|
||||
int id_target = -1;
|
||||
//int id_slot = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
server_tokens tokens;
|
||||
@@ -148,6 +149,10 @@ struct server_task {
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
|
||||
server_task() = default;
|
||||
server_task(server_task_type type) : type(type) {}
|
||||
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
@@ -531,7 +536,7 @@ struct server_task_result {
|
||||
}
|
||||
};
|
||||
|
||||
inline std::string stop_type_to_str(stop_type type) {
|
||||
static inline std::string stop_type_to_str(stop_type type) {
|
||||
switch (type) {
|
||||
case STOP_TYPE_EOS: return "eos";
|
||||
case STOP_TYPE_WORD: return "word";
|
||||
@@ -579,6 +584,212 @@ struct slot_params {
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
std::list<server_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const {
|
||||
size_t res = data.size();
|
||||
|
||||
for (const auto& checkpoint : checkpoints) {
|
||||
res += checkpoint.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
|
||||
this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib);
|
||||
this->limit_tokens = limit_tokens;
|
||||
}
|
||||
|
||||
std::list<server_prompt> states;
|
||||
|
||||
// in bytes, 0 = no limit
|
||||
size_t limit_size = 0;
|
||||
|
||||
// in tokens, 0 = no limit
|
||||
size_t limit_tokens = 0;
|
||||
|
||||
size_t size() const {
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto& state : states) {
|
||||
res += state.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t n_tokens() const {
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto& state : states) {
|
||||
res += state.n_tokens();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
server_prompt* alloc(const server_prompt& prompt, size_t state_size) {
|
||||
for (auto it = states.begin(); it != states.end();) {
|
||||
const size_t len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
|
||||
// first check if the current state is contained fully in the cache
|
||||
if (len == prompt.tokens.size()) {
|
||||
LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n");
|
||||
return nullptr;
|
||||
}
|
||||
// next, remove any cached prompts that are fully contained in the current prompt
|
||||
else if(len == it->tokens.size()) {
|
||||
LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", len);
|
||||
it = states.erase(it);
|
||||
}
|
||||
else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
}
|
||||
catch (const std::bad_alloc& e) {
|
||||
LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
limit_size = std::max<size_t>(1, 0.4 * size());
|
||||
|
||||
LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
|
||||
|
||||
update();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
|
||||
auto& cur = states.emplace_back();
|
||||
cur = {
|
||||
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
|
||||
return &cur;
|
||||
}
|
||||
|
||||
bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
|
||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||
|
||||
float f_keep_best = float(lcp_best) / prompt.tokens.size();
|
||||
//float sim_best = float(lcp_best) / tokens_new.size();
|
||||
float sim_best = get_slot_similarity(lcp_best, tokens_new.size(), prompt.tokens.size());
|
||||
|
||||
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
auto it_best = states.end();
|
||||
|
||||
// find the most similar cached prompt, that would also preserve the most context
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
|
||||
|
||||
const float f_keep_cur = float(lcp_cur) / it->tokens.size();
|
||||
//const float sim_cur = float(lcp_cur) / tokens_new.size();
|
||||
const float sim_cur = get_slot_similarity(lcp_cur, tokens_new.size(), it->tokens.size());
|
||||
if (sim_best < sim_cur) {
|
||||
f_keep_best = f_keep_cur;
|
||||
sim_best = sim_cur;
|
||||
it_best = it;
|
||||
}
|
||||
}
|
||||
|
||||
if (it_best != states.end()) {
|
||||
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
|
||||
if (n != size) {
|
||||
LLAMA_LOG_INFO("failed to restore state with size %zu\n", size);
|
||||
return false;
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
states.erase(it_best);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void update() {
|
||||
if (limit_size > 0) {
|
||||
// always keep at least one state, regardless of the limits
|
||||
while (states.size() > 1 && size() > limit_size) {
|
||||
if (states.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
|
||||
|
||||
states.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// average size per token
|
||||
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
|
||||
|
||||
// dynamically increase the token limit if it can fit in the memory limit
|
||||
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size / size_per_token) : limit_tokens;
|
||||
|
||||
//if (limit_tokens > 0) {
|
||||
//
|
||||
// while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
|
||||
// if (states.empty()) {
|
||||
// break;
|
||||
// }
|
||||
|
||||
// LLAMA_LOG_INFO(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
|
||||
// limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
|
||||
|
||||
// states.pop_front();
|
||||
// }
|
||||
//}
|
||||
|
||||
LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
||||
|
||||
for (const auto& state : states) {
|
||||
LLAMA_LOG_INFO(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
|
||||
(const void*)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
@@ -589,9 +800,12 @@ struct server_slot {
|
||||
slot_state state = SLOT_STATE_IDLE;
|
||||
slot_command command = SLOT_COMMAND_NONE;
|
||||
|
||||
llama_context* ctx = nullptr;
|
||||
// used to determine the slot that has been used the longest
|
||||
int64_t t_last_used = -1;
|
||||
|
||||
std::unique_ptr<const server_task> task;
|
||||
|
||||
// generation props
|
||||
int32_t n_ctx = 0; // context size per slot
|
||||
int32_t n_past = 0;
|
||||
@@ -627,6 +841,33 @@ struct server_slot {
|
||||
std::string oaicompat_model;
|
||||
std::string stopping_word;
|
||||
stop_type stop;
|
||||
|
||||
server_prompt server_prompt;
|
||||
|
||||
void prompt_save(server_prompt_cache & prompt_cache) const {
|
||||
assert(server_prompt.data.size() == 0);
|
||||
|
||||
const size_t cur_size = llama_state_seq_get_size(ctx, id);
|
||||
|
||||
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
||||
(int)server_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
||||
|
||||
auto* cur = prompt_cache.alloc(server_prompt, cur_size);
|
||||
if (cur == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id);
|
||||
}
|
||||
|
||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
|
||||
bool res = prompt_cache.load(server_prompt, tokens, ctx, id);
|
||||
if (!res) {
|
||||
LLAMA_LOG_INFO("failed to load prompt from cache\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// sampling
|
||||
llama_token sampled;
|
||||
struct llama_sampling_params sparams;
|
||||
@@ -689,6 +930,8 @@ struct server_slot {
|
||||
chat_msg = {};
|
||||
json_schema = json();
|
||||
generated_tool_call_ids.clear();
|
||||
|
||||
task.reset();
|
||||
}
|
||||
|
||||
bool has_budget(gpt_params &global_params) {
|
||||
@@ -726,6 +969,7 @@ struct server_slot {
|
||||
if (state == SLOT_STATE_PROCESSING) {
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
command = SLOT_COMMAND_RELEASE;
|
||||
task.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1176,12 +1420,16 @@ struct server_context {
|
||||
server_queue queue_tasks;
|
||||
server_response queue_results;
|
||||
|
||||
std::unique_ptr<server_prompt_cache> prompt_cache;
|
||||
|
||||
server_metrics metrics;
|
||||
|
||||
common_chat_templates_ptr chat_templates;
|
||||
oaicompat_parser_options oai_parser_opt;
|
||||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
int32_t cache_ram_n_min = 0;
|
||||
float cache_ram_similarity = 0.5f;
|
||||
|
||||
~server_context() {
|
||||
if (ctx) {
|
||||
@@ -1340,6 +1588,7 @@ struct server_context {
|
||||
server_slot slot;
|
||||
|
||||
slot.id = i;
|
||||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.n_predict = params.n_predict;
|
||||
slot.mctx = mctx;
|
||||
@@ -1412,6 +1661,21 @@ struct server_context {
|
||||
|
||||
metrics.init();
|
||||
|
||||
if (params.cache_ram_mib != 0) {
|
||||
if (params.cache_ram_mib < 0) {
|
||||
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit");
|
||||
}
|
||||
else {
|
||||
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
|
||||
// only apply ram size limit. No token limit for now.
|
||||
prompt_cache = std::make_unique<server_prompt_cache>(params.cache_ram_mib, 0);
|
||||
}
|
||||
else {
|
||||
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
|
||||
}
|
||||
|
||||
// thinking is enabled if:
|
||||
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
||||
// 2. The chat template supports it
|
||||
@@ -1483,11 +1747,12 @@ struct server_context {
|
||||
|
||||
server_slot * get_available_slot(const server_task & task) {
|
||||
server_slot * ret = nullptr;
|
||||
bool update_cache = false;
|
||||
|
||||
// find the slot that has at least n% prompt similarity
|
||||
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
||||
int max_lcp_len = 0;
|
||||
float similarity = 0;
|
||||
float sim_best = 0;
|
||||
|
||||
for (server_slot & slot : slots) {
|
||||
// skip the slot if it is not available
|
||||
@@ -1499,23 +1764,22 @@ struct server_context {
|
||||
if (cache_tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||
int lcp_len = cache_tokens.get_common_prefix(task.tokens);
|
||||
// fraction of the common substring length compared to the current slot's prompt length
|
||||
const float similarity = float(lcp_len) / task.tokens.size();
|
||||
size_t lcp_len = cache_tokens.get_common_prefix(task.tokens);
|
||||
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
|
||||
const float sim_cur = get_slot_similarity(lcp_len, task.tokens.size(), cache_tokens.size());
|
||||
// select the current slot if the criteria match
|
||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
||||
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
|
||||
sim_best = sim_cur;
|
||||
max_lcp_len = lcp_len;
|
||||
ret = &slot;
|
||||
}
|
||||
}
|
||||
|
||||
if (ret != nullptr) {
|
||||
LOG_VERBOSE("selected slot by lcp similarity", {
|
||||
{"id_slot", ret->id},
|
||||
{"max_lcp_len", max_lcp_len},
|
||||
{"similarity", similarity},
|
||||
{"similarity", sim_best},
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1528,7 +1792,6 @@ struct server_context {
|
||||
if (!slot.available()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (slot.t_last_used < t_last) {
|
||||
t_last = slot.t_last_used;
|
||||
@@ -1543,7 +1806,46 @@ struct server_context {
|
||||
});
|
||||
}
|
||||
}
|
||||
if (ret) {
|
||||
const auto& tokens = ret->cache_tokens;
|
||||
float f_keep = 0.0f;
|
||||
if (!tokens.empty()) {
|
||||
size_t lcp_len = tokens.get_common_prefix(task.tokens);
|
||||
f_keep = float(lcp_len) / tokens.size();
|
||||
// if we are about to lose a large portion of the existing context - save it in the prompt cache
|
||||
if (f_keep < cache_ram_similarity) {
|
||||
update_cache = true;
|
||||
}
|
||||
}
|
||||
update_cache = update_cache && prompt_cache;
|
||||
// cache prompts only for completion tasks
|
||||
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// don't update the cache if the slot's context is above cache_ram_n_min
|
||||
update_cache = update_cache && tokens.size() >= cache_ram_n_min;
|
||||
|
||||
// TODO: mtmd does not support prompt cache
|
||||
update_cache = update_cache && (ret->mctx == nullptr);
|
||||
|
||||
LLAMA_LOG_INFO("prompt cache: cache size: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
|
||||
tokens.size(), cache_ram_n_min, f_keep, cache_ram_similarity);
|
||||
if (update_cache) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
LLAMA_LOG_INFO("updating prompt cache\n");
|
||||
ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
ret->prompt_save(*prompt_cache);
|
||||
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
// has prompts saved earlier to load
|
||||
if (!prompt_cache->states.empty()) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
ret->server_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
ret->prompt_load(*prompt_cache, task.tokens);
|
||||
prompt_cache->update();
|
||||
ret->cache_tokens = server_tokens(ret->server_prompt.tokens.get_text_tokens(), false); // recover cache tokens
|
||||
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -3007,40 +3309,10 @@ struct server_context {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
if (!params.ctx_shift) {
|
||||
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
int n_keep = slot.params.n_keep;
|
||||
int n_discard = erased_blocks * n_block_size;
|
||||
llama_tokens new_tokens = prompt_tokens.get_text_tokens(); // copy
|
||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||
new_tokens[i - n_discard] = new_tokens[i];
|
||||
}
|
||||
new_tokens.resize(prompt_tokens.size() - n_discard);
|
||||
prompt_tokens.clear();
|
||||
prompt_tokens.insert(new_tokens);
|
||||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"n_prompt_tokens", slot.n_prompt_tokens},
|
||||
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
|
||||
});
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
@@ -3881,6 +4153,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// Necessary similarity of prompt for slot selection
|
||||
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
||||
ctx_server.cache_ram_n_min = params.cache_ram_n_min;
|
||||
ctx_server.cache_ram_similarity = params.cache_ram_similarity;
|
||||
#ifdef SQLITE3_MODERN_CPP_SUPPORT
|
||||
auto db_handle = std::make_shared<DatabaseHandle>(params.sql_save_file);
|
||||
bool sqlite_extension_loaded = false;
|
||||
|
||||
Reference in New Issue
Block a user