mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
server: exclude thinking tokens when finding the slot (#1079)
refactor find slot enable by default Fix load prompt rename variables Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -176,7 +176,13 @@ void server_context::init() {
|
||||
slot.n_predict = params.n_predict;
|
||||
slot.mctx = mctx;
|
||||
slot.cache_tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
slot.params.think_tokens = params.think_tokens;
|
||||
if (params.think_tokens.exclude) {
|
||||
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params.think_tokens.begin.c_str(), params.think_tokens.end.c_str() );
|
||||
}
|
||||
else {
|
||||
SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
|
||||
}
|
||||
LOG_INFO("new slot", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_ctx_slot", slot.n_ctx}
|
||||
@@ -585,6 +591,44 @@ server_slot* server_context::get_slot_by_id(int id) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
float server_context::calculate_slot_f_keep(const server_slot & slot, llama_context * ctx,const server_tokens & a, const server_tokens & b) {
|
||||
float f_keep = 0.0f;
|
||||
if (!a.empty()) {
|
||||
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
|
||||
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, slot.params.n_keep + add_bos_token, slot.n_discarded_prompt);
|
||||
}
|
||||
else {
|
||||
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, 0, 0);
|
||||
}
|
||||
}
|
||||
return f_keep;
|
||||
}
|
||||
|
||||
std::pair<common_prefix, float> server_context::calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b) {
|
||||
std::pair<common_prefix, float> sim;
|
||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||
common_prefix lcp_len = a.get_common_prefix(slot.ctx, b);
|
||||
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
|
||||
float sim_cur = a.get_tokens_similarity(slot.ctx, b, 0, 0);
|
||||
// handle context shift
|
||||
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
|
||||
float sim_cur_ctx_shift = a.get_tokens_similarity(slot.ctx, b, slot.n_kept_prompt, slot.n_discarded_prompt);
|
||||
if (sim_cur_ctx_shift > sim_cur) {
|
||||
sim_cur = sim_cur_ctx_shift;
|
||||
}
|
||||
}
|
||||
sim.first = lcp_len;
|
||||
sim.second = sim_cur;
|
||||
return sim;
|
||||
}
|
||||
|
||||
void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
|
||||
slot.server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
|
||||
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
|
||||
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
|
||||
}
|
||||
|
||||
server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
server_slot* ret = nullptr;
|
||||
bool update_cache = false;
|
||||
@@ -599,22 +643,25 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
if (!slot.available()) {
|
||||
continue;
|
||||
}
|
||||
const auto& cache_tokens = slot.cache_tokens;
|
||||
auto& cache_tokens = slot.cache_tokens;
|
||||
// skip the slot if it does not contains prompt
|
||||
if (cache_tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||
auto lcp_len = cache_tokens.get_common_prefix(slot.ctx, task.tokens);
|
||||
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
|
||||
float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0);
|
||||
// handle context shift
|
||||
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size() >= slot.n_ctx) {
|
||||
float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt);
|
||||
if (sim_cur_ctx_shift > sim_cur) {
|
||||
sim_cur = sim_cur_ctx_shift;
|
||||
}
|
||||
bool exclude_think = !cache_tokens.has_mtmd && slot.params.think_tokens.exclude;
|
||||
std::pair<common_prefix, float> sim;
|
||||
if (exclude_think) {
|
||||
auto temp = slot.cache_tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
||||
server_tokens cache_tokens_exclude_think = server_tokens(temp, false);
|
||||
temp = task.tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
||||
server_tokens prompt_tokens_exclude_think = server_tokens(temp, false);
|
||||
sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
|
||||
}
|
||||
else {
|
||||
sim = calculate_slot_similarity(slot, ctx, cache_tokens, task.tokens);
|
||||
}
|
||||
common_prefix lcp_len = sim.first;
|
||||
float sim_cur = sim.second;
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
|
||||
@@ -655,26 +702,36 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
}
|
||||
}
|
||||
if (ret) {
|
||||
const auto& tokens = ret->cache_tokens;
|
||||
float f_keep = 0.0f;
|
||||
auto& tokens = ret->cache_tokens;
|
||||
float f_keep = 0;
|
||||
size_t cache_token_size = tokens.size();
|
||||
if (!tokens.empty()) {
|
||||
if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) {
|
||||
f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt);
|
||||
bool exclude_think = !tokens.has_mtmd && ret->params.think_tokens.exclude;
|
||||
if (exclude_think) {
|
||||
auto temp = tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
||||
server_tokens cache_exclude_think = server_tokens(temp, false);
|
||||
|
||||
temp = task.tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
||||
server_tokens prompt_exclude_think = server_tokens(temp, false);
|
||||
|
||||
cache_token_size = cache_exclude_think.size();
|
||||
f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
|
||||
}
|
||||
else {
|
||||
f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, 0, 0);
|
||||
f_keep = calculate_slot_f_keep(*ret, ret->ctx, tokens, task.tokens);
|
||||
}
|
||||
// if we are about to lose a large portion of the existing context - save it in the prompt cache
|
||||
if (f_keep < cache_ram_similarity) {
|
||||
update_cache = true;
|
||||
}
|
||||
}
|
||||
|
||||
update_cache = update_cache && prompt_cache;
|
||||
// cache prompts only for completion tasks
|
||||
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// don't update the cache if the slot's context is above cache_ram_n_min
|
||||
update_cache = update_cache && tokens.size() >= cache_ram_n_min;
|
||||
update_cache = update_cache && cache_token_size >= cache_ram_n_min;
|
||||
|
||||
// TODO: mtmd does not support prompt cache
|
||||
update_cache = update_cache && (ret->mctx == nullptr);
|
||||
@@ -684,9 +741,8 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
if (update_cache) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
LLAMA_LOG_INFO("updating prompt cache\n");
|
||||
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
|
||||
ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
|
||||
// copy cache tokens
|
||||
copy_data_to_cached_prompt(tokens, *ret);
|
||||
|
||||
ret->prompt_save(*prompt_cache);
|
||||
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
||||
@@ -694,9 +750,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
// has prompts saved earlier to load
|
||||
if (prompt_cache && !prompt_cache->states.empty()) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
|
||||
ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
|
||||
copy_data_to_cached_prompt(tokens, *ret);
|
||||
|
||||
ret->prompt_load(*prompt_cache, task.tokens);
|
||||
prompt_cache->update();
|
||||
@@ -1959,7 +2013,6 @@ void server_context::context_shift_find_n_tokens(llama_context* ctx, const serve
|
||||
}
|
||||
|
||||
void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) {
|
||||
//server_tokens prompt_tokens = std::move(slot.prompt_tokens);
|
||||
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
|
||||
const int n_left = slot.n_ctx - n_keep;
|
||||
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||
|
||||
Reference in New Issue
Block a user