diff --git a/common/common.cpp b/common/common.cpp index 4873a27d..badc4317 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1532,9 +1532,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa ban_phrases.push_back(str); } } - std::sort(ban_phrases.begin(), ban_phrases.end(), [](const std::string& a, const std::string& b) { - return a.length() > b.length(); - }); params.ban_phrases = ban_phrases; return true; } diff --git a/common/common.h b/common/common.h index 58ffded1..55a56018 100644 --- a/common/common.h +++ b/common/common.h @@ -216,9 +216,9 @@ struct gpt_params { std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) - std::vector ban_phrases; //strings that are banned in generation - int32_t banned_n = 1; // number of tokens that are banned in the phrase - int32_t n_buffer; // number of token buffers for string ban + std::vector ban_phrases; // strings that are banned in generation + int32_t banned_n = 1; // number of tokens that are banned in the phrase + size_t n_buffer = 0; // number of token buffers for string ban std::vector kv_overrides; std::vector tensor_buft_overrides; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index dd5d4001..130d499e 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -1143,21 +1143,28 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) { return a.length() > b.length(); }); - } - else if (params_base.ban_phrases.size()>0 && params_base.n_buffer == 0) { - slot.ban_phrases.clear(); - for (const auto & val : params_base.ban_phrases) { - if (!val.empty()) { - std::string s = string_lower(val); - auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true); - if (ban_tokens.size() > slot.n_buffer) { - slot.n_buffer = ban_tokens.size(); + } else if (params_base.ban_phrases.size() > 0) { + if (params_base.n_buffer == 0) { + slot.ban_phrases.clear(); + std::sort(params_base.ban_phrases.begin(), params_base.ban_phrases.end(), [](const std::string & a, const std::string & b) { + return a.length() > b.length(); + }); + for (auto & val : params_base.ban_phrases) { + if (!val.empty()) { + val = string_lower(val); + auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true); + if (ban_tokens.size() > slot.n_buffer) { + slot.n_buffer = ban_tokens.size(); + } + slot.ban_phrases.push_back(val); } - slot.ban_phrases.push_back(s); - } + } + slot.n_buffer = slot.n_buffer + 3; // extra buffer in case + params_base.n_buffer = slot.n_buffer; + } else { + slot.ban_phrases = params_base.ban_phrases; + slot.n_buffer = params_base.n_buffer; } - params_base.n_buffer = slot.n_buffer + 3; - slot.n_buffer = slot.n_buffer + 3; // extra buffer in case } slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index fc2dc029..a9d9573c 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -84,7 +84,7 @@ struct server_slot { stop_type stop; // For context rewind/ token buffer - int32_t n_buffer = 0; + size_t n_buffer = 0; int32_t rewind_count = 0; bool rewind_status = false; std::unordered_map logit_bias;