server: add string ban

This commit is contained in:
firecoperana
2026-01-19 21:24:47 -06:00
parent 28f8320f3a
commit c96ad27cd0
5 changed files with 333 additions and 53 deletions

View File

@@ -1521,6 +1521,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.antiprompt.emplace_back(argv[i]);
return true;
}
if (arg == "--banned-string-file") {
CHECK_ARG
std::string files = read_file(std::string(argv[i]));
std::vector<std::string> ban_strings=string_split(files, "\n");
std::vector<std::string> ban_phrases;
for (auto& str : ban_strings) {
std::erase(str, '"');
if (!str.empty()) {
ban_phrases.push_back(str);
}
}
std::sort(ban_phrases.begin(), ban_phrases.end(), [](const std::string& a, const std::string& b) {
return a.length() > b.length();
});
params.ban_phrases = ban_phrases;
return true;
}
if (arg == "--banned-n") {
CHECK_ARG
params.banned_n = std::stoi(argv[i]);
return true;
}
if (arg == "-ld" || arg == "--logdir") {
CHECK_ARG
params.logdir = argv[i];
@@ -2231,6 +2253,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target});
options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay});
options.push_back({ "*", " --banned-string-file", "file path of the list of banned strings on each line" });
options.push_back({ "*", " --banned-n", "number of tokens banned in the phrase during rewind. -1 means all tokens: (default: %d)",params.banned_n });
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
@@ -2625,6 +2649,18 @@ std::string string_get_sortable_timestamp() {
return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
}
// could be improved to support more languages
std::string string_lower(const std::string& str) {
std::string result = str;
for (char& c : result) {
if (c >= 'A' && c <= 'Z') {
c = static_cast<char>(c + ('a' - 'A'));
}
}
return result;
}
void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
if (search.empty()) {
return; // Avoid infinite loop if 'search' is an empty string