mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
server: add string ban
This commit is contained in:
@@ -1521,6 +1521,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.antiprompt.emplace_back(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--banned-string-file") {
|
||||
CHECK_ARG
|
||||
std::string files = read_file(std::string(argv[i]));
|
||||
std::vector<std::string> ban_strings=string_split(files, "\n");
|
||||
std::vector<std::string> ban_phrases;
|
||||
for (auto& str : ban_strings) {
|
||||
std::erase(str, '"');
|
||||
if (!str.empty()) {
|
||||
ban_phrases.push_back(str);
|
||||
}
|
||||
}
|
||||
std::sort(ban_phrases.begin(), ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
||||
return a.length() > b.length();
|
||||
});
|
||||
params.ban_phrases = ban_phrases;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--banned-n") {
|
||||
CHECK_ARG
|
||||
params.banned_n = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-ld" || arg == "--logdir") {
|
||||
CHECK_ARG
|
||||
params.logdir = argv[i];
|
||||
@@ -2231,6 +2253,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
|
||||
options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target});
|
||||
options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay});
|
||||
options.push_back({ "*", " --banned-string-file", "file path of the list of banned strings on each line" });
|
||||
options.push_back({ "*", " --banned-n", "number of tokens banned in the phrase during rewind. -1 means all tokens: (default: %d)",params.banned_n });
|
||||
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
|
||||
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
|
||||
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
|
||||
@@ -2625,6 +2649,18 @@ std::string string_get_sortable_timestamp() {
|
||||
return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
|
||||
}
|
||||
|
||||
// could be improved to support more languages
|
||||
std::string string_lower(const std::string& str) {
|
||||
std::string result = str;
|
||||
for (char& c : result) {
|
||||
if (c >= 'A' && c <= 'Z') {
|
||||
c = static_cast<char>(c + ('a' - 'A'));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||
if (search.empty()) {
|
||||
return; // Avoid infinite loop if 'search' is an empty string
|
||||
|
||||
@@ -144,40 +144,41 @@ struct gpt_params {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
|
||||
|
||||
int32_t n_threads = cpu_get_num_math();
|
||||
int32_t n_threads_draft = -1;
|
||||
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
|
||||
int32_t n_threads_batch_draft = -1;
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 0; // context size
|
||||
int32_t n_ctx_draft = 0; // context size for draft model
|
||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
|
||||
int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding
|
||||
float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy)
|
||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||
int32_t n_sequences = 1; // number of sequences to decode
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
int32_t max_gpu = 0; // max number of GPUs to use at a time for split mode "graph"
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
int32_t grp_attn_n = 1; // group-attention factor
|
||||
int32_t grp_attn_w = 512; // group-attention width
|
||||
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
|
||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||
int32_t max_extra_alloc_MiB = 256; // extra VRAM per GPU the scheduler may allocate for more efficient compute graph evaluation
|
||||
int32_t nrep = 1; // number of repetitions used in sweep bench
|
||||
int32_t n_threads_draft = -1;
|
||||
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
|
||||
int32_t n_threads_batch_draft = -1;
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 0; // context size
|
||||
int32_t n_ctx_draft = 0; // context size for draft model
|
||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
|
||||
int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding
|
||||
float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy)
|
||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||
int32_t n_sequences = 1; // number of sequences to decode
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
int32_t max_gpu = 0; // max number of GPUs to use at a time for split mode "graph"
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
int32_t grp_attn_n = 1; // group-attention factor
|
||||
int32_t grp_attn_w = 512; // group-attention width
|
||||
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
|
||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||
float ban_phrases_bias = -999.0f; // logit bias applied to ban phrases
|
||||
int32_t max_extra_alloc_MiB = 256; // additional VRAM per GPU the scheduler may allocate for more efficient compute graph evaluation
|
||||
int32_t nrep = 1; // number of repetitions used in sweep bench
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||
void * cb_eval_user_data = nullptr;
|
||||
@@ -213,8 +214,12 @@ struct gpt_params {
|
||||
|
||||
std::string cuda_params = ""; // comma separated list of cuda parameters key=value1,key2=value2
|
||||
|
||||
std::vector<std::string> in_files; // all input files
|
||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||
std::vector<std::string> in_files; // all input files
|
||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||
std::vector<std::string> ban_phrases; //strings that are banned in generation
|
||||
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
||||
int32_t n_buffer; // number of token buffers for string ban
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
std::vector<std::pair<int,int>> offload_policy;
|
||||
@@ -431,6 +436,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string&
|
||||
std::string string_join(const std::vector<std::string>& values, const std::string& separator);
|
||||
std::string string_strip(const std::string & str);
|
||||
std::string string_get_sortable_timestamp();
|
||||
std::string string_lower(const std::string & str);
|
||||
|
||||
static bool string_starts_with(const std::string& str,
|
||||
const std::string& prefix) { // While we wait for C++20's std::string::starts_with...
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
#include <deque>
|
||||
|
||||
|
||||
|
||||
@@ -215,6 +216,8 @@ struct completion_token_output {
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output>& probs, bool post_sampling_probs);
|
||||
};
|
||||
|
||||
using completion_token_outputs = std::deque<completion_token_output>;
|
||||
|
||||
// convert a vector of completion_token_output to json
|
||||
json probs_vector_to_json(const llama_context* ctx, const std::vector<completion_token_output>& probs);
|
||||
|
||||
|
||||
@@ -318,16 +318,20 @@ void server_slot::reset() {
|
||||
n_past = 0;
|
||||
n_past_prompt = 0;
|
||||
n_sent_text = 0;
|
||||
|
||||
drafted.clear();
|
||||
i_batch_dft.clear();
|
||||
|
||||
n_sent_token_probs = 0;
|
||||
infill = false;
|
||||
ga_i = 0;
|
||||
n_past_se = 0;
|
||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
logit_bias.clear();
|
||||
token_buffer.clear();
|
||||
rewind_count = 0;
|
||||
n_buffer = 0;
|
||||
rewind_status = false;
|
||||
|
||||
generated_token_probs.clear();
|
||||
|
||||
|
||||
@@ -782,7 +786,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||
llama_sampling_params default_sparams = params_base.sparams;
|
||||
auto& data = task.data;
|
||||
|
||||
const llama_vocab* vocab = llama_model_get_vocab(model);
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
slot.oaicompat = true;
|
||||
slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||
@@ -1046,8 +1050,10 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
|
||||
{ // apply logit bias
|
||||
const auto& logit_bias = data.find("logit_bias");
|
||||
if (logit_bias != data.end() && logit_bias->is_array()) {
|
||||
if (logit_bias != data.end() && (logit_bias->is_object() || logit_bias->is_array())) {
|
||||
slot.sparams.logit_bias.clear(); // only clear if user sets it
|
||||
}
|
||||
if (logit_bias != data.end() && logit_bias->is_array()) {
|
||||
const int n_vocab = llama_n_vocab(model);
|
||||
for (const auto& el : *logit_bias) {
|
||||
// TODO: we may want to throw errors here, in case "el" is incorrect
|
||||
@@ -1078,12 +1084,86 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (logit_bias != data.end() && logit_bias->is_object()) {
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
for (const auto& el : logit_bias->items()) {
|
||||
float bias;
|
||||
const auto& key = el.key();
|
||||
const auto& value = el.value();
|
||||
if (value.is_number()) {
|
||||
bias = value.get<float>();
|
||||
}
|
||||
else if (value.is_boolean() && !value.get<bool>()) {
|
||||
bias = -INFINITY;
|
||||
}
|
||||
else {
|
||||
continue;
|
||||
}
|
||||
|
||||
char* end;
|
||||
llama_token tok = strtol(key.c_str(), &end, 10);
|
||||
if (*end == 0) {
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
slot.sparams.logit_bias[tok] = bias;
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto toks = common_tokenize(model, key, false);
|
||||
for (auto tok : toks) {
|
||||
slot.sparams.logit_bias[tok] = bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
||||
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
// ban string
|
||||
const auto& banned_strings = data.find("banned_strings");
|
||||
if (banned_strings != data.end() && banned_strings->is_array()) {
|
||||
slot.ban_phrases.clear();
|
||||
for (const auto& val : data["banned_strings"]) {
|
||||
if (val.is_string()) {
|
||||
std::string s = val.get<std::string>();
|
||||
if (!s.empty()) {
|
||||
s = string_lower(s);
|
||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
|
||||
if (ban_tokens.size() > slot.n_buffer) {
|
||||
slot.n_buffer = ban_tokens.size();
|
||||
}
|
||||
slot.ban_phrases.push_back(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
||||
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
||||
return a.length() > b.length();
|
||||
});
|
||||
}
|
||||
else if (params_base.ban_phrases.size()>0 && params_base.n_buffer == 0) {
|
||||
slot.ban_phrases.clear();
|
||||
for (const auto & val : params_base.ban_phrases) {
|
||||
if (!val.empty()) {
|
||||
std::string s = string_lower(val);
|
||||
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
|
||||
if (ban_tokens.size() > slot.n_buffer) {
|
||||
slot.n_buffer = ban_tokens.size();
|
||||
}
|
||||
slot.ban_phrases.push_back(s);
|
||||
}
|
||||
}
|
||||
params_base.n_buffer = slot.n_buffer + 3;
|
||||
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
|
||||
}
|
||||
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
|
||||
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
||||
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
||||
}
|
||||
|
||||
{
|
||||
const auto& stop = data.find("stop");
|
||||
if (stop != data.end() && stop->is_array()) {
|
||||
@@ -1196,6 +1276,28 @@ bool server_context::system_prompt_set(const std::string& sys_prompt) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// keep in sync with process_token(completion_token_output& result, server_slot& slot)
|
||||
bool server_context::has_next_token(const completion_token_output& result, server_slot& slot) {
|
||||
bool next = true;
|
||||
//std::string generate_text = slot.generated_text + result.text_to_send;
|
||||
//bool incomplete = validate_utf8(generate_text) < generate_text.size();
|
||||
//if (incomplete) {
|
||||
// next = true;
|
||||
//}
|
||||
if (slot.n_decoded > 0 && !slot.has_budget(params_base)) {
|
||||
next = false;
|
||||
}
|
||||
if (llama_token_is_eog(model, result.tok)) {
|
||||
next = false;
|
||||
}
|
||||
auto n_ctx_train = llama_n_ctx_train(model);
|
||||
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
|
||||
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
||||
next = false;
|
||||
}
|
||||
return next;
|
||||
}
|
||||
|
||||
bool server_context::process_token(completion_token_output& result, server_slot& slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
@@ -2523,7 +2625,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
if (slot.n_past_prompt == slot.n_prompt_tokens) {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size());
|
||||
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
@@ -2651,6 +2752,124 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_
|
||||
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
||||
};
|
||||
|
||||
|
||||
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
||||
int count = 0;
|
||||
for (auto& it : results) {
|
||||
bool has_next = process_token(it, slot);
|
||||
count++;
|
||||
if (!has_next) {
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
break;
|
||||
}
|
||||
if (n > 0 && count >= n) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
slot.sampled = results[results.size()-1].tok;
|
||||
results.erase(results.begin(), results.begin() + count);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
inline int32_t check_ban_phrase(const server_slot& slot) {
|
||||
bool found = false;
|
||||
size_t n = slot.token_buffer.size();
|
||||
size_t start;
|
||||
int32_t n_rewind = 0;
|
||||
std::string string_buffer;
|
||||
llama_tokens tokens;
|
||||
for (auto& it : slot.token_buffer) {
|
||||
string_buffer = string_buffer + it.text_to_send;
|
||||
tokens.push_back(it.tok);
|
||||
}
|
||||
string_buffer = string_lower(string_buffer);
|
||||
for (auto it : slot.ban_phrases) {
|
||||
start = string_buffer.find(it);
|
||||
// has been sorted from longest to shortest
|
||||
if (start != std::string::npos) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
std::vector<size_t> unused;
|
||||
LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str());
|
||||
n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused);
|
||||
n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n;
|
||||
}
|
||||
return n_rewind;
|
||||
}
|
||||
|
||||
inline void rewind_context(server_slot& slot, int32_t n_rewind) {
|
||||
slot.rewind_count++;
|
||||
int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind;
|
||||
std::set<llama_token> tokens;
|
||||
// ban all tokens for better coherence
|
||||
if (slot.banned_n != 0) {
|
||||
int32_t n = 0;
|
||||
for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++)
|
||||
{
|
||||
if (!tokens.contains(result->tok)) {
|
||||
slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias;
|
||||
}
|
||||
else {
|
||||
tokens.insert(result->tok);
|
||||
}
|
||||
n++;
|
||||
if (slot.banned_n > 0 && n == slot.banned_n) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slot.token_buffer.resize(n_keep_rewind);
|
||||
size_t n_keep = slot.cache_tokens.size() - n_rewind;
|
||||
slot.sampled = slot.cache_tokens[n_keep];
|
||||
slot.cache_tokens.keep_first(n_keep);
|
||||
}
|
||||
|
||||
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
||||
slot.token_buffer.push_back(result);
|
||||
|
||||
bool next_token = has_next_token(result, slot);
|
||||
bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token;
|
||||
int32_t n_rewind = 0;
|
||||
// don't restore if last time was also rewind
|
||||
if (!slot.rewind_status) {
|
||||
slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias
|
||||
}
|
||||
if (slot.ban_phrases.size() > 0) {
|
||||
n_rewind = check_ban_phrase(slot);
|
||||
}
|
||||
// if found string in the ban
|
||||
if (n_rewind > 0 && slot.rewind_count <= 2 * slot.ban_phrases.size()) {
|
||||
rewind_context(slot, n_rewind);
|
||||
slot.rewind_status = true;
|
||||
}
|
||||
else if (send_result) {
|
||||
slot.rewind_status = false;
|
||||
slot.rewind_count = 0;
|
||||
if (!next_token) {
|
||||
// send all remaining tokens in the buffer
|
||||
send_token_results(slot.token_buffer, slot);
|
||||
}
|
||||
else {
|
||||
// send 1 token
|
||||
send_token_results(slot.token_buffer, slot, 1);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// buffer the result
|
||||
slot.sampled = result.tok; // for common batch add
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
@@ -2668,7 +2887,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
int user_cancel = -3;
|
||||
@@ -2721,17 +2939,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
if (slot.i_batch_dft.size() > 0) {
|
||||
continue; // sample using speculative decoding
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
|
||||
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
if (slot.n_decoded == 1) {
|
||||
@@ -2745,16 +2963,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
result.tok = id;
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
if (slot.n_buffer == 0) {
|
||||
slot.token_buffer = { result };
|
||||
send_token_results(slot.token_buffer, slot);
|
||||
} else {
|
||||
// buffer the result and check string ban.
|
||||
// if ban, we need to go back, apply logit bias and regenerate
|
||||
buffer_and_check_string_ban(slot, result);
|
||||
}
|
||||
|
||||
slot.i_batch = -1;
|
||||
@@ -2794,7 +3013,7 @@ void server_context::update_slots() {
|
||||
common_batch_clear(batch);
|
||||
|
||||
// frist, add sampled tokens from any ongoing sequences
|
||||
add_sampled_tokens();
|
||||
add_sampled_tokens(); // Prepare batch for inference
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
@@ -2806,7 +3025,7 @@ void server_context::update_slots() {
|
||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
batch_pending_prompt(n_ubatch, n_batch, batch_type);
|
||||
batch_pending_prompt(n_ubatch, n_batch, batch_type); // Prepare batch for prompt process
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
LOG_VERBOSE("no tokens to decode", {});
|
||||
@@ -2821,7 +3040,7 @@ void server_context::update_slots() {
|
||||
llama_set_embeddings(ctx, batch_type == 1);
|
||||
|
||||
// process the created batch of tokens
|
||||
process_batch_tokens(n_batch);
|
||||
process_batch_tokens(n_batch); // Decode with batch
|
||||
|
||||
LOG_VERBOSE("run slots completed", {});
|
||||
}
|
||||
|
||||
@@ -83,6 +83,16 @@ struct server_slot {
|
||||
std::string stopping_word;
|
||||
stop_type stop;
|
||||
|
||||
// For context rewind/ token buffer
|
||||
int32_t n_buffer = 0;
|
||||
int32_t rewind_count = 0;
|
||||
bool rewind_status = false;
|
||||
std::unordered_map<llama_token, float> logit_bias;
|
||||
std::vector<std::string>ban_phrases;
|
||||
completion_token_outputs token_buffer;
|
||||
float ban_phrases_bias = 0;
|
||||
int32_t banned_n = 1;
|
||||
|
||||
server_prompt server_cached_prompt;
|
||||
|
||||
void prompt_save(server_prompt_cache& prompt_cache) const;
|
||||
@@ -315,5 +325,11 @@ struct server_context {
|
||||
|
||||
bool accept_special_token(const server_slot& slot, const llama_token token);
|
||||
|
||||
bool has_next_token(const completion_token_output& result, server_slot& slot);
|
||||
|
||||
void send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n = 0);
|
||||
|
||||
void buffer_and_check_string_ban(server_slot& slot, completion_token_output& result);
|
||||
|
||||
json model_meta() const;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user