mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 01:24:08 +00:00
init n_buffer
This commit is contained in:
@@ -1532,9 +1532,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
ban_phrases.push_back(str);
|
||||
}
|
||||
}
|
||||
std::sort(ban_phrases.begin(), ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
||||
return a.length() > b.length();
|
||||
});
|
||||
params.ban_phrases = ban_phrases;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -216,9 +216,9 @@ struct gpt_params {
|
||||
|
||||
std::vector<std::string> in_files; // all input files
|
||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||
std::vector<std::string> ban_phrases; //strings that are banned in generation
|
||||
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
||||
int32_t n_buffer; // number of token buffers for string ban
|
||||
std::vector<std::string> ban_phrases; // strings that are banned in generation
|
||||
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
||||
size_t n_buffer = 0; // number of token buffers for string ban
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
|
||||
@@ -1143,21 +1143,28 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
||||
return a.length() > b.length();
|
||||
});
|
||||
}
|
||||
else if (params_base.ban_phrases.size()>0 && params_base.n_buffer == 0) {
|
||||
slot.ban_phrases.clear();
|
||||
for (const auto & val : params_base.ban_phrases) {
|
||||
if (!val.empty()) {
|
||||
std::string s = string_lower(val);
|
||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
|
||||
if (ban_tokens.size() > slot.n_buffer) {
|
||||
slot.n_buffer = ban_tokens.size();
|
||||
} else if (params_base.ban_phrases.size() > 0) {
|
||||
if (params_base.n_buffer == 0) {
|
||||
slot.ban_phrases.clear();
|
||||
std::sort(params_base.ban_phrases.begin(), params_base.ban_phrases.end(), [](const std::string & a, const std::string & b) {
|
||||
return a.length() > b.length();
|
||||
});
|
||||
for (auto & val : params_base.ban_phrases) {
|
||||
if (!val.empty()) {
|
||||
val = string_lower(val);
|
||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true);
|
||||
if (ban_tokens.size() > slot.n_buffer) {
|
||||
slot.n_buffer = ban_tokens.size();
|
||||
}
|
||||
slot.ban_phrases.push_back(val);
|
||||
}
|
||||
slot.ban_phrases.push_back(s);
|
||||
}
|
||||
}
|
||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
||||
params_base.n_buffer = slot.n_buffer;
|
||||
} else {
|
||||
slot.ban_phrases = params_base.ban_phrases;
|
||||
slot.n_buffer = params_base.n_buffer;
|
||||
}
|
||||
params_base.n_buffer = slot.n_buffer + 3;
|
||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
||||
}
|
||||
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
|
||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||
|
||||
@@ -84,7 +84,7 @@ struct server_slot {
|
||||
stop_type stop;
|
||||
|
||||
// For context rewind/ token buffer
|
||||
int32_t n_buffer = 0;
|
||||
size_t n_buffer = 0;
|
||||
int32_t rewind_count = 0;
|
||||
bool rewind_status = false;
|
||||
std::unordered_map<llama_token, float> logit_bias;
|
||||
|
||||
Reference in New Issue
Block a user