server: exclude thinking tokens when finding the slot (#1079)

refactor find slot

enable by default

Fix load prompt

rename variables

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-12-22 02:46:45 -06:00
committed by GitHub
parent 21fc9322f9
commit 5562605076
8 changed files with 247 additions and 33 deletions

View File

@@ -692,19 +692,40 @@ size_t server_prompt_cache::n_tokens() const {
}
bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new);
float f_keep_best = float(lcp_best.second) / prompt.tokens.size();
float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt);
thinking_tokens think_tokens;
for (auto it = states.begin(); it != states.end(); ++it) {
think_tokens = it->think_tokens;
break;
}
server_tokens prompt_tokens;
server_tokens tokens_new_ex;
if (think_tokens.exclude) {
prompt_tokens = server_tokens(prompt.tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
tokens_new_ex = server_tokens(tokens_new.get_text_tokens_exclude_think(ctx, think_tokens), false);
}
else {
prompt_tokens = std::move(prompt.tokens); //server_tokens(prompt.tokens.get_text_tokens(), false);
tokens_new_ex = server_tokens(tokens_new.get_text_tokens(), false);
}
const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex);
float f_keep_best = float(lcp_best.second) / prompt_tokens.size();
float sim_best = prompt_tokens.get_tokens_similarity(ctx, tokens_new_ex, prompt.n_kept_prompt, prompt.n_discarded_prompt);
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt);
auto it_best = states.end();
// find the most similar cached prompt, that would also preserve the most context
for (auto it = states.begin(); it != states.end(); ++it) {
const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new);
const float f_keep_cur = float(lcp_cur.first) / it->tokens.size();
const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt);
server_tokens tokens;
if (think_tokens.exclude) {
tokens = server_tokens(it->tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
}
else {
tokens = std::move(it->tokens);
}
const auto lcp_cur = tokens.get_common_prefix(ctx, tokens_new_ex);
const float f_keep_cur = float(lcp_cur.first) / tokens.size();
const float sim_cur = tokens.get_tokens_similarity(ctx, tokens_new_ex, it->n_kept_prompt, it->n_discarded_prompt);
if (sim_best < sim_cur) {
f_keep_best = f_keep_cur;
sim_best = sim_cur;
@@ -778,6 +799,7 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.n_keep =*/ prompt.n_kept_prompt,
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
/*.think_tokens =*/ prompt.think_tokens,
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};