server: exclude thinking tokens when finding the slot (#1079)

refactor find slot

enable by default

Fix load prompt

rename variables

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-12-22 02:46:45 -06:00
committed by GitHub
parent 21fc9322f9
commit 5562605076
8 changed files with 247 additions and 33 deletions

View File

@@ -176,7 +176,13 @@ void server_context::init() {
slot.n_predict = params.n_predict;
slot.mctx = mctx;
slot.cache_tokens.has_mtmd = mctx != nullptr;
slot.params.think_tokens = params.think_tokens;
if (params.think_tokens.exclude) {
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params.think_tokens.begin.c_str(), params.think_tokens.end.c_str() );
}
else {
SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
}
LOG_INFO("new slot", {
{"id_slot", slot.id},
{"n_ctx_slot", slot.n_ctx}
@@ -585,6 +591,44 @@ server_slot* server_context::get_slot_by_id(int id) {
return nullptr;
}
float server_context::calculate_slot_f_keep(const server_slot & slot, llama_context * ctx,const server_tokens & a, const server_tokens & b) {
float f_keep = 0.0f;
if (!a.empty()) {
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, slot.params.n_keep + add_bos_token, slot.n_discarded_prompt);
}
else {
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, 0, 0);
}
}
return f_keep;
}
std::pair<common_prefix, float> server_context::calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b) {
std::pair<common_prefix, float> sim;
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
common_prefix lcp_len = a.get_common_prefix(slot.ctx, b);
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
float sim_cur = a.get_tokens_similarity(slot.ctx, b, 0, 0);
// handle context shift
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
float sim_cur_ctx_shift = a.get_tokens_similarity(slot.ctx, b, slot.n_kept_prompt, slot.n_discarded_prompt);
if (sim_cur_ctx_shift > sim_cur) {
sim_cur = sim_cur_ctx_shift;
}
}
sim.first = lcp_len;
sim.second = sim_cur;
return sim;
}
void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
slot.server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
}
server_slot* server_context::get_available_slot(const server_task& task) {
server_slot* ret = nullptr;
bool update_cache = false;
@@ -599,22 +643,25 @@ server_slot* server_context::get_available_slot(const server_task& task) {
if (!slot.available()) {
continue;
}
const auto& cache_tokens = slot.cache_tokens;
auto& cache_tokens = slot.cache_tokens;
// skip the slot if it does not contains prompt
if (cache_tokens.empty()) {
continue;
}
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
auto lcp_len = cache_tokens.get_common_prefix(slot.ctx, task.tokens);
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0);
// handle context shift
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size() >= slot.n_ctx) {
float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt);
if (sim_cur_ctx_shift > sim_cur) {
sim_cur = sim_cur_ctx_shift;
}
bool exclude_think = !cache_tokens.has_mtmd && slot.params.think_tokens.exclude;
std::pair<common_prefix, float> sim;
if (exclude_think) {
auto temp = slot.cache_tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
server_tokens cache_tokens_exclude_think = server_tokens(temp, false);
temp = task.tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
server_tokens prompt_tokens_exclude_think = server_tokens(temp, false);
sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
}
else {
sim = calculate_slot_similarity(slot, ctx, cache_tokens, task.tokens);
}
common_prefix lcp_len = sim.first;
float sim_cur = sim.second;
// select the current slot if the criteria match
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
@@ -655,26 +702,36 @@ server_slot* server_context::get_available_slot(const server_task& task) {
}
}
if (ret) {
const auto& tokens = ret->cache_tokens;
float f_keep = 0.0f;
auto& tokens = ret->cache_tokens;
float f_keep = 0;
size_t cache_token_size = tokens.size();
if (!tokens.empty()) {
if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) {
f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt);
bool exclude_think = !tokens.has_mtmd && ret->params.think_tokens.exclude;
if (exclude_think) {
auto temp = tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
server_tokens cache_exclude_think = server_tokens(temp, false);
temp = task.tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
server_tokens prompt_exclude_think = server_tokens(temp, false);
cache_token_size = cache_exclude_think.size();
f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
}
else {
f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, 0, 0);
f_keep = calculate_slot_f_keep(*ret, ret->ctx, tokens, task.tokens);
}
// if we are about to lose a large portion of the existing context - save it in the prompt cache
if (f_keep < cache_ram_similarity) {
update_cache = true;
}
}
update_cache = update_cache && prompt_cache;
// cache prompts only for completion tasks
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
// don't update the cache if the slot's context is above cache_ram_n_min
update_cache = update_cache && tokens.size() >= cache_ram_n_min;
update_cache = update_cache && cache_token_size >= cache_ram_n_min;
// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);
@@ -684,9 +741,8 @@ server_slot* server_context::get_available_slot(const server_task& task) {
if (update_cache) {
const int64_t t_start = ggml_time_us();
LLAMA_LOG_INFO("updating prompt cache\n");
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
// copy cache tokens
copy_data_to_cached_prompt(tokens, *ret);
ret->prompt_save(*prompt_cache);
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
@@ -694,9 +750,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
// has prompts saved earlier to load
if (prompt_cache && !prompt_cache->states.empty()) {
const int64_t t_start = ggml_time_us();
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
copy_data_to_cached_prompt(tokens, *ret);
ret->prompt_load(*prompt_cache, task.tokens);
prompt_cache->update();
@@ -1959,7 +2013,6 @@ void server_context::context_shift_find_n_tokens(llama_context* ctx, const serve
}
void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) {
//server_tokens prompt_tokens = std::move(slot.prompt_tokens);
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
const int n_left = slot.n_ctx - n_keep;
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);