server: exclude thinking tokens when finding the slot (#1079)

refactor find slot

enable by default

Fix load prompt

rename variables

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-12-22 02:46:45 -06:00
committed by GitHub
parent 21fc9322f9
commit 5562605076
8 changed files with 247 additions and 33 deletions

View File

@@ -1773,6 +1773,84 @@ server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mt
return max_idx; // all tokens are equal
}
llama_tokens server_tokens::get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const {
if (!think_token.exclude) {
return get_text_tokens();
}
GGML_ASSERT((think_token.begin != "" && think_token.end != "") && "think tokens cannot be empty");
std::string startStr = think_token.begin;
std::string endStr = think_token.end;
llama_tokens tokens = get_text_tokens();
std::string str = llama_detokenize(ctx, tokens, true);
std::vector<std::pair<size_t, size_t>> results;
// Find all positions of start and end
std::vector<size_t> startPositions;
std::vector<size_t> endPositions;
size_t pos = 0;
// Find all start positions
while ((pos = str.find(startStr, pos)) != std::string::npos) {
startPositions.push_back(pos);
pos += startStr.length();
}
pos = 0;
// Find all end positions
while ((pos = str.find(endStr, pos)) != std::string::npos) {
endPositions.push_back(pos + endStr.length());
pos += endStr.length();
}
// For each start position, pair with all end positions that come after it
for (size_t i = 0; i < startPositions.size(); i++) {
for (size_t j = 0; j < endPositions.size(); j++) {
if (results.size()) {
// start must be after last end
if (startPositions[i] > results[results.size() - 1].second && endPositions[j] > startPositions[i]) {
results.push_back({ startPositions[i], endPositions[j] });
break;
}
}
else {
if (endPositions[j] > startPositions[i]) {
results.push_back({ startPositions[i], endPositions[j] });
break;
}
}
}
}
if (!results.size()) {
return tokens;
}
// Exclude tokens
pos = 0;
size_t n = 0;
size_t string_len = 0;
llama_tokens tokens_new;
auto model = llama_get_model(ctx);
for (n = 0; n < tokens.size(); ++n) {
str = llama_token_to_piece(model, tokens[n], true);
string_len = string_len + str.size();
if (string_len <= results[pos].first) {
tokens_new.push_back(tokens[n]);
}
else if (string_len <= results[pos].second) {
continue;
}
else {
tokens_new.push_back(tokens[n]);
if (pos+1 < results.size()) {
pos++;
}
}
}
return tokens_new;
}
common_prefix server_tokens::get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact) const {
common_prefix token_prefix;