Make string ban more robust and add regex ban (#1243)

* Test new ctx_sampling->n_rewind system

* CRLF quickfix

* Adaptive p check

* merge banned_n

* Fix attempt 1

* Fix attempt 2
This commit is contained in:
SneedwareInc
2026-03-11 15:30:27 +01:00
committed by GitHub
parent fd4638f0e8
commit 4a247593dc
2 changed files with 271 additions and 67 deletions

View File

@@ -11,6 +11,7 @@
#include "mtmd.h"
#include "mtmd-helper.h"
#include <regex>
server_context::~server_context() {
if (ctx) {
@@ -375,6 +376,11 @@ void server_slot::reset() {
generated_token_probs.clear();
checkpoint_pos = 0;
positional_bans.clear();
ban_phrases.clear();
ban_regex.clear();
ban_regex_ci.clear();
// Reset speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
@@ -864,6 +870,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
slot.saturate_predict = json_value(data, "saturate_predict", false);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
@@ -1216,6 +1223,10 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
{
// ban string
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
const auto& banned_strings = data.find("banned_strings");
if (banned_strings != data.end() && banned_strings->is_array()) {
slot.ban_phrases.clear();
@@ -1224,15 +1235,14 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
std::string s = val.get<std::string>();
if (!s.empty()) {
s = string_lower(s);
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
if (ban_tokens.size() > slot.n_buffer) {
slot.n_buffer = ban_tokens.size();
// Use string length instead of token count
if (s.length() > slot.n_buffer) {
slot.n_buffer = s.length();
}
slot.ban_phrases.push_back(s);
}
}
}
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
return a.length() > b.length();
});
@@ -1245,20 +1255,72 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
for (auto & val : params_base.ban_phrases) {
if (!val.empty()) {
val = string_lower(val);
auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true);
if (ban_tokens.size() > slot.n_buffer) {
slot.n_buffer = ban_tokens.size();
// Use string length instead of token count
if (val.length() > slot.n_buffer) {
slot.n_buffer = val.length();
}
slot.ban_phrases.push_back(val);
}
}
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
params_base.n_buffer = slot.n_buffer;
params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1
} else {
slot.ban_phrases = params_base.ban_phrases;
slot.n_buffer = params_base.n_buffer;
}
}
// ban regex
slot.ban_regex.clear();
const auto& banned_regex = data.find("banned_regex");
if (banned_regex != data.end() && banned_regex->is_array()) {
for (const auto& val : data["banned_regex"]) {
if (val.is_string()) {
std::string s = val.get<std::string>();
if (!s.empty()) {
try {
std::regex re(s);
slot.ban_regex.push_back(s);
if (s.length() > slot.n_buffer) {
slot.n_buffer = s.length();
}
} catch (const std::regex_error& e) {
send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}
}
}
// ban regex case insensitive
slot.ban_regex_ci.clear();
const auto& banned_regex_ci = data.find("banned_regex_case_insensitive");
if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) {
for (const auto& val : data["banned_regex_case_insensitive"]) {
if (val.is_string()) {
std::string s = val.get<std::string>();
if (!s.empty()) {
try {
std::regex re(s, std::regex_constants::icase);
slot.ban_regex_ci.push_back(s);
if (s.length() > slot.n_buffer) {
slot.n_buffer = s.length();
}
} catch (const std::regex_error& e) {
send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}
}
}
if (banbuffer_size > 0) {
slot.n_buffer = banbuffer_size;
} else {
slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1
}
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
@@ -3261,12 +3323,24 @@ void server_context::release_slot_after_final_response(server_slot & slot) {
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
int count = 0;
bool released = false;
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
for (auto& it : results) {
bool has_next = process_token(it, slot);
// Clean up positional bans for the token we just confirmed/sent
slot.positional_bans.erase(start_pos + count);
count++;
if (!has_next) {
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
continue;
}
send_final_response(slot);
release_slot_after_final_response(slot);
released = true;
break;
}
if (n > 0 && count >= n) {
@@ -3274,6 +3348,11 @@ void server_context::send_token_results(completion_token_outputs& results, serve
}
}
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
send_final_response(slot);
release_slot_after_final_response(slot);
}
if (count > 0) {
slot.sampled = results[results.size()-1].tok;
results.erase(results.begin(), results.begin() + count);
@@ -3281,50 +3360,100 @@ void server_context::send_token_results(completion_token_outputs& results, serve
}
inline int32_t check_ban_phrase(const server_slot& slot) {
bool found = false;
size_t n = slot.token_buffer.size();
size_t start;
int32_t n_rewind = 0;
inline int32_t check_ban_phrase(server_slot& slot) {
if (slot.token_buffer.empty()) return -1;
std::string string_buffer;
llama_tokens tokens;
for (auto& it : slot.token_buffer) {
string_buffer = string_buffer + it.text_to_send;
tokens.push_back(it.tok);
std::vector<size_t> token_offsets;
for (const auto& it : slot.token_buffer) {
token_offsets.push_back(string_buffer.size());
string_buffer += it.text_to_send;
}
string_buffer = string_lower(string_buffer);
for (auto it : slot.ban_phrases) {
start = string_buffer.find(it);
// has been sorted from longest to shortest
size_t best_start = std::string::npos;
bool found = false;
std::string string_buffer_lower = string_lower(string_buffer);
// 1. Check strings
for (const auto& phrase : slot.ban_phrases) {
size_t start = string_buffer_lower.find(phrase);
if (start != std::string::npos) {
found = true;
break;
if (start < best_start) {
best_start = start;
found = true;
}
}
}
if (found) {
std::vector<size_t> unused;
LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str());
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
// 2. Check regex
for (const auto& pattern : slot.ban_regex) {
try {
std::regex re(pattern);
std::smatch match;
if (std::regex_search(string_buffer, match, re)) {
if (match.position() < best_start) {
best_start = match.position();
found = true;
}
}
} catch (...) { continue; }
}
return n_rewind;
// 3. Check regex case insensitive
for (const auto& pattern : slot.ban_regex_ci) {
try {
std::regex re(pattern, std::regex_constants::icase);
std::smatch match;
if (std::regex_search(string_buffer, match, re)) {
if (match.position() < best_start) {
best_start = match.position();
found = true;
}
}
} catch (...) { continue; }
}
if (found) {
int32_t token_idx = -1;
for (size_t i = 0; i < token_offsets.size(); ++i) {
size_t len = (i == token_offsets.size() - 1)
? string_buffer.size() - token_offsets[i]
: token_offsets[i+1] - token_offsets[i];
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
token_idx = (int32_t)i;
break;
}
}
if (token_idx != -1) {
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
return abs_pos;
}
}
return -1;
}
inline void rewind_context(server_slot& slot, int32_t n_rewind) {
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
slot.rewind_count++;
int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind;
std::set<llama_token> tokens;
// ban all tokens for better coherence
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
if (n_keep_buffer < 0) n_keep_buffer = 0;
if (slot.banned_n != 0) {
int32_t n = 0;
for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++)
{
if (!tokens.contains(result->tok)) {
slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias;
}
else {
tokens.insert(result->tok);
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
llama_token banned_tok = result->tok;
if (n == 0) {
LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
ban_pos, banned_tok, result->text_to_send.c_str());
}
slot.positional_bans[ban_pos].insert(banned_tok);
n++;
if (slot.banned_n > 0 && n == slot.banned_n) {
break;
@@ -3332,52 +3461,114 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) {
}
}
slot.token_buffer.resize(n_keep_rewind);
size_t n_keep = slot.cache_tokens.size() - n_rewind;
slot.sampled = slot.cache_tokens[n_keep];
slot.cache_tokens.keep_first(n_keep);
llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1);
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
size_t n_keep_cache = 0;
if (ban_pos > 0) {
n_keep_cache = (size_t)(ban_pos - 1);
}
if (n_keep_cache > slot.cache_tokens.size()) {
n_keep_cache = slot.cache_tokens.size();
}
if (n_keep_cache < slot.cache_tokens.size()) {
slot.sampled = slot.cache_tokens[n_keep_cache];
} else {
slot.sampled = 0;
}
// Truncate cache
slot.cache_tokens.keep_first(n_keep_cache);
slot.n_past = slot.cache_tokens.n_tokens();
// Remove from KV cache
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
// Truncate buffer
slot.token_buffer.resize(n_keep_buffer);
// Adjust decoded count
if (slot.saturate_predict) {
slot.n_decoded -= n_rewind_total;
if (slot.n_decoded < 0) slot.n_decoded = 0;
}
}
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
slot.token_buffer.push_back(result);
bool next_token = has_next_token(result, slot);
bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token;
// If buffer full or generation stopped, we might send tokens
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
int32_t ban_pos = -1;
int32_t n_rewind = 0;
bool sent_results = false;
// don't restore if last time was also rewind
if (!slot.rewind_status) {
slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias
// Always reset logit bias to base before checking bans
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
ban_pos = check_ban_phrase(slot);
if (ban_pos >= 0 && slot.sparams.adaptive_target >= 0.0f) {
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
if (n_keep_buffer < 0) n_keep_buffer = 0;
n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer;
}
}
if (slot.ban_phrases.size() > 0) {
n_rewind = check_ban_phrase(slot);
bool allow_rewind = true;
if (ban_pos >= 0) {
if (slot.rewind_count_max == -1) {
// Automatic / Heuristic logic
// Account for strings + regex + regex_ci
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
// Heuristic: Allow if under 20 OR under 2 * total_bans
// Conversely: Stop if >= 20 AND > 2 * total_bans
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
allow_rewind = false;
}
}
else if (slot.rewind_count_max > 0) {
// Strict limit logic
if (slot.rewind_count >= slot.rewind_count_max) {
allow_rewind = false;
}
}
// If slot.rewind_count_max == 0, allow_rewind remains true (Infinite)
}
// if found string in the ban
if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) {
rewind_context(slot, n_rewind);
if (ban_pos >= 0 && allow_rewind) {
rewind_context(slot, ban_pos);
slot.rewind_status = true;
}
else if (send_result) {
else if (buffer_full || !next_token) {
slot.rewind_status = false;
slot.rewind_count = 0;
if (!next_token) {
// send all remaining tokens in the buffer
// send all remaining tokens
send_token_results(slot.token_buffer, slot);
}
else {
// send 1 token
// send 1 token from the front (FIFO)
send_token_results(slot.token_buffer, slot, 1);
}
sent_results = true;
if (slot.sparams.adaptive_target >= 0.0f) {
sent_results = true;
}
}
else {
// buffer the result
slot.sampled = result.tok; // for common batch add
// buffer the result, wait for more tokens to validate string
slot.sampled = result.tok;
}
if (slot.sparams.adaptive_target >= 0.0f) {
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
}
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
}
void server_context::process_batch_tokens(int32_t & n_batch) {
@@ -3465,6 +3656,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
continue; // sample using speculative decoding
}
// RESTORE AND APPLY POSITIONAL BANS
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
auto ban_it = slot.positional_bans.find(slot.n_past);
if (ban_it != slot.positional_bans.end()) {
for (llama_token tok : ban_it->second) {
slot.ctx_sampling->params.logit_bias[tok] += slot.ban_phrases_bias;
}
}
completion_token_output result;
const int tok_idx = slot.i_batch - i;
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
@@ -3607,4 +3807,4 @@ json server_context::model_meta() const {
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size(model)},
};
}
}

View File

@@ -81,7 +81,7 @@ struct server_slot {
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;
bool saturate_predict = false;
bool oaicompat = false;
std::string oaicompat_model;
@@ -91,12 +91,16 @@ struct server_slot {
// For context rewind/ token buffer
size_t n_buffer = 0;
int32_t rewind_count = 0;
int32_t rewind_count_max = -1;
bool rewind_status = false;
std::unordered_map<llama_token, float> logit_bias;
std::vector<std::string>ban_phrases;
std::vector<std::string> ban_phrases;
std::vector<std::string> ban_regex;
std::vector<std::string> ban_regex_ci;
completion_token_outputs token_buffer;
float ban_phrases_bias = 0;
int32_t banned_n = 1;
std::map<int32_t, std::set<llama_token>> positional_bans;
server_prompt server_cached_prompt;