diff --git a/common/common.cpp b/common/common.cpp
index 664a6f4c..17b9f72d 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -253,6 +253,30 @@ common_webui common_webui_from_name(const std::string& format) {
}
}
+thinking_tokens thinking_tokens_from_string(const std::string& format) {
+ thinking_tokens think_token;
+ std::string token_string = string_strip(format);
+ if (token_string == "none" || token_string == "None") {
+ think_token.exclude = false;
+ return think_token;
+ }
+ else if (token_string == "auto" || token_string == "Auto") {
+ think_token.exclude = true;
+ think_token.begin = "";
+ think_token.end = "";
+ return think_token;
+ }
+ // Use user provided think tokens
+ auto start_end = string_split(format, ",");
+ if (start_end.size() == 2) {
+ think_token.exclude = true;
+ think_token.begin = start_end[0];
+ think_token.end = start_end[1];
+ }
+ return think_token;
+}
+
+
static std::string read_file(const std::string& fname) {
std::ifstream file(fname);
if (!file) {
@@ -1745,6 +1769,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
+ if (arg == "--reasoning-tokens") {
+ CHECK_ARG
+ params.think_tokens = thinking_tokens_from_string(std::string(argv[i]));
+ return true;
+ }
if (arg == "--reasoning-budget") {
CHECK_ARG
params.reasoning_budget = std::stoi(argv[i]);
@@ -2160,6 +2189,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
"negative prompt file to use for guidance" });
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
+ options.push_back({ "template" });
options.push_back({ "main", " --jinja",
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
@@ -2176,7 +2206,15 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`\n"
"(default: none)", });
options.push_back({ "main", " --chat-template-kwargs JSON", "sets additional params for the json template parser"});
- options.push_back({ "main", " --reasoning-budget N", "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)" });
+ options.push_back({ "main", " --reasoning-budget N", "controls the amount of thinking allowed.\n"
+ "currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking"
+ "(default: -1)" });
+ options.push_back({ "main", " --reasoning-tokens FORMAT", "exclude reasoning tokens to select the slot more accurately.\n"
+ "none: include all tokens\n"
+ "auto: exclude all tokens between and \n"
+ "Or comma separated start and end tokens such as [THINK],[/THINK]\n"
+ "(default: auto)" });
+
options.push_back({ "main", " --no-prefill-assistant", "whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n"
"when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n" });
options.push_back({ "grammar" });
diff --git a/common/common.h b/common/common.h
index e0b7f53b..f3a70beb 100644
--- a/common/common.h
+++ b/common/common.h
@@ -119,6 +119,15 @@ enum common_webui {
common_webui common_webui_from_name(const std::string& format);
+struct thinking_tokens {
+ bool exclude = true;
+ std::string begin = "";
+ std::string end = "";
+};
+
+thinking_tokens thinking_tokens_from_string(const std::string& format);
+
+
struct model_paths {
std::string path = ""; // model local path // NOLINT
std::string url = ""; // model url to download // NOLINT
@@ -314,6 +323,7 @@ struct gpt_params {
std::string system_prompt = "";
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
+ thinking_tokens think_tokens;
int reasoning_budget = -1;
bool prefill_assistant = true;
diff --git a/examples/server/server-common.cpp b/examples/server/server-common.cpp
index 6be7ff5f..42bdbc2c 100644
--- a/examples/server/server-common.cpp
+++ b/examples/server/server-common.cpp
@@ -1773,6 +1773,84 @@ server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mt
return max_idx; // all tokens are equal
}
+ llama_tokens server_tokens::get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const {
+ if (!think_token.exclude) {
+ return get_text_tokens();
+ }
+ GGML_ASSERT((think_token.begin != "" && think_token.end != "") && "think tokens cannot be empty");
+ std::string startStr = think_token.begin;
+ std::string endStr = think_token.end;
+
+ llama_tokens tokens = get_text_tokens();
+ std::string str = llama_detokenize(ctx, tokens, true);
+
+ std::vector> results;
+ // Find all positions of start and end
+ std::vector startPositions;
+ std::vector endPositions;
+
+ size_t pos = 0;
+ // Find all start positions
+ while ((pos = str.find(startStr, pos)) != std::string::npos) {
+ startPositions.push_back(pos);
+ pos += startStr.length();
+ }
+
+ pos = 0;
+ // Find all end positions
+ while ((pos = str.find(endStr, pos)) != std::string::npos) {
+ endPositions.push_back(pos + endStr.length());
+ pos += endStr.length();
+ }
+
+ // For each start position, pair with all end positions that come after it
+ for (size_t i = 0; i < startPositions.size(); i++) {
+ for (size_t j = 0; j < endPositions.size(); j++) {
+ if (results.size()) {
+ // start must be after last end
+ if (startPositions[i] > results[results.size() - 1].second && endPositions[j] > startPositions[i]) {
+ results.push_back({ startPositions[i], endPositions[j] });
+ break;
+ }
+ }
+ else {
+ if (endPositions[j] > startPositions[i]) {
+ results.push_back({ startPositions[i], endPositions[j] });
+ break;
+ }
+ }
+
+ }
+ }
+ if (!results.size()) {
+ return tokens;
+ }
+
+ // Exclude tokens
+ pos = 0;
+ size_t n = 0;
+ size_t string_len = 0;
+ llama_tokens tokens_new;
+ auto model = llama_get_model(ctx);
+ for (n = 0; n < tokens.size(); ++n) {
+ str = llama_token_to_piece(model, tokens[n], true);
+ string_len = string_len + str.size();
+ if (string_len <= results[pos].first) {
+ tokens_new.push_back(tokens[n]);
+ }
+ else if (string_len <= results[pos].second) {
+ continue;
+ }
+ else {
+ tokens_new.push_back(tokens[n]);
+ if (pos+1 < results.size()) {
+ pos++;
+ }
+ }
+ }
+ return tokens_new;
+ }
+
common_prefix server_tokens::get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact) const {
common_prefix token_prefix;
diff --git a/examples/server/server-common.h b/examples/server/server-common.h
index 2289a3c7..227e91db 100644
--- a/examples/server/server-common.h
+++ b/examples/server/server-common.h
@@ -171,6 +171,7 @@ std::string tokens_to_str(llama_context* ctx, const llama_tokens& tokens);
// format incomplete utf-8 multibyte character for output
std::string tokens_to_output_formatted_string(const llama_context* ctx, const llama_token token);
+
struct common_prefix {
size_t first = 0;
size_t second = 0;
@@ -389,6 +390,7 @@ public:
size_t get_common_prefix_exact(const server_tokens& b) const;
+ llama_tokens get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const;
common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const;
// take first n tokens of tokens list a
diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp
index 45ef8988..a3f3155a 100644
--- a/examples/server/server-context.cpp
+++ b/examples/server/server-context.cpp
@@ -176,7 +176,13 @@ void server_context::init() {
slot.n_predict = params.n_predict;
slot.mctx = mctx;
slot.cache_tokens.has_mtmd = mctx != nullptr;
-
+ slot.params.think_tokens = params.think_tokens;
+ if (params.think_tokens.exclude) {
+ SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params.think_tokens.begin.c_str(), params.think_tokens.end.c_str() );
+ }
+ else {
+ SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
+ }
LOG_INFO("new slot", {
{"id_slot", slot.id},
{"n_ctx_slot", slot.n_ctx}
@@ -585,6 +591,44 @@ server_slot* server_context::get_slot_by_id(int id) {
return nullptr;
}
+float server_context::calculate_slot_f_keep(const server_slot & slot, llama_context * ctx,const server_tokens & a, const server_tokens & b) {
+ float f_keep = 0.0f;
+ if (!a.empty()) {
+ if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
+ f_keep = a.get_cached_tokens_similarity(slot.ctx, b, slot.params.n_keep + add_bos_token, slot.n_discarded_prompt);
+ }
+ else {
+ f_keep = a.get_cached_tokens_similarity(slot.ctx, b, 0, 0);
+ }
+ }
+ return f_keep;
+}
+
+std::pair server_context::calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b) {
+ std::pair sim;
+ // length of the Longest Common Prefix between the current slot's prompt and the input prompt
+ common_prefix lcp_len = a.get_common_prefix(slot.ctx, b);
+ // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
+ float sim_cur = a.get_tokens_similarity(slot.ctx, b, 0, 0);
+ // handle context shift
+ if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
+ float sim_cur_ctx_shift = a.get_tokens_similarity(slot.ctx, b, slot.n_kept_prompt, slot.n_discarded_prompt);
+ if (sim_cur_ctx_shift > sim_cur) {
+ sim_cur = sim_cur_ctx_shift;
+ }
+ }
+ sim.first = lcp_len;
+ sim.second = sim_cur;
+ return sim;
+}
+
+void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
+ slot.server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
+ slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
+ slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
+ slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
+}
+
server_slot* server_context::get_available_slot(const server_task& task) {
server_slot* ret = nullptr;
bool update_cache = false;
@@ -599,22 +643,25 @@ server_slot* server_context::get_available_slot(const server_task& task) {
if (!slot.available()) {
continue;
}
- const auto& cache_tokens = slot.cache_tokens;
+ auto& cache_tokens = slot.cache_tokens;
// skip the slot if it does not contains prompt
if (cache_tokens.empty()) {
continue;
}
- // length of the Longest Common Prefix between the current slot's prompt and the input prompt
- auto lcp_len = cache_tokens.get_common_prefix(slot.ctx, task.tokens);
- // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
- float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0);
- // handle context shift
- if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size() >= slot.n_ctx) {
- float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt);
- if (sim_cur_ctx_shift > sim_cur) {
- sim_cur = sim_cur_ctx_shift;
- }
+ bool exclude_think = !cache_tokens.has_mtmd && slot.params.think_tokens.exclude;
+ std::pair sim;
+ if (exclude_think) {
+ auto temp = slot.cache_tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
+ server_tokens cache_tokens_exclude_think = server_tokens(temp, false);
+ temp = task.tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
+ server_tokens prompt_tokens_exclude_think = server_tokens(temp, false);
+ sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
}
+ else {
+ sim = calculate_slot_similarity(slot, ctx, cache_tokens, task.tokens);
+ }
+ common_prefix lcp_len = sim.first;
+ float sim_cur = sim.second;
// select the current slot if the criteria match
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
@@ -655,26 +702,36 @@ server_slot* server_context::get_available_slot(const server_task& task) {
}
}
if (ret) {
- const auto& tokens = ret->cache_tokens;
- float f_keep = 0.0f;
+ auto& tokens = ret->cache_tokens;
+ float f_keep = 0;
+ size_t cache_token_size = tokens.size();
if (!tokens.empty()) {
- if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) {
- f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt);
+ bool exclude_think = !tokens.has_mtmd && ret->params.think_tokens.exclude;
+ if (exclude_think) {
+ auto temp = tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
+ server_tokens cache_exclude_think = server_tokens(temp, false);
+
+ temp = task.tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
+ server_tokens prompt_exclude_think = server_tokens(temp, false);
+
+ cache_token_size = cache_exclude_think.size();
+ f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
}
else {
- f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, 0, 0);
+ f_keep = calculate_slot_f_keep(*ret, ret->ctx, tokens, task.tokens);
}
// if we are about to lose a large portion of the existing context - save it in the prompt cache
if (f_keep < cache_ram_similarity) {
update_cache = true;
}
}
+
update_cache = update_cache && prompt_cache;
// cache prompts only for completion tasks
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
// don't update the cache if the slot's context is above cache_ram_n_min
- update_cache = update_cache && tokens.size() >= cache_ram_n_min;
+ update_cache = update_cache && cache_token_size >= cache_ram_n_min;
// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);
@@ -684,9 +741,8 @@ server_slot* server_context::get_available_slot(const server_task& task) {
if (update_cache) {
const int64_t t_start = ggml_time_us();
LLAMA_LOG_INFO("updating prompt cache\n");
- ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
- ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
- ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
+ // copy cache tokens
+ copy_data_to_cached_prompt(tokens, *ret);
ret->prompt_save(*prompt_cache);
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
@@ -694,9 +750,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
// has prompts saved earlier to load
if (prompt_cache && !prompt_cache->states.empty()) {
const int64_t t_start = ggml_time_us();
- ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
- ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
- ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
+ copy_data_to_cached_prompt(tokens, *ret);
ret->prompt_load(*prompt_cache, task.tokens);
prompt_cache->update();
@@ -1959,7 +2013,6 @@ void server_context::context_shift_find_n_tokens(llama_context* ctx, const serve
}
void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) {
- //server_tokens prompt_tokens = std::move(slot.prompt_tokens);
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
const int n_left = slot.n_ctx - n_keep;
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
diff --git a/examples/server/server-context.h b/examples/server/server-context.h
index 568db95a..a53d4786 100644
--- a/examples/server/server-context.h
+++ b/examples/server/server-context.h
@@ -34,6 +34,8 @@ struct slot_params {
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
+ thinking_tokens think_tokens;
+
std::vector antiprompt;
bool timings_per_token = false;
@@ -259,6 +261,12 @@ struct server_context {
server_slot* get_slot_by_id(int id);
+ float calculate_slot_f_keep(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b);
+
+ std::pair calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b);
+
+ void copy_data_to_cached_prompt(const server_tokens& tokens, server_slot& slot);
+
server_slot* get_available_slot(const server_task& task);
bool launch_slot_with_task(server_slot& slot, server_task& task);
@@ -302,12 +310,14 @@ struct server_context {
void print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1 = 0, size_t start2 = 0, size_t length = 10);
+ // discard tokens in kv cache and cached tokens
void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard);
// convert keep first few and discard next tokens in a to b
void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false);
+ // handle context shift for prompt
void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false);
void update_slots();
diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp
index 335dc85c..f2121c64 100644
--- a/examples/server/server-task.cpp
+++ b/examples/server/server-task.cpp
@@ -692,19 +692,40 @@ size_t server_prompt_cache::n_tokens() const {
}
bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
- const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new);
-
- float f_keep_best = float(lcp_best.second) / prompt.tokens.size();
- float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt);
+ thinking_tokens think_tokens;
+ for (auto it = states.begin(); it != states.end(); ++it) {
+ think_tokens = it->think_tokens;
+ break;
+ }
+ server_tokens prompt_tokens;
+ server_tokens tokens_new_ex;
+ if (think_tokens.exclude) {
+ prompt_tokens = server_tokens(prompt.tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
+ tokens_new_ex = server_tokens(tokens_new.get_text_tokens_exclude_think(ctx, think_tokens), false);
+ }
+ else {
+ prompt_tokens = std::move(prompt.tokens); //server_tokens(prompt.tokens.get_text_tokens(), false);
+ tokens_new_ex = server_tokens(tokens_new.get_text_tokens(), false);
+ }
+ const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex);
+ float f_keep_best = float(lcp_best.second) / prompt_tokens.size();
+ float sim_best = prompt_tokens.get_tokens_similarity(ctx, tokens_new_ex, prompt.n_kept_prompt, prompt.n_discarded_prompt);
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt);
auto it_best = states.end();
// find the most similar cached prompt, that would also preserve the most context
for (auto it = states.begin(); it != states.end(); ++it) {
- const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new);
- const float f_keep_cur = float(lcp_cur.first) / it->tokens.size();
- const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt);
+ server_tokens tokens;
+ if (think_tokens.exclude) {
+ tokens = server_tokens(it->tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
+ }
+ else {
+ tokens = std::move(it->tokens);
+ }
+ const auto lcp_cur = tokens.get_common_prefix(ctx, tokens_new_ex);
+ const float f_keep_cur = float(lcp_cur.first) / tokens.size();
+ const float sim_cur = tokens.get_tokens_similarity(ctx, tokens_new_ex, it->n_kept_prompt, it->n_discarded_prompt);
if (sim_best < sim_cur) {
f_keep_best = f_keep_cur;
sim_best = sim_cur;
@@ -778,6 +799,7 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.n_keep =*/ prompt.n_kept_prompt,
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
+ /*.think_tokens =*/ prompt.think_tokens,
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};
diff --git a/examples/server/server-task.h b/examples/server/server-task.h
index 4be2e001..f8537e72 100644
--- a/examples/server/server-task.h
+++ b/examples/server/server-task.h
@@ -177,6 +177,7 @@ struct server_prompt {
server_tokens tokens;
int n_kept_prompt;
int n_discarded_prompt;
+ thinking_tokens think_tokens;
std::vector data;