server: add string ban

This commit is contained in:
firecoperana
2026-01-19 21:24:47 -06:00
parent 28f8320f3a
commit c96ad27cd0
5 changed files with 333 additions and 53 deletions

View File

@@ -318,16 +318,20 @@ void server_slot::reset() {
n_past = 0;
n_past_prompt = 0;
n_sent_text = 0;
drafted.clear();
i_batch_dft.clear();
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
n_past_se = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
logit_bias.clear();
token_buffer.clear();
rewind_count = 0;
n_buffer = 0;
rewind_status = false;
generated_token_probs.clear();
@@ -782,7 +786,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
llama_sampling_params default_sparams = params_base.sparams;
auto& data = task.data;
const llama_vocab* vocab = llama_model_get_vocab(model);
if (data.count("__oaicompat") != 0) {
slot.oaicompat = true;
slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
@@ -1046,8 +1050,10 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
{ // apply logit bias
const auto& logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
if (logit_bias != data.end() && (logit_bias->is_object() || logit_bias->is_array())) {
slot.sparams.logit_bias.clear(); // only clear if user sets it
}
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
for (const auto& el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
@@ -1078,12 +1084,86 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
}
}
else if (logit_bias != data.end() && logit_bias->is_object()) {
const int n_vocab = llama_vocab_n_tokens(vocab);
for (const auto& el : logit_bias->items()) {
float bias;
const auto& key = el.key();
const auto& value = el.value();
if (value.is_number()) {
bias = value.get<float>();
}
else if (value.is_boolean() && !value.get<bool>()) {
bias = -INFINITY;
}
else {
continue;
}
char* end;
llama_token tok = strtol(key.c_str(), &end, 10);
if (*end == 0) {
if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias[tok] = bias;
}
}
else {
auto toks = common_tokenize(model, key, false);
for (auto tok : toks) {
slot.sparams.logit_bias[tok] = bias;
}
}
}
}
if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}
}
{
// ban string
const auto& banned_strings = data.find("banned_strings");
if (banned_strings != data.end() && banned_strings->is_array()) {
slot.ban_phrases.clear();
for (const auto& val : data["banned_strings"]) {
if (val.is_string()) {
std::string s = val.get<std::string>();
if (!s.empty()) {
s = string_lower(s);
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
if (ban_tokens.size() > slot.n_buffer) {
slot.n_buffer = ban_tokens.size();
}
slot.ban_phrases.push_back(s);
}
}
}
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
return a.length() > b.length();
});
}
else if (params_base.ban_phrases.size()>0 && params_base.n_buffer == 0) {
slot.ban_phrases.clear();
for (const auto & val : params_base.ban_phrases) {
if (!val.empty()) {
std::string s = string_lower(val);
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
if (ban_tokens.size() > slot.n_buffer) {
slot.n_buffer = ban_tokens.size();
}
slot.ban_phrases.push_back(s);
}
}
params_base.n_buffer = slot.n_buffer + 3;
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
}
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
}
{
const auto& stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
@@ -1196,6 +1276,28 @@ bool server_context::system_prompt_set(const std::string& sys_prompt) {
return true;
}
// keep in sync with process_token(completion_token_output& result, server_slot& slot)
bool server_context::has_next_token(const completion_token_output& result, server_slot& slot) {
bool next = true;
//std::string generate_text = slot.generated_text + result.text_to_send;
//bool incomplete = validate_utf8(generate_text) < generate_text.size();
//if (incomplete) {
// next = true;
//}
if (slot.n_decoded > 0 && !slot.has_budget(params_base)) {
next = false;
}
if (llama_token_is_eog(model, result.tok)) {
next = false;
}
auto n_ctx_train = llama_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
next = false;
}
return next;
}
bool server_context::process_token(completion_token_output& result, server_slot& slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
@@ -2523,7 +2625,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
if (slot.n_past_prompt == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size());
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
@@ -2651,6 +2752,124 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
};
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
int count = 0;
for (auto& it : results) {
bool has_next = process_token(it, slot);
count++;
if (!has_next) {
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
if (n > 0 && count >= n) {
break;
}
}
if (count > 0) {
slot.sampled = results[results.size()-1].tok;
results.erase(results.begin(), results.begin() + count);
}
}
inline int32_t check_ban_phrase(const server_slot& slot) {
bool found = false;
size_t n = slot.token_buffer.size();
size_t start;
int32_t n_rewind = 0;
std::string string_buffer;
llama_tokens tokens;
for (auto& it : slot.token_buffer) {
string_buffer = string_buffer + it.text_to_send;
tokens.push_back(it.tok);
}
string_buffer = string_lower(string_buffer);
for (auto it : slot.ban_phrases) {
start = string_buffer.find(it);
// has been sorted from longest to shortest
if (start != std::string::npos) {
found = true;
break;
}
}
if (found) {
std::vector<size_t> unused;
LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str());
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
}
return n_rewind;
}
inline void rewind_context(server_slot& slot, int32_t n_rewind) {
slot.rewind_count++;
int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind;
std::set<llama_token> tokens;
// ban all tokens for better coherence
if (slot.banned_n != 0) {
int32_t n = 0;
for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++)
{
if (!tokens.contains(result->tok)) {
slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias;
}
else {
tokens.insert(result->tok);
}
n++;
if (slot.banned_n > 0 && n == slot.banned_n) {
break;
}
}
}
slot.token_buffer.resize(n_keep_rewind);
size_t n_keep = slot.cache_tokens.size() - n_rewind;
slot.sampled = slot.cache_tokens[n_keep];
slot.cache_tokens.keep_first(n_keep);
}
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
slot.token_buffer.push_back(result);
bool next_token = has_next_token(result, slot);
bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token;
int32_t n_rewind = 0;
// don't restore if last time was also rewind
if (!slot.rewind_status) {
slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias
}
if (slot.ban_phrases.size() > 0) {
n_rewind = check_ban_phrase(slot);
}
// if found string in the ban
if (n_rewind > 0 && slot.rewind_count <= 2 * slot.ban_phrases.size()) {
rewind_context(slot, n_rewind);
slot.rewind_status = true;
}
else if (send_result) {
slot.rewind_status = false;
slot.rewind_count = 0;
if (!next_token) {
// send all remaining tokens in the buffer
send_token_results(slot.token_buffer, slot);
}
else {
// send 1 token
send_token_results(slot.token_buffer, slot, 1);
}
}
else {
// buffer the result
slot.sampled = result.tok; // for common batch add
}
}
void server_context::process_batch_tokens(int32_t & n_batch) {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@@ -2668,7 +2887,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
};
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
int user_cancel = -3;
@@ -2721,17 +2939,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
continue; // continue loop of slots
}
completion_token_output result;
if (slot.i_batch_dft.size() > 0) {
continue; // sample using speculative decoding
}
completion_token_output result;
const int tok_idx = slot.i_batch - i;
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
slot.n_decoded += 1;
const int64_t t_current = ggml_time_us();
if (slot.n_decoded == 1) {
@@ -2745,16 +2963,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
result.tok = id;
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
if (slot.n_buffer == 0) {
slot.token_buffer = { result };
send_token_results(slot.token_buffer, slot);
} else {
// buffer the result and check string ban.
// if ban, we need to go back, apply logit bias and regenerate
buffer_and_check_string_ban(slot, result);
}
slot.i_batch = -1;
@@ -2794,7 +3013,7 @@ void server_context::update_slots() {
common_batch_clear(batch);
// frist, add sampled tokens from any ongoing sequences
add_sampled_tokens();
add_sampled_tokens(); // Prepare batch for inference
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
@@ -2806,7 +3025,7 @@ void server_context::update_slots() {
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
batch_pending_prompt(n_ubatch, n_batch, batch_type);
batch_pending_prompt(n_ubatch, n_batch, batch_type); // Prepare batch for prompt process
if (batch.n_tokens == 0) {
LOG_VERBOSE("no tokens to decode", {});
@@ -2821,7 +3040,7 @@ void server_context::update_slots() {
llama_set_embeddings(ctx, batch_type == 1);
// process the created batch of tokens
process_batch_tokens(n_batch);
process_batch_tokens(n_batch); // Decode with batch
LOG_VERBOSE("run slots completed", {});
}