mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-11 22:40:01 +00:00
Make string ban more robust and add regex ban (#1243)
* Test new ctx_sampling->n_rewind system * CRLF quickfix * Adaptive p check * merge banned_n * Fix attempt 1 * Fix attempt 2
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include <regex>
|
||||
|
||||
server_context::~server_context() {
|
||||
if (ctx) {
|
||||
@@ -375,6 +376,11 @@ void server_slot::reset() {
|
||||
generated_token_probs.clear();
|
||||
checkpoint_pos = 0;
|
||||
|
||||
positional_bans.clear();
|
||||
ban_phrases.clear();
|
||||
ban_regex.clear();
|
||||
ban_regex_ci.clear();
|
||||
|
||||
// Reset speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
n_draft_accepted = 0;
|
||||
@@ -864,6 +870,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||
slot.saturate_predict = json_value(data, "saturate_predict", false);
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
@@ -1216,6 +1223,10 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
|
||||
{
|
||||
// ban string
|
||||
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
|
||||
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
|
||||
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
|
||||
|
||||
const auto& banned_strings = data.find("banned_strings");
|
||||
if (banned_strings != data.end() && banned_strings->is_array()) {
|
||||
slot.ban_phrases.clear();
|
||||
@@ -1224,15 +1235,14 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
std::string s = val.get<std::string>();
|
||||
if (!s.empty()) {
|
||||
s = string_lower(s);
|
||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
|
||||
if (ban_tokens.size() > slot.n_buffer) {
|
||||
slot.n_buffer = ban_tokens.size();
|
||||
// Use string length instead of token count
|
||||
if (s.length() > slot.n_buffer) {
|
||||
slot.n_buffer = s.length();
|
||||
}
|
||||
slot.ban_phrases.push_back(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
||||
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
||||
return a.length() > b.length();
|
||||
});
|
||||
@@ -1245,20 +1255,72 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
for (auto & val : params_base.ban_phrases) {
|
||||
if (!val.empty()) {
|
||||
val = string_lower(val);
|
||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true);
|
||||
if (ban_tokens.size() > slot.n_buffer) {
|
||||
slot.n_buffer = ban_tokens.size();
|
||||
// Use string length instead of token count
|
||||
if (val.length() > slot.n_buffer) {
|
||||
slot.n_buffer = val.length();
|
||||
}
|
||||
slot.ban_phrases.push_back(val);
|
||||
}
|
||||
}
|
||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
||||
params_base.n_buffer = slot.n_buffer;
|
||||
params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1
|
||||
} else {
|
||||
slot.ban_phrases = params_base.ban_phrases;
|
||||
slot.n_buffer = params_base.n_buffer;
|
||||
}
|
||||
}
|
||||
|
||||
// ban regex
|
||||
slot.ban_regex.clear();
|
||||
const auto& banned_regex = data.find("banned_regex");
|
||||
if (banned_regex != data.end() && banned_regex->is_array()) {
|
||||
for (const auto& val : data["banned_regex"]) {
|
||||
if (val.is_string()) {
|
||||
std::string s = val.get<std::string>();
|
||||
if (!s.empty()) {
|
||||
try {
|
||||
std::regex re(s);
|
||||
slot.ban_regex.push_back(s);
|
||||
if (s.length() > slot.n_buffer) {
|
||||
slot.n_buffer = s.length();
|
||||
}
|
||||
} catch (const std::regex_error& e) {
|
||||
send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ban regex case insensitive
|
||||
slot.ban_regex_ci.clear();
|
||||
const auto& banned_regex_ci = data.find("banned_regex_case_insensitive");
|
||||
if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) {
|
||||
for (const auto& val : data["banned_regex_case_insensitive"]) {
|
||||
if (val.is_string()) {
|
||||
std::string s = val.get<std::string>();
|
||||
if (!s.empty()) {
|
||||
try {
|
||||
std::regex re(s, std::regex_constants::icase);
|
||||
slot.ban_regex_ci.push_back(s);
|
||||
if (s.length() > slot.n_buffer) {
|
||||
slot.n_buffer = s.length();
|
||||
}
|
||||
} catch (const std::regex_error& e) {
|
||||
send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (banbuffer_size > 0) {
|
||||
slot.n_buffer = banbuffer_size;
|
||||
} else {
|
||||
slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1
|
||||
}
|
||||
|
||||
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
|
||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
||||
@@ -3261,12 +3323,24 @@ void server_context::release_slot_after_final_response(server_slot & slot) {
|
||||
|
||||
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
||||
int count = 0;
|
||||
bool released = false;
|
||||
|
||||
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||
|
||||
for (auto& it : results) {
|
||||
bool has_next = process_token(it, slot);
|
||||
|
||||
// Clean up positional bans for the token we just confirmed/sent
|
||||
slot.positional_bans.erase(start_pos + count);
|
||||
|
||||
count++;
|
||||
if (!has_next) {
|
||||
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
||||
continue;
|
||||
}
|
||||
send_final_response(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
released = true;
|
||||
break;
|
||||
}
|
||||
if (n > 0 && count >= n) {
|
||||
@@ -3274,6 +3348,11 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
||||
}
|
||||
}
|
||||
|
||||
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
||||
send_final_response(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
slot.sampled = results[results.size()-1].tok;
|
||||
results.erase(results.begin(), results.begin() + count);
|
||||
@@ -3281,50 +3360,100 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
||||
|
||||
}
|
||||
|
||||
inline int32_t check_ban_phrase(const server_slot& slot) {
|
||||
bool found = false;
|
||||
size_t n = slot.token_buffer.size();
|
||||
size_t start;
|
||||
int32_t n_rewind = 0;
|
||||
inline int32_t check_ban_phrase(server_slot& slot) {
|
||||
if (slot.token_buffer.empty()) return -1;
|
||||
|
||||
std::string string_buffer;
|
||||
llama_tokens tokens;
|
||||
for (auto& it : slot.token_buffer) {
|
||||
string_buffer = string_buffer + it.text_to_send;
|
||||
tokens.push_back(it.tok);
|
||||
std::vector<size_t> token_offsets;
|
||||
|
||||
for (const auto& it : slot.token_buffer) {
|
||||
token_offsets.push_back(string_buffer.size());
|
||||
string_buffer += it.text_to_send;
|
||||
}
|
||||
string_buffer = string_lower(string_buffer);
|
||||
for (auto it : slot.ban_phrases) {
|
||||
start = string_buffer.find(it);
|
||||
// has been sorted from longest to shortest
|
||||
|
||||
size_t best_start = std::string::npos;
|
||||
bool found = false;
|
||||
std::string string_buffer_lower = string_lower(string_buffer);
|
||||
|
||||
// 1. Check strings
|
||||
for (const auto& phrase : slot.ban_phrases) {
|
||||
size_t start = string_buffer_lower.find(phrase);
|
||||
if (start != std::string::npos) {
|
||||
found = true;
|
||||
break;
|
||||
if (start < best_start) {
|
||||
best_start = start;
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
std::vector<size_t> unused;
|
||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str());
|
||||
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
|
||||
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
|
||||
|
||||
// 2. Check regex
|
||||
for (const auto& pattern : slot.ban_regex) {
|
||||
try {
|
||||
std::regex re(pattern);
|
||||
std::smatch match;
|
||||
if (std::regex_search(string_buffer, match, re)) {
|
||||
if (match.position() < best_start) {
|
||||
best_start = match.position();
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
} catch (...) { continue; }
|
||||
}
|
||||
return n_rewind;
|
||||
|
||||
// 3. Check regex case insensitive
|
||||
for (const auto& pattern : slot.ban_regex_ci) {
|
||||
try {
|
||||
std::regex re(pattern, std::regex_constants::icase);
|
||||
std::smatch match;
|
||||
if (std::regex_search(string_buffer, match, re)) {
|
||||
if (match.position() < best_start) {
|
||||
best_start = match.position();
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
} catch (...) { continue; }
|
||||
}
|
||||
|
||||
if (found) {
|
||||
int32_t token_idx = -1;
|
||||
for (size_t i = 0; i < token_offsets.size(); ++i) {
|
||||
size_t len = (i == token_offsets.size() - 1)
|
||||
? string_buffer.size() - token_offsets[i]
|
||||
: token_offsets[i+1] - token_offsets[i];
|
||||
|
||||
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
|
||||
token_idx = (int32_t)i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (token_idx != -1) {
|
||||
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
|
||||
return abs_pos;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
||||
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
slot.rewind_count++;
|
||||
int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind;
|
||||
std::set<llama_token> tokens;
|
||||
// ban all tokens for better coherence
|
||||
|
||||
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
||||
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
||||
|
||||
if (slot.banned_n != 0) {
|
||||
int32_t n = 0;
|
||||
for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++)
|
||||
{
|
||||
if (!tokens.contains(result->tok)) {
|
||||
slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias;
|
||||
}
|
||||
else {
|
||||
tokens.insert(result->tok);
|
||||
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
|
||||
llama_token banned_tok = result->tok;
|
||||
|
||||
if (n == 0) {
|
||||
LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
||||
ban_pos, banned_tok, result->text_to_send.c_str());
|
||||
}
|
||||
|
||||
slot.positional_bans[ban_pos].insert(banned_tok);
|
||||
n++;
|
||||
if (slot.banned_n > 0 && n == slot.banned_n) {
|
||||
break;
|
||||
@@ -3332,52 +3461,114 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
||||
}
|
||||
}
|
||||
|
||||
slot.token_buffer.resize(n_keep_rewind);
|
||||
size_t n_keep = slot.cache_tokens.size() - n_rewind;
|
||||
slot.sampled = slot.cache_tokens[n_keep];
|
||||
slot.cache_tokens.keep_first(n_keep);
|
||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1);
|
||||
|
||||
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
|
||||
|
||||
size_t n_keep_cache = 0;
|
||||
if (ban_pos > 0) {
|
||||
n_keep_cache = (size_t)(ban_pos - 1);
|
||||
}
|
||||
|
||||
if (n_keep_cache > slot.cache_tokens.size()) {
|
||||
n_keep_cache = slot.cache_tokens.size();
|
||||
}
|
||||
|
||||
if (n_keep_cache < slot.cache_tokens.size()) {
|
||||
slot.sampled = slot.cache_tokens[n_keep_cache];
|
||||
} else {
|
||||
slot.sampled = 0;
|
||||
}
|
||||
|
||||
// Truncate cache
|
||||
slot.cache_tokens.keep_first(n_keep_cache);
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
|
||||
// Remove from KV cache
|
||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
|
||||
|
||||
// Truncate buffer
|
||||
slot.token_buffer.resize(n_keep_buffer);
|
||||
|
||||
// Adjust decoded count
|
||||
if (slot.saturate_predict) {
|
||||
slot.n_decoded -= n_rewind_total;
|
||||
if (slot.n_decoded < 0) slot.n_decoded = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
||||
slot.token_buffer.push_back(result);
|
||||
|
||||
bool next_token = has_next_token(result, slot);
|
||||
bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token;
|
||||
// If buffer full or generation stopped, we might send tokens
|
||||
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
|
||||
|
||||
int32_t ban_pos = -1;
|
||||
int32_t n_rewind = 0;
|
||||
bool sent_results = false;
|
||||
// don't restore if last time was also rewind
|
||||
if (!slot.rewind_status) {
|
||||
slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias
|
||||
|
||||
// Always reset logit bias to base before checking bans
|
||||
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
||||
|
||||
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
|
||||
ban_pos = check_ban_phrase(slot);
|
||||
if (ban_pos >= 0 && slot.sparams.adaptive_target >= 0.0f) {
|
||||
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
||||
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
||||
n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer;
|
||||
}
|
||||
}
|
||||
if (slot.ban_phrases.size() > 0) {
|
||||
n_rewind = check_ban_phrase(slot);
|
||||
|
||||
bool allow_rewind = true;
|
||||
|
||||
if (ban_pos >= 0) {
|
||||
if (slot.rewind_count_max == -1) {
|
||||
// Automatic / Heuristic logic
|
||||
// Account for strings + regex + regex_ci
|
||||
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
|
||||
|
||||
// Heuristic: Allow if under 20 OR under 2 * total_bans
|
||||
// Conversely: Stop if >= 20 AND > 2 * total_bans
|
||||
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
|
||||
allow_rewind = false;
|
||||
}
|
||||
}
|
||||
else if (slot.rewind_count_max > 0) {
|
||||
// Strict limit logic
|
||||
if (slot.rewind_count >= slot.rewind_count_max) {
|
||||
allow_rewind = false;
|
||||
}
|
||||
}
|
||||
// If slot.rewind_count_max == 0, allow_rewind remains true (Infinite)
|
||||
}
|
||||
// if found string in the ban
|
||||
if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) {
|
||||
rewind_context(slot, n_rewind);
|
||||
|
||||
if (ban_pos >= 0 && allow_rewind) {
|
||||
rewind_context(slot, ban_pos);
|
||||
slot.rewind_status = true;
|
||||
}
|
||||
else if (send_result) {
|
||||
else if (buffer_full || !next_token) {
|
||||
slot.rewind_status = false;
|
||||
slot.rewind_count = 0;
|
||||
|
||||
if (!next_token) {
|
||||
// send all remaining tokens in the buffer
|
||||
// send all remaining tokens
|
||||
send_token_results(slot.token_buffer, slot);
|
||||
}
|
||||
else {
|
||||
// send 1 token
|
||||
// send 1 token from the front (FIFO)
|
||||
send_token_results(slot.token_buffer, slot, 1);
|
||||
}
|
||||
sent_results = true;
|
||||
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||
sent_results = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// buffer the result
|
||||
slot.sampled = result.tok; // for common batch add
|
||||
// buffer the result, wait for more tokens to validate string
|
||||
slot.sampled = result.tok;
|
||||
}
|
||||
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
||||
}
|
||||
|
||||
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
||||
}
|
||||
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
@@ -3465,6 +3656,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue; // sample using speculative decoding
|
||||
}
|
||||
|
||||
// RESTORE AND APPLY POSITIONAL BANS
|
||||
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
||||
auto ban_it = slot.positional_bans.find(slot.n_past);
|
||||
if (ban_it != slot.positional_bans.end()) {
|
||||
for (llama_token tok : ban_it->second) {
|
||||
slot.ctx_sampling->params.logit_bias[tok] += slot.ban_phrases_bias;
|
||||
}
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
|
||||
@@ -3607,4 +3807,4 @@ json server_context::model_meta() const {
|
||||
{"n_params", llama_model_n_params(model)},
|
||||
{"size", llama_model_size(model)},
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -81,7 +81,7 @@ struct server_slot {
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
bool stopped_limit = false;
|
||||
|
||||
bool saturate_predict = false;
|
||||
bool oaicompat = false;
|
||||
|
||||
std::string oaicompat_model;
|
||||
@@ -91,12 +91,16 @@ struct server_slot {
|
||||
// For context rewind/ token buffer
|
||||
size_t n_buffer = 0;
|
||||
int32_t rewind_count = 0;
|
||||
int32_t rewind_count_max = -1;
|
||||
bool rewind_status = false;
|
||||
std::unordered_map<llama_token, float> logit_bias;
|
||||
std::vector<std::string>ban_phrases;
|
||||
std::vector<std::string> ban_phrases;
|
||||
std::vector<std::string> ban_regex;
|
||||
std::vector<std::string> ban_regex_ci;
|
||||
completion_token_outputs token_buffer;
|
||||
float ban_phrases_bias = 0;
|
||||
int32_t banned_n = 1;
|
||||
std::map<int32_t, std::set<llama_token>> positional_bans;
|
||||
|
||||
server_prompt server_cached_prompt;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user