From 4a247593dcea93507088c9baedc9cce19960fda3 Mon Sep 17 00:00:00 2001 From: SneedwareInc <254158255+SneedwareInc@users.noreply.github.com> Date: Wed, 11 Mar 2026 15:30:27 +0100 Subject: [PATCH] Make string ban more robust and add regex ban (#1243) * Test new ctx_sampling->n_rewind system * CRLF quickfix * Adaptive p check * merge banned_n * Fix attempt 1 * Fix attempt 2 --- examples/server/server-context.cpp | 330 +++++++++++++++++++++++------ examples/server/server-context.h | 8 +- 2 files changed, 271 insertions(+), 67 deletions(-) diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 6e7b3dae..e768eb16 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -11,6 +11,7 @@ #include "mtmd.h" #include "mtmd-helper.h" +#include server_context::~server_context() { if (ctx) { @@ -375,6 +376,11 @@ void server_slot::reset() { generated_token_probs.clear(); checkpoint_pos = 0; + positional_bans.clear(); + ban_phrases.clear(); + ban_regex.clear(); + ban_regex_ci.clear(); + // Reset speculative decoding stats n_draft_total = 0; n_draft_accepted = 0; @@ -864,6 +870,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) slot.params.include_usage = json_value(stream_opt, "include_usage", false); slot.params.cache_prompt = json_value(data, "cache_prompt", true); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + slot.saturate_predict = json_value(data, "saturate_predict", false); slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); @@ -1216,6 +1223,10 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) { // ban string + int32_t banbuffer_size = json_value(data, "banbuffer_size", 0); + slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot + slot.rewind_count_max = json_value(data, "rewind_count_max", -1); + const auto& banned_strings = data.find("banned_strings"); if (banned_strings != data.end() && banned_strings->is_array()) { slot.ban_phrases.clear(); @@ -1224,15 +1235,14 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) std::string s = val.get(); if (!s.empty()) { s = string_lower(s); - auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true); - if (ban_tokens.size() > slot.n_buffer) { - slot.n_buffer = ban_tokens.size(); + // Use string length instead of token count + if (s.length() > slot.n_buffer) { + slot.n_buffer = s.length(); } slot.ban_phrases.push_back(s); } } } - slot.n_buffer = slot.n_buffer + 3; // extra buffer in case std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) { return a.length() > b.length(); }); @@ -1245,20 +1255,72 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) for (auto & val : params_base.ban_phrases) { if (!val.empty()) { val = string_lower(val); - auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true); - if (ban_tokens.size() > slot.n_buffer) { - slot.n_buffer = ban_tokens.size(); + // Use string length instead of token count + if (val.length() > slot.n_buffer) { + slot.n_buffer = val.length(); } slot.ban_phrases.push_back(val); } } - slot.n_buffer = slot.n_buffer + 3; // extra buffer in case - params_base.n_buffer = slot.n_buffer; + params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1 } else { slot.ban_phrases = params_base.ban_phrases; slot.n_buffer = params_base.n_buffer; } } + + // ban regex + slot.ban_regex.clear(); + const auto& banned_regex = data.find("banned_regex"); + if (banned_regex != data.end() && banned_regex->is_array()) { + for (const auto& val : data["banned_regex"]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + try { + std::regex re(s); + slot.ban_regex.push_back(s); + if (s.length() > slot.n_buffer) { + slot.n_buffer = s.length(); + } + } catch (const std::regex_error& e) { + send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + } + } + + // ban regex case insensitive + slot.ban_regex_ci.clear(); + const auto& banned_regex_ci = data.find("banned_regex_case_insensitive"); + if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) { + for (const auto& val : data["banned_regex_case_insensitive"]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + try { + std::regex re(s, std::regex_constants::icase); + slot.ban_regex_ci.push_back(s); + if (s.length() > slot.n_buffer) { + slot.n_buffer = s.length(); + } + } catch (const std::regex_error& e) { + send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + } + } + + if (banbuffer_size > 0) { + slot.n_buffer = banbuffer_size; + } else { + slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1 + } + slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias); slot.banned_n = json_value(data, "banned_n", params_base.banned_n); @@ -3261,12 +3323,24 @@ void server_context::release_slot_after_final_response(server_slot & slot) { void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) { int count = 0; + bool released = false; + + int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; + for (auto& it : results) { bool has_next = process_token(it, slot); + + // Clean up positional bans for the token we just confirmed/sent + slot.positional_bans.erase(start_pos + count); + count++; if (!has_next) { + if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { + continue; + } send_final_response(slot); release_slot_after_final_response(slot); + released = true; break; } if (n > 0 && count >= n) { @@ -3274,6 +3348,11 @@ void server_context::send_token_results(completion_token_outputs& results, serve } } + if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { + send_final_response(slot); + release_slot_after_final_response(slot); + } + if (count > 0) { slot.sampled = results[results.size()-1].tok; results.erase(results.begin(), results.begin() + count); @@ -3281,50 +3360,100 @@ void server_context::send_token_results(completion_token_outputs& results, serve } -inline int32_t check_ban_phrase(const server_slot& slot) { - bool found = false; - size_t n = slot.token_buffer.size(); - size_t start; - int32_t n_rewind = 0; +inline int32_t check_ban_phrase(server_slot& slot) { + if (slot.token_buffer.empty()) return -1; + std::string string_buffer; - llama_tokens tokens; - for (auto& it : slot.token_buffer) { - string_buffer = string_buffer + it.text_to_send; - tokens.push_back(it.tok); + std::vector token_offsets; + + for (const auto& it : slot.token_buffer) { + token_offsets.push_back(string_buffer.size()); + string_buffer += it.text_to_send; } - string_buffer = string_lower(string_buffer); - for (auto it : slot.ban_phrases) { - start = string_buffer.find(it); - // has been sorted from longest to shortest + + size_t best_start = std::string::npos; + bool found = false; + std::string string_buffer_lower = string_lower(string_buffer); + + // 1. Check strings + for (const auto& phrase : slot.ban_phrases) { + size_t start = string_buffer_lower.find(phrase); if (start != std::string::npos) { - found = true; - break; + if (start < best_start) { + best_start = start; + found = true; + } } } - if (found) { - std::vector unused; - LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str()); - n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused); - n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n; + + // 2. Check regex + for (const auto& pattern : slot.ban_regex) { + try { + std::regex re(pattern); + std::smatch match; + if (std::regex_search(string_buffer, match, re)) { + if (match.position() < best_start) { + best_start = match.position(); + found = true; + } + } + } catch (...) { continue; } } - return n_rewind; + + // 3. Check regex case insensitive + for (const auto& pattern : slot.ban_regex_ci) { + try { + std::regex re(pattern, std::regex_constants::icase); + std::smatch match; + if (std::regex_search(string_buffer, match, re)) { + if (match.position() < best_start) { + best_start = match.position(); + found = true; + } + } + } catch (...) { continue; } + } + + if (found) { + int32_t token_idx = -1; + for (size_t i = 0; i < token_offsets.size(); ++i) { + size_t len = (i == token_offsets.size() - 1) + ? string_buffer.size() - token_offsets[i] + : token_offsets[i+1] - token_offsets[i]; + + if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) { + token_idx = (int32_t)i; + break; + } + } + + if (token_idx != -1) { + int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx; + return abs_pos; + } + } + + return -1; } -inline void rewind_context(server_slot& slot, int32_t n_rewind) { +inline void rewind_context(server_slot& slot, int32_t ban_pos) { slot.rewind_count++; - int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind; - std::set tokens; - // ban all tokens for better coherence + + int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; + int32_t n_keep_buffer = ban_pos - buffer_start_pos; + if (n_keep_buffer < 0) n_keep_buffer = 0; + if (slot.banned_n != 0) { int32_t n = 0; - for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++) - { - if (!tokens.contains(result->tok)) { - slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias; - } - else { - tokens.insert(result->tok); + for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) { + llama_token banned_tok = result->tok; + + if (n == 0) { + LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", + ban_pos, banned_tok, result->text_to_send.c_str()); } + + slot.positional_bans[ban_pos].insert(banned_tok); n++; if (slot.banned_n > 0 && n == slot.banned_n) { break; @@ -3332,52 +3461,114 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) { } } - slot.token_buffer.resize(n_keep_rewind); - size_t n_keep = slot.cache_tokens.size() - n_rewind; - slot.sampled = slot.cache_tokens[n_keep]; - slot.cache_tokens.keep_first(n_keep); - llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1); - + int32_t n_rewind_total = (slot.n_past + 1) - ban_pos; + + size_t n_keep_cache = 0; + if (ban_pos > 0) { + n_keep_cache = (size_t)(ban_pos - 1); + } + + if (n_keep_cache > slot.cache_tokens.size()) { + n_keep_cache = slot.cache_tokens.size(); + } + + if (n_keep_cache < slot.cache_tokens.size()) { + slot.sampled = slot.cache_tokens[n_keep_cache]; + } else { + slot.sampled = 0; + } + + // Truncate cache + slot.cache_tokens.keep_first(n_keep_cache); + slot.n_past = slot.cache_tokens.n_tokens(); + + // Remove from KV cache + llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1); + + // Truncate buffer + slot.token_buffer.resize(n_keep_buffer); + + // Adjust decoded count + if (slot.saturate_predict) { + slot.n_decoded -= n_rewind_total; + if (slot.n_decoded < 0) slot.n_decoded = 0; + } } void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) { slot.token_buffer.push_back(result); bool next_token = has_next_token(result, slot); - bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token; + // If buffer full or generation stopped, we might send tokens + bool buffer_full = slot.token_buffer.size() >= slot.n_buffer; + + int32_t ban_pos = -1; int32_t n_rewind = 0; bool sent_results = false; - // don't restore if last time was also rewind - if (!slot.rewind_status) { - slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias + + // Always reset logit bias to base before checking bans + slot.ctx_sampling->params.logit_bias = slot.logit_bias; + + if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) { + ban_pos = check_ban_phrase(slot); + if (ban_pos >= 0 && slot.sparams.adaptive_target >= 0.0f) { + int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1; + int32_t n_keep_buffer = ban_pos - buffer_start_pos; + if (n_keep_buffer < 0) n_keep_buffer = 0; + n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer; + } } - if (slot.ban_phrases.size() > 0) { - n_rewind = check_ban_phrase(slot); + + bool allow_rewind = true; + + if (ban_pos >= 0) { + if (slot.rewind_count_max == -1) { + // Automatic / Heuristic logic + // Account for strings + regex + regex_ci + size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size(); + + // Heuristic: Allow if under 20 OR under 2 * total_bans + // Conversely: Stop if >= 20 AND > 2 * total_bans + if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) { + allow_rewind = false; + } + } + else if (slot.rewind_count_max > 0) { + // Strict limit logic + if (slot.rewind_count >= slot.rewind_count_max) { + allow_rewind = false; + } + } + // If slot.rewind_count_max == 0, allow_rewind remains true (Infinite) } - // if found string in the ban - if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) { - rewind_context(slot, n_rewind); + + if (ban_pos >= 0 && allow_rewind) { + rewind_context(slot, ban_pos); slot.rewind_status = true; } - else if (send_result) { + else if (buffer_full || !next_token) { slot.rewind_status = false; slot.rewind_count = 0; + if (!next_token) { - // send all remaining tokens in the buffer + // send all remaining tokens send_token_results(slot.token_buffer, slot); } else { - // send 1 token + // send 1 token from the front (FIFO) send_token_results(slot.token_buffer, slot, 1); } - sent_results = true; + if (slot.sparams.adaptive_target >= 0.0f) { + sent_results = true; + } } else { - // buffer the result - slot.sampled = result.tok; // for common batch add + // buffer the result, wait for more tokens to validate string + slot.sampled = result.tok; + } + if (slot.sparams.adaptive_target >= 0.0f) { + slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind; } - - slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind; } void server_context::process_batch_tokens(int32_t & n_batch) { @@ -3465,6 +3656,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // sample using speculative decoding } + // RESTORE AND APPLY POSITIONAL BANS + slot.ctx_sampling->params.logit_bias = slot.logit_bias; + auto ban_it = slot.positional_bans.find(slot.n_past); + if (ban_it != slot.positional_bans.end()) { + for (llama_token tok : ban_it->second) { + slot.ctx_sampling->params.logit_bias[tok] += slot.ban_phrases_bias; + } + } + completion_token_output result; const int tok_idx = slot.i_batch - i; const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx); @@ -3607,4 +3807,4 @@ json server_context::model_meta() const { {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)}, }; -} +} \ No newline at end of file diff --git a/examples/server/server-context.h b/examples/server/server-context.h index a0008aae..38675d81 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -81,7 +81,7 @@ struct server_slot { bool stopped_eos = false; bool stopped_word = false; bool stopped_limit = false; - + bool saturate_predict = false; bool oaicompat = false; std::string oaicompat_model; @@ -91,12 +91,16 @@ struct server_slot { // For context rewind/ token buffer size_t n_buffer = 0; int32_t rewind_count = 0; + int32_t rewind_count_max = -1; bool rewind_status = false; std::unordered_map logit_bias; - std::vectorban_phrases; + std::vector ban_phrases; + std::vector ban_regex; + std::vector ban_regex_ci; completion_token_outputs token_buffer; float ban_phrases_bias = 0; int32_t banned_n = 1; + std::map> positional_bans; server_prompt server_cached_prompt;