diff --git a/common/common.cpp b/common/common.cpp index 664a6f4c..17b9f72d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -253,6 +253,30 @@ common_webui common_webui_from_name(const std::string& format) { } } +thinking_tokens thinking_tokens_from_string(const std::string& format) { + thinking_tokens think_token; + std::string token_string = string_strip(format); + if (token_string == "none" || token_string == "None") { + think_token.exclude = false; + return think_token; + } + else if (token_string == "auto" || token_string == "Auto") { + think_token.exclude = true; + think_token.begin = ""; + think_token.end = ""; + return think_token; + } + // Use user provided think tokens + auto start_end = string_split(format, ","); + if (start_end.size() == 2) { + think_token.exclude = true; + think_token.begin = start_end[0]; + think_token.end = start_end[1]; + } + return think_token; +} + + static std::string read_file(const std::string& fname) { std::ifstream file(fname); if (!file) { @@ -1745,6 +1769,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--reasoning-tokens") { + CHECK_ARG + params.think_tokens = thinking_tokens_from_string(std::string(argv[i])); + return true; + } if (arg == "--reasoning-budget") { CHECK_ARG params.reasoning_budget = std::stoi(argv[i]); @@ -2160,6 +2189,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "main", " --cfg-negative-prompt-file FNAME", "negative prompt file to use for guidance" }); options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); + options.push_back({ "template" }); options.push_back({ "main", " --jinja", "set custom jinja chat template (default: template taken from model's metadata)\n" "if suffix/prefix are specified, template will be disabled\n" @@ -2176,7 +2206,15 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`\n" "(default: none)", }); options.push_back({ "main", " --chat-template-kwargs JSON", "sets additional params for the json template parser"}); - options.push_back({ "main", " --reasoning-budget N", "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)" }); + options.push_back({ "main", " --reasoning-budget N", "controls the amount of thinking allowed.\n" + "currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking" + "(default: -1)" }); + options.push_back({ "main", " --reasoning-tokens FORMAT", "exclude reasoning tokens to select the slot more accurately.\n" + "none: include all tokens\n" + "auto: exclude all tokens between and \n" + "Or comma separated start and end tokens such as [THINK],[/THINK]\n" + "(default: auto)" }); + options.push_back({ "main", " --no-prefill-assistant", "whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n" "when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n" }); options.push_back({ "grammar" }); diff --git a/common/common.h b/common/common.h index e0b7f53b..f3a70beb 100644 --- a/common/common.h +++ b/common/common.h @@ -119,6 +119,15 @@ enum common_webui { common_webui common_webui_from_name(const std::string& format); +struct thinking_tokens { + bool exclude = true; + std::string begin = ""; + std::string end = ""; +}; + +thinking_tokens thinking_tokens_from_string(const std::string& format); + + struct model_paths { std::string path = ""; // model local path // NOLINT std::string url = ""; // model url to download // NOLINT @@ -314,6 +323,7 @@ struct gpt_params { std::string system_prompt = ""; bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + thinking_tokens think_tokens; int reasoning_budget = -1; bool prefill_assistant = true; diff --git a/examples/server/server-common.cpp b/examples/server/server-common.cpp index 6be7ff5f..42bdbc2c 100644 --- a/examples/server/server-common.cpp +++ b/examples/server/server-common.cpp @@ -1773,6 +1773,84 @@ server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mt return max_idx; // all tokens are equal } + llama_tokens server_tokens::get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const { + if (!think_token.exclude) { + return get_text_tokens(); + } + GGML_ASSERT((think_token.begin != "" && think_token.end != "") && "think tokens cannot be empty"); + std::string startStr = think_token.begin; + std::string endStr = think_token.end; + + llama_tokens tokens = get_text_tokens(); + std::string str = llama_detokenize(ctx, tokens, true); + + std::vector> results; + // Find all positions of start and end + std::vector startPositions; + std::vector endPositions; + + size_t pos = 0; + // Find all start positions + while ((pos = str.find(startStr, pos)) != std::string::npos) { + startPositions.push_back(pos); + pos += startStr.length(); + } + + pos = 0; + // Find all end positions + while ((pos = str.find(endStr, pos)) != std::string::npos) { + endPositions.push_back(pos + endStr.length()); + pos += endStr.length(); + } + + // For each start position, pair with all end positions that come after it + for (size_t i = 0; i < startPositions.size(); i++) { + for (size_t j = 0; j < endPositions.size(); j++) { + if (results.size()) { + // start must be after last end + if (startPositions[i] > results[results.size() - 1].second && endPositions[j] > startPositions[i]) { + results.push_back({ startPositions[i], endPositions[j] }); + break; + } + } + else { + if (endPositions[j] > startPositions[i]) { + results.push_back({ startPositions[i], endPositions[j] }); + break; + } + } + + } + } + if (!results.size()) { + return tokens; + } + + // Exclude tokens + pos = 0; + size_t n = 0; + size_t string_len = 0; + llama_tokens tokens_new; + auto model = llama_get_model(ctx); + for (n = 0; n < tokens.size(); ++n) { + str = llama_token_to_piece(model, tokens[n], true); + string_len = string_len + str.size(); + if (string_len <= results[pos].first) { + tokens_new.push_back(tokens[n]); + } + else if (string_len <= results[pos].second) { + continue; + } + else { + tokens_new.push_back(tokens[n]); + if (pos+1 < results.size()) { + pos++; + } + } + } + return tokens_new; + } + common_prefix server_tokens::get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact) const { common_prefix token_prefix; diff --git a/examples/server/server-common.h b/examples/server/server-common.h index 2289a3c7..227e91db 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -171,6 +171,7 @@ std::string tokens_to_str(llama_context* ctx, const llama_tokens& tokens); // format incomplete utf-8 multibyte character for output std::string tokens_to_output_formatted_string(const llama_context* ctx, const llama_token token); + struct common_prefix { size_t first = 0; size_t second = 0; @@ -389,6 +390,7 @@ public: size_t get_common_prefix_exact(const server_tokens& b) const; + llama_tokens get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const; common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const; // take first n tokens of tokens list a diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 45ef8988..a3f3155a 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -176,7 +176,13 @@ void server_context::init() { slot.n_predict = params.n_predict; slot.mctx = mctx; slot.cache_tokens.has_mtmd = mctx != nullptr; - + slot.params.think_tokens = params.think_tokens; + if (params.think_tokens.exclude) { + SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params.think_tokens.begin.c_str(), params.think_tokens.end.c_str() ); + } + else { + SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n"); + } LOG_INFO("new slot", { {"id_slot", slot.id}, {"n_ctx_slot", slot.n_ctx} @@ -585,6 +591,44 @@ server_slot* server_context::get_slot_by_id(int id) { return nullptr; } +float server_context::calculate_slot_f_keep(const server_slot & slot, llama_context * ctx,const server_tokens & a, const server_tokens & b) { + float f_keep = 0.0f; + if (!a.empty()) { + if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) { + f_keep = a.get_cached_tokens_similarity(slot.ctx, b, slot.params.n_keep + add_bos_token, slot.n_discarded_prompt); + } + else { + f_keep = a.get_cached_tokens_similarity(slot.ctx, b, 0, 0); + } + } + return f_keep; +} + +std::pair server_context::calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b) { + std::pair sim; + // length of the Longest Common Prefix between the current slot's prompt and the input prompt + common_prefix lcp_len = a.get_common_prefix(slot.ctx, b); + // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length + float sim_cur = a.get_tokens_similarity(slot.ctx, b, 0, 0); + // handle context shift + if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) { + float sim_cur_ctx_shift = a.get_tokens_similarity(slot.ctx, b, slot.n_kept_prompt, slot.n_discarded_prompt); + if (sim_cur_ctx_shift > sim_cur) { + sim_cur = sim_cur_ctx_shift; + } + } + sim.first = lcp_len; + sim.second = sim_cur; + return sim; +} + +void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) { + slot.server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt; + slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt; + slot.server_cached_prompt.think_tokens = slot.params.think_tokens; +} + server_slot* server_context::get_available_slot(const server_task& task) { server_slot* ret = nullptr; bool update_cache = false; @@ -599,22 +643,25 @@ server_slot* server_context::get_available_slot(const server_task& task) { if (!slot.available()) { continue; } - const auto& cache_tokens = slot.cache_tokens; + auto& cache_tokens = slot.cache_tokens; // skip the slot if it does not contains prompt if (cache_tokens.empty()) { continue; } - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - auto lcp_len = cache_tokens.get_common_prefix(slot.ctx, task.tokens); - // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length - float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0); - // handle context shift - if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size() >= slot.n_ctx) { - float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt); - if (sim_cur_ctx_shift > sim_cur) { - sim_cur = sim_cur_ctx_shift; - } + bool exclude_think = !cache_tokens.has_mtmd && slot.params.think_tokens.exclude; + std::pair sim; + if (exclude_think) { + auto temp = slot.cache_tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens); + server_tokens cache_tokens_exclude_think = server_tokens(temp, false); + temp = task.tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens); + server_tokens prompt_tokens_exclude_think = server_tokens(temp, false); + sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think); } + else { + sim = calculate_slot_similarity(slot, ctx, cache_tokens, task.tokens); + } + common_prefix lcp_len = sim.first; + float sim_cur = sim.second; // select the current slot if the criteria match if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { @@ -655,26 +702,36 @@ server_slot* server_context::get_available_slot(const server_task& task) { } } if (ret) { - const auto& tokens = ret->cache_tokens; - float f_keep = 0.0f; + auto& tokens = ret->cache_tokens; + float f_keep = 0; + size_t cache_token_size = tokens.size(); if (!tokens.empty()) { - if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) { - f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt); + bool exclude_think = !tokens.has_mtmd && ret->params.think_tokens.exclude; + if (exclude_think) { + auto temp = tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens); + server_tokens cache_exclude_think = server_tokens(temp, false); + + temp = task.tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens); + server_tokens prompt_exclude_think = server_tokens(temp, false); + + cache_token_size = cache_exclude_think.size(); + f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think); } else { - f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, 0, 0); + f_keep = calculate_slot_f_keep(*ret, ret->ctx, tokens, task.tokens); } // if we are about to lose a large portion of the existing context - save it in the prompt cache if (f_keep < cache_ram_similarity) { update_cache = true; } } + update_cache = update_cache && prompt_cache; // cache prompts only for completion tasks update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; // don't update the cache if the slot's context is above cache_ram_n_min - update_cache = update_cache && tokens.size() >= cache_ram_n_min; + update_cache = update_cache && cache_token_size >= cache_ram_n_min; // TODO: mtmd does not support prompt cache update_cache = update_cache && (ret->mctx == nullptr); @@ -684,9 +741,8 @@ server_slot* server_context::get_available_slot(const server_task& task) { if (update_cache) { const int64_t t_start = ggml_time_us(); LLAMA_LOG_INFO("updating prompt cache\n"); - ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; - ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; + // copy cache tokens + copy_data_to_cached_prompt(tokens, *ret); ret->prompt_save(*prompt_cache); LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); @@ -694,9 +750,7 @@ server_slot* server_context::get_available_slot(const server_task& task) { // has prompts saved earlier to load if (prompt_cache && !prompt_cache->states.empty()) { const int64_t t_start = ggml_time_us(); - ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; - ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; + copy_data_to_cached_prompt(tokens, *ret); ret->prompt_load(*prompt_cache, task.tokens); prompt_cache->update(); @@ -1959,7 +2013,6 @@ void server_context::context_shift_find_n_tokens(llama_context* ctx, const serve } void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) { - //server_tokens prompt_tokens = std::move(slot.prompt_tokens); int n_keep = std::max(0, slot.params.n_keep + add_bos_token); const int n_left = slot.n_ctx - n_keep; int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 568db95a..a53d4786 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -34,6 +34,8 @@ struct slot_params { int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_predict = -1; // new tokens to predict + thinking_tokens think_tokens; + std::vector antiprompt; bool timings_per_token = false; @@ -259,6 +261,12 @@ struct server_context { server_slot* get_slot_by_id(int id); + float calculate_slot_f_keep(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b); + + std::pair calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b); + + void copy_data_to_cached_prompt(const server_tokens& tokens, server_slot& slot); + server_slot* get_available_slot(const server_task& task); bool launch_slot_with_task(server_slot& slot, server_task& task); @@ -302,12 +310,14 @@ struct server_context { void print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1 = 0, size_t start2 = 0, size_t length = 10); + // discard tokens in kv cache and cached tokens void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard); // convert keep first few and discard next tokens in a to b void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep, int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false); + // handle context shift for prompt void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false); void update_slots(); diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp index 335dc85c..f2121c64 100644 --- a/examples/server/server-task.cpp +++ b/examples/server/server-task.cpp @@ -692,19 +692,40 @@ size_t server_prompt_cache::n_tokens() const { } bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) { - const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new); - - float f_keep_best = float(lcp_best.second) / prompt.tokens.size(); - float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt); + thinking_tokens think_tokens; + for (auto it = states.begin(); it != states.end(); ++it) { + think_tokens = it->think_tokens; + break; + } + server_tokens prompt_tokens; + server_tokens tokens_new_ex; + if (think_tokens.exclude) { + prompt_tokens = server_tokens(prompt.tokens.get_text_tokens_exclude_think(ctx, think_tokens), false); + tokens_new_ex = server_tokens(tokens_new.get_text_tokens_exclude_think(ctx, think_tokens), false); + } + else { + prompt_tokens = std::move(prompt.tokens); //server_tokens(prompt.tokens.get_text_tokens(), false); + tokens_new_ex = server_tokens(tokens_new.get_text_tokens(), false); + } + const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex); + float f_keep_best = float(lcp_best.second) / prompt_tokens.size(); + float sim_best = prompt_tokens.get_tokens_similarity(ctx, tokens_new_ex, prompt.n_kept_prompt, prompt.n_discarded_prompt); LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt); auto it_best = states.end(); // find the most similar cached prompt, that would also preserve the most context for (auto it = states.begin(); it != states.end(); ++it) { - const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new); - const float f_keep_cur = float(lcp_cur.first) / it->tokens.size(); - const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt); + server_tokens tokens; + if (think_tokens.exclude) { + tokens = server_tokens(it->tokens.get_text_tokens_exclude_think(ctx, think_tokens), false); + } + else { + tokens = std::move(it->tokens); + } + const auto lcp_cur = tokens.get_common_prefix(ctx, tokens_new_ex); + const float f_keep_cur = float(lcp_cur.first) / tokens.size(); + const float sim_cur = tokens.get_tokens_similarity(ctx, tokens_new_ex, it->n_kept_prompt, it->n_discarded_prompt); if (sim_best < sim_cur) { f_keep_best = f_keep_cur; sim_best = sim_cur; @@ -778,6 +799,7 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), /*.n_keep =*/ prompt.n_kept_prompt, /*.n_discarded_prompt =*/ prompt.n_discarded_prompt, + /*.think_tokens =*/ prompt.think_tokens, /*.data =*/ std::move(state_data), /*.checkpoints =*/ prompt.checkpoints, }; diff --git a/examples/server/server-task.h b/examples/server/server-task.h index 4be2e001..f8537e72 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -177,6 +177,7 @@ struct server_prompt { server_tokens tokens; int n_kept_prompt; int n_discarded_prompt; + thinking_tokens think_tokens; std::vector data;