mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 03:41:53 +00:00
Make string ban more robust and add regex ban (#1243)
* Test new ctx_sampling->n_rewind system * CRLF quickfix * Adaptive p check * merge banned_n * Fix attempt 1 * Fix attempt 2
This commit is contained in:
@@ -11,6 +11,7 @@
|
|||||||
#include "mtmd.h"
|
#include "mtmd.h"
|
||||||
#include "mtmd-helper.h"
|
#include "mtmd-helper.h"
|
||||||
|
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
server_context::~server_context() {
|
server_context::~server_context() {
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
@@ -375,6 +376,11 @@ void server_slot::reset() {
|
|||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
checkpoint_pos = 0;
|
checkpoint_pos = 0;
|
||||||
|
|
||||||
|
positional_bans.clear();
|
||||||
|
ban_phrases.clear();
|
||||||
|
ban_regex.clear();
|
||||||
|
ban_regex_ci.clear();
|
||||||
|
|
||||||
// Reset speculative decoding stats
|
// Reset speculative decoding stats
|
||||||
n_draft_total = 0;
|
n_draft_total = 0;
|
||||||
n_draft_accepted = 0;
|
n_draft_accepted = 0;
|
||||||
@@ -864,6 +870,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
|||||||
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
|
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
|
||||||
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||||
|
slot.saturate_predict = json_value(data, "saturate_predict", false);
|
||||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||||
@@ -1216,6 +1223,10 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
|||||||
|
|
||||||
{
|
{
|
||||||
// ban string
|
// ban string
|
||||||
|
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
|
||||||
|
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
|
||||||
|
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
|
||||||
|
|
||||||
const auto& banned_strings = data.find("banned_strings");
|
const auto& banned_strings = data.find("banned_strings");
|
||||||
if (banned_strings != data.end() && banned_strings->is_array()) {
|
if (banned_strings != data.end() && banned_strings->is_array()) {
|
||||||
slot.ban_phrases.clear();
|
slot.ban_phrases.clear();
|
||||||
@@ -1224,15 +1235,14 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
|||||||
std::string s = val.get<std::string>();
|
std::string s = val.get<std::string>();
|
||||||
if (!s.empty()) {
|
if (!s.empty()) {
|
||||||
s = string_lower(s);
|
s = string_lower(s);
|
||||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
|
// Use string length instead of token count
|
||||||
if (ban_tokens.size() > slot.n_buffer) {
|
if (s.length() > slot.n_buffer) {
|
||||||
slot.n_buffer = ban_tokens.size();
|
slot.n_buffer = s.length();
|
||||||
}
|
}
|
||||||
slot.ban_phrases.push_back(s);
|
slot.ban_phrases.push_back(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
|
||||||
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
||||||
return a.length() > b.length();
|
return a.length() > b.length();
|
||||||
});
|
});
|
||||||
@@ -1245,20 +1255,72 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
|||||||
for (auto & val : params_base.ban_phrases) {
|
for (auto & val : params_base.ban_phrases) {
|
||||||
if (!val.empty()) {
|
if (!val.empty()) {
|
||||||
val = string_lower(val);
|
val = string_lower(val);
|
||||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true);
|
// Use string length instead of token count
|
||||||
if (ban_tokens.size() > slot.n_buffer) {
|
if (val.length() > slot.n_buffer) {
|
||||||
slot.n_buffer = ban_tokens.size();
|
slot.n_buffer = val.length();
|
||||||
}
|
}
|
||||||
slot.ban_phrases.push_back(val);
|
slot.ban_phrases.push_back(val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1
|
||||||
params_base.n_buffer = slot.n_buffer;
|
|
||||||
} else {
|
} else {
|
||||||
slot.ban_phrases = params_base.ban_phrases;
|
slot.ban_phrases = params_base.ban_phrases;
|
||||||
slot.n_buffer = params_base.n_buffer;
|
slot.n_buffer = params_base.n_buffer;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ban regex
|
||||||
|
slot.ban_regex.clear();
|
||||||
|
const auto& banned_regex = data.find("banned_regex");
|
||||||
|
if (banned_regex != data.end() && banned_regex->is_array()) {
|
||||||
|
for (const auto& val : data["banned_regex"]) {
|
||||||
|
if (val.is_string()) {
|
||||||
|
std::string s = val.get<std::string>();
|
||||||
|
if (!s.empty()) {
|
||||||
|
try {
|
||||||
|
std::regex re(s);
|
||||||
|
slot.ban_regex.push_back(s);
|
||||||
|
if (s.length() > slot.n_buffer) {
|
||||||
|
slot.n_buffer = s.length();
|
||||||
|
}
|
||||||
|
} catch (const std::regex_error& e) {
|
||||||
|
send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ban regex case insensitive
|
||||||
|
slot.ban_regex_ci.clear();
|
||||||
|
const auto& banned_regex_ci = data.find("banned_regex_case_insensitive");
|
||||||
|
if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) {
|
||||||
|
for (const auto& val : data["banned_regex_case_insensitive"]) {
|
||||||
|
if (val.is_string()) {
|
||||||
|
std::string s = val.get<std::string>();
|
||||||
|
if (!s.empty()) {
|
||||||
|
try {
|
||||||
|
std::regex re(s, std::regex_constants::icase);
|
||||||
|
slot.ban_regex_ci.push_back(s);
|
||||||
|
if (s.length() > slot.n_buffer) {
|
||||||
|
slot.n_buffer = s.length();
|
||||||
|
}
|
||||||
|
} catch (const std::regex_error& e) {
|
||||||
|
send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (banbuffer_size > 0) {
|
||||||
|
slot.n_buffer = banbuffer_size;
|
||||||
|
} else {
|
||||||
|
slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1
|
||||||
|
}
|
||||||
|
|
||||||
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
|
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
|
||||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||||
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
||||||
@@ -3261,12 +3323,24 @@ void server_context::release_slot_after_final_response(server_slot & slot) {
|
|||||||
|
|
||||||
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
bool released = false;
|
||||||
|
|
||||||
|
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||||
|
|
||||||
for (auto& it : results) {
|
for (auto& it : results) {
|
||||||
bool has_next = process_token(it, slot);
|
bool has_next = process_token(it, slot);
|
||||||
|
|
||||||
|
// Clean up positional bans for the token we just confirmed/sent
|
||||||
|
slot.positional_bans.erase(start_pos + count);
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
if (!has_next) {
|
if (!has_next) {
|
||||||
|
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
release_slot_after_final_response(slot);
|
release_slot_after_final_response(slot);
|
||||||
|
released = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (n > 0 && count >= n) {
|
if (n > 0 && count >= n) {
|
||||||
@@ -3274,6 +3348,11 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
||||||
|
send_final_response(slot);
|
||||||
|
release_slot_after_final_response(slot);
|
||||||
|
}
|
||||||
|
|
||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
slot.sampled = results[results.size()-1].tok;
|
slot.sampled = results[results.size()-1].tok;
|
||||||
results.erase(results.begin(), results.begin() + count);
|
results.erase(results.begin(), results.begin() + count);
|
||||||
@@ -3281,50 +3360,100 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int32_t check_ban_phrase(const server_slot& slot) {
|
inline int32_t check_ban_phrase(server_slot& slot) {
|
||||||
bool found = false;
|
if (slot.token_buffer.empty()) return -1;
|
||||||
size_t n = slot.token_buffer.size();
|
|
||||||
size_t start;
|
|
||||||
int32_t n_rewind = 0;
|
|
||||||
std::string string_buffer;
|
std::string string_buffer;
|
||||||
llama_tokens tokens;
|
std::vector<size_t> token_offsets;
|
||||||
for (auto& it : slot.token_buffer) {
|
|
||||||
string_buffer = string_buffer + it.text_to_send;
|
for (const auto& it : slot.token_buffer) {
|
||||||
tokens.push_back(it.tok);
|
token_offsets.push_back(string_buffer.size());
|
||||||
|
string_buffer += it.text_to_send;
|
||||||
}
|
}
|
||||||
string_buffer = string_lower(string_buffer);
|
|
||||||
for (auto it : slot.ban_phrases) {
|
size_t best_start = std::string::npos;
|
||||||
start = string_buffer.find(it);
|
bool found = false;
|
||||||
// has been sorted from longest to shortest
|
std::string string_buffer_lower = string_lower(string_buffer);
|
||||||
|
|
||||||
|
// 1. Check strings
|
||||||
|
for (const auto& phrase : slot.ban_phrases) {
|
||||||
|
size_t start = string_buffer_lower.find(phrase);
|
||||||
if (start != std::string::npos) {
|
if (start != std::string::npos) {
|
||||||
found = true;
|
if (start < best_start) {
|
||||||
break;
|
best_start = start;
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (found) {
|
|
||||||
std::vector<size_t> unused;
|
// 2. Check regex
|
||||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n", string_buffer.substr(start).c_str());
|
for (const auto& pattern : slot.ban_regex) {
|
||||||
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
|
try {
|
||||||
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
|
std::regex re(pattern);
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(string_buffer, match, re)) {
|
||||||
|
if (match.position() < best_start) {
|
||||||
|
best_start = match.position();
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (...) { continue; }
|
||||||
}
|
}
|
||||||
return n_rewind;
|
|
||||||
|
// 3. Check regex case insensitive
|
||||||
|
for (const auto& pattern : slot.ban_regex_ci) {
|
||||||
|
try {
|
||||||
|
std::regex re(pattern, std::regex_constants::icase);
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(string_buffer, match, re)) {
|
||||||
|
if (match.position() < best_start) {
|
||||||
|
best_start = match.position();
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (...) { continue; }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (found) {
|
||||||
|
int32_t token_idx = -1;
|
||||||
|
for (size_t i = 0; i < token_offsets.size(); ++i) {
|
||||||
|
size_t len = (i == token_offsets.size() - 1)
|
||||||
|
? string_buffer.size() - token_offsets[i]
|
||||||
|
: token_offsets[i+1] - token_offsets[i];
|
||||||
|
|
||||||
|
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
|
||||||
|
token_idx = (int32_t)i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (token_idx != -1) {
|
||||||
|
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
|
||||||
|
return abs_pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||||
slot.rewind_count++;
|
slot.rewind_count++;
|
||||||
int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind;
|
|
||||||
std::set<llama_token> tokens;
|
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||||
// ban all tokens for better coherence
|
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
||||||
|
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
||||||
|
|
||||||
if (slot.banned_n != 0) {
|
if (slot.banned_n != 0) {
|
||||||
int32_t n = 0;
|
int32_t n = 0;
|
||||||
for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++)
|
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
|
||||||
{
|
llama_token banned_tok = result->tok;
|
||||||
if (!tokens.contains(result->tok)) {
|
|
||||||
slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias;
|
if (n == 0) {
|
||||||
}
|
LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
||||||
else {
|
ban_pos, banned_tok, result->text_to_send.c_str());
|
||||||
tokens.insert(result->tok);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slot.positional_bans[ban_pos].insert(banned_tok);
|
||||||
n++;
|
n++;
|
||||||
if (slot.banned_n > 0 && n == slot.banned_n) {
|
if (slot.banned_n > 0 && n == slot.banned_n) {
|
||||||
break;
|
break;
|
||||||
@@ -3332,52 +3461,114 @@ inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.token_buffer.resize(n_keep_rewind);
|
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
|
||||||
size_t n_keep = slot.cache_tokens.size() - n_rewind;
|
|
||||||
slot.sampled = slot.cache_tokens[n_keep];
|
size_t n_keep_cache = 0;
|
||||||
slot.cache_tokens.keep_first(n_keep);
|
if (ban_pos > 0) {
|
||||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, n_keep, -1);
|
n_keep_cache = (size_t)(ban_pos - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_keep_cache > slot.cache_tokens.size()) {
|
||||||
|
n_keep_cache = slot.cache_tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_keep_cache < slot.cache_tokens.size()) {
|
||||||
|
slot.sampled = slot.cache_tokens[n_keep_cache];
|
||||||
|
} else {
|
||||||
|
slot.sampled = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate cache
|
||||||
|
slot.cache_tokens.keep_first(n_keep_cache);
|
||||||
|
slot.n_past = slot.cache_tokens.n_tokens();
|
||||||
|
|
||||||
|
// Remove from KV cache
|
||||||
|
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
|
||||||
|
|
||||||
|
// Truncate buffer
|
||||||
|
slot.token_buffer.resize(n_keep_buffer);
|
||||||
|
|
||||||
|
// Adjust decoded count
|
||||||
|
if (slot.saturate_predict) {
|
||||||
|
slot.n_decoded -= n_rewind_total;
|
||||||
|
if (slot.n_decoded < 0) slot.n_decoded = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
||||||
slot.token_buffer.push_back(result);
|
slot.token_buffer.push_back(result);
|
||||||
|
|
||||||
bool next_token = has_next_token(result, slot);
|
bool next_token = has_next_token(result, slot);
|
||||||
bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token;
|
// If buffer full or generation stopped, we might send tokens
|
||||||
|
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
|
||||||
|
|
||||||
|
int32_t ban_pos = -1;
|
||||||
int32_t n_rewind = 0;
|
int32_t n_rewind = 0;
|
||||||
bool sent_results = false;
|
bool sent_results = false;
|
||||||
// don't restore if last time was also rewind
|
|
||||||
if (!slot.rewind_status) {
|
// Always reset logit bias to base before checking bans
|
||||||
slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias
|
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
||||||
|
|
||||||
|
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
|
||||||
|
ban_pos = check_ban_phrase(slot);
|
||||||
|
if (ban_pos >= 0 && slot.sparams.adaptive_target >= 0.0f) {
|
||||||
|
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||||
|
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
||||||
|
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
||||||
|
n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (slot.ban_phrases.size() > 0) {
|
|
||||||
n_rewind = check_ban_phrase(slot);
|
bool allow_rewind = true;
|
||||||
|
|
||||||
|
if (ban_pos >= 0) {
|
||||||
|
if (slot.rewind_count_max == -1) {
|
||||||
|
// Automatic / Heuristic logic
|
||||||
|
// Account for strings + regex + regex_ci
|
||||||
|
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
|
||||||
|
|
||||||
|
// Heuristic: Allow if under 20 OR under 2 * total_bans
|
||||||
|
// Conversely: Stop if >= 20 AND > 2 * total_bans
|
||||||
|
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
|
||||||
|
allow_rewind = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (slot.rewind_count_max > 0) {
|
||||||
|
// Strict limit logic
|
||||||
|
if (slot.rewind_count >= slot.rewind_count_max) {
|
||||||
|
allow_rewind = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If slot.rewind_count_max == 0, allow_rewind remains true (Infinite)
|
||||||
}
|
}
|
||||||
// if found string in the ban
|
|
||||||
if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) {
|
if (ban_pos >= 0 && allow_rewind) {
|
||||||
rewind_context(slot, n_rewind);
|
rewind_context(slot, ban_pos);
|
||||||
slot.rewind_status = true;
|
slot.rewind_status = true;
|
||||||
}
|
}
|
||||||
else if (send_result) {
|
else if (buffer_full || !next_token) {
|
||||||
slot.rewind_status = false;
|
slot.rewind_status = false;
|
||||||
slot.rewind_count = 0;
|
slot.rewind_count = 0;
|
||||||
|
|
||||||
if (!next_token) {
|
if (!next_token) {
|
||||||
// send all remaining tokens in the buffer
|
// send all remaining tokens
|
||||||
send_token_results(slot.token_buffer, slot);
|
send_token_results(slot.token_buffer, slot);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// send 1 token
|
// send 1 token from the front (FIFO)
|
||||||
send_token_results(slot.token_buffer, slot, 1);
|
send_token_results(slot.token_buffer, slot, 1);
|
||||||
}
|
}
|
||||||
sent_results = true;
|
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||||
|
sent_results = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// buffer the result
|
// buffer the result, wait for more tokens to validate string
|
||||||
slot.sampled = result.tok; // for common batch add
|
slot.sampled = result.tok;
|
||||||
|
}
|
||||||
|
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||||
|
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||||
@@ -3465,6 +3656,15 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
|||||||
continue; // sample using speculative decoding
|
continue; // sample using speculative decoding
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RESTORE AND APPLY POSITIONAL BANS
|
||||||
|
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
||||||
|
auto ban_it = slot.positional_bans.find(slot.n_past);
|
||||||
|
if (ban_it != slot.positional_bans.end()) {
|
||||||
|
for (llama_token tok : ban_it->second) {
|
||||||
|
slot.ctx_sampling->params.logit_bias[tok] += slot.ban_phrases_bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const int tok_idx = slot.i_batch - i;
|
const int tok_idx = slot.i_batch - i;
|
||||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
|
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
|
||||||
@@ -3607,4 +3807,4 @@ json server_context::model_meta() const {
|
|||||||
{"n_params", llama_model_n_params(model)},
|
{"n_params", llama_model_n_params(model)},
|
||||||
{"size", llama_model_size(model)},
|
{"size", llama_model_size(model)},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -81,7 +81,7 @@ struct server_slot {
|
|||||||
bool stopped_eos = false;
|
bool stopped_eos = false;
|
||||||
bool stopped_word = false;
|
bool stopped_word = false;
|
||||||
bool stopped_limit = false;
|
bool stopped_limit = false;
|
||||||
|
bool saturate_predict = false;
|
||||||
bool oaicompat = false;
|
bool oaicompat = false;
|
||||||
|
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
@@ -91,12 +91,16 @@ struct server_slot {
|
|||||||
// For context rewind/ token buffer
|
// For context rewind/ token buffer
|
||||||
size_t n_buffer = 0;
|
size_t n_buffer = 0;
|
||||||
int32_t rewind_count = 0;
|
int32_t rewind_count = 0;
|
||||||
|
int32_t rewind_count_max = -1;
|
||||||
bool rewind_status = false;
|
bool rewind_status = false;
|
||||||
std::unordered_map<llama_token, float> logit_bias;
|
std::unordered_map<llama_token, float> logit_bias;
|
||||||
std::vector<std::string>ban_phrases;
|
std::vector<std::string> ban_phrases;
|
||||||
|
std::vector<std::string> ban_regex;
|
||||||
|
std::vector<std::string> ban_regex_ci;
|
||||||
completion_token_outputs token_buffer;
|
completion_token_outputs token_buffer;
|
||||||
float ban_phrases_bias = 0;
|
float ban_phrases_bias = 0;
|
||||||
int32_t banned_n = 1;
|
int32_t banned_n = 1;
|
||||||
|
std::map<int32_t, std::set<llama_token>> positional_bans;
|
||||||
|
|
||||||
server_prompt server_cached_prompt;
|
server_prompt server_cached_prompt;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user