diff --git a/common/common.cpp b/common/common.cpp index 3192fd37..4873a27d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1521,6 +1521,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.antiprompt.emplace_back(argv[i]); return true; } + if (arg == "--banned-string-file") { + CHECK_ARG + std::string files = read_file(std::string(argv[i])); + std::vector ban_strings=string_split(files, "\n"); + std::vector ban_phrases; + for (auto& str : ban_strings) { + std::erase(str, '"'); + if (!str.empty()) { + ban_phrases.push_back(str); + } + } + std::sort(ban_phrases.begin(), ban_phrases.end(), [](const std::string& a, const std::string& b) { + return a.length() > b.length(); + }); + params.ban_phrases = ban_phrases; + return true; + } + if (arg == "--banned-n") { + CHECK_ARG + params.banned_n = std::stoi(argv[i]); + return true; + } if (arg == "-ld" || arg == "--logdir") { CHECK_ARG params.logdir = argv[i]; @@ -2231,6 +2253,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma}); options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target}); options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay}); + options.push_back({ "*", " --banned-string-file", "file path of the list of banned strings on each line" }); + options.push_back({ "*", " --banned-n", "number of tokens banned in the phrase during rewind. -1 means all tokens: (default: %d)",params.banned_n }); options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); @@ -2625,6 +2649,18 @@ std::string string_get_sortable_timestamp() { return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); } +// could be improved to support more languages +std::string string_lower(const std::string& str) { + std::string result = str; + for (char& c : result) { + if (c >= 'A' && c <= 'Z') { + c = static_cast(c + ('a' - 'A')); + } + } + return result; +} + + void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { return; // Avoid infinite loop if 'search' is an empty string diff --git a/common/common.h b/common/common.h index 1de82a6c..58ffded1 100644 --- a/common/common.h +++ b/common/common.h @@ -144,40 +144,41 @@ struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed int32_t n_threads = cpu_get_num_math(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_threads_batch_draft = -1; - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 0; // context size - int32_t n_ctx_draft = 0; // context size for draft model - int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 16; // number of tokens to draft during speculative decoding - int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding - float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy) - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - int32_t max_gpu = 0; // max number of GPUs to use at a time for split mode "graph" - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - int32_t grp_attn_n = 1; // group-attention factor - int32_t grp_attn_w = 512; // group-attention width - int32_t n_print = -1; // print token count every n tokens (-1 = disabled) - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor - float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor - float yarn_beta_fast = -1.0f; // YaRN low correction dim - float yarn_beta_slow = -1.0f; // YaRN high correction dim - int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = -1.0f; // KV cache defragmentation threshold - int32_t max_extra_alloc_MiB = 256; // extra VRAM per GPU the scheduler may allocate for more efficient compute graph evaluation - int32_t nrep = 1; // number of repetitions used in sweep bench + int32_t n_threads_draft = -1; + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch_draft = -1; + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 0; // context size + int32_t n_ctx_draft = 0; // context size for draft model + int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 16; // number of tokens to draft during speculative decoding + int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding + float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy) + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + int32_t max_gpu = 0; // max number of GPUs to use at a time for split mode "graph" + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor + float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = -1.0f; // YaRN low correction dim + float yarn_beta_slow = -1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length + float defrag_thold = -1.0f; // KV cache defragmentation threshold + float ban_phrases_bias = -999.0f; // logit bias applied to ban phrases + int32_t max_extra_alloc_MiB = 256; // additional VRAM per GPU the scheduler may allocate for more efficient compute graph evaluation + int32_t nrep = 1; // number of repetitions used in sweep bench ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; @@ -213,8 +214,12 @@ struct gpt_params { std::string cuda_params = ""; // comma separated list of cuda parameters key=value1,key2=value2 - std::vector in_files; // all input files - std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) + std::vector in_files; // all input files + std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) + std::vector ban_phrases; //strings that are banned in generation + int32_t banned_n = 1; // number of tokens that are banned in the phrase + int32_t n_buffer; // number of token buffers for string ban + std::vector kv_overrides; std::vector tensor_buft_overrides; std::vector> offload_policy; @@ -431,6 +436,7 @@ std::pair> common_remote_get_content(const std::string& std::string string_join(const std::vector& values, const std::string& separator); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_lower(const std::string & str); static bool string_starts_with(const std::string& str, const std::string& prefix) { // While we wait for C++20's std::string::starts_with... diff --git a/examples/server/server-common.h b/examples/server/server-common.h index 52d1e5b3..db6ff0b9 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -13,6 +13,7 @@ #include #include #include +#include @@ -215,6 +216,8 @@ struct completion_token_output { static json probs_vector_to_json(const std::vector& probs, bool post_sampling_probs); }; +using completion_token_outputs = std::deque; + // convert a vector of completion_token_output to json json probs_vector_to_json(const llama_context* ctx, const std::vector& probs); diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 3c4ff874..66b38575 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -318,16 +318,20 @@ void server_slot::reset() { n_past = 0; n_past_prompt = 0; n_sent_text = 0; - drafted.clear(); i_batch_dft.clear(); - n_sent_token_probs = 0; infill = false; ga_i = 0; n_past_se = 0; chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + logit_bias.clear(); + token_buffer.clear(); + rewind_count = 0; + n_buffer = 0; + rewind_status = false; + generated_token_probs.clear(); @@ -782,7 +786,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) llama_sampling_params default_sparams = params_base.sparams; auto& data = task.data; - + const llama_vocab* vocab = llama_model_get_vocab(model); if (data.count("__oaicompat") != 0) { slot.oaicompat = true; slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); @@ -1046,8 +1050,10 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) { // apply logit bias const auto& logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { + if (logit_bias != data.end() && (logit_bias->is_object() || logit_bias->is_array())) { slot.sparams.logit_bias.clear(); // only clear if user sets it + } + if (logit_bias != data.end() && logit_bias->is_array()) { const int n_vocab = llama_n_vocab(model); for (const auto& el : *logit_bias) { // TODO: we may want to throw errors here, in case "el" is incorrect @@ -1078,12 +1084,86 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) } } } + else if (logit_bias != data.end() && logit_bias->is_object()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto& el : logit_bias->items()) { + float bias; + const auto& key = el.key(); + const auto& value = el.value(); + if (value.is_number()) { + bias = value.get(); + } + else if (value.is_boolean() && !value.get()) { + bias = -INFINITY; + } + else { + continue; + } + + char* end; + llama_token tok = strtol(key.c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) { + slot.sparams.logit_bias[tok] = bias; + } + } + else { + auto toks = common_tokenize(model, key, false); + for (auto tok : toks) { + slot.sparams.logit_bias[tok] = bias; + } + } + } + } if (json_value(data, "ignore_eos", false) && has_eos_token) { slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; } } + { + // ban string + const auto& banned_strings = data.find("banned_strings"); + if (banned_strings != data.end() && banned_strings->is_array()) { + slot.ban_phrases.clear(); + for (const auto& val : data["banned_strings"]) { + if (val.is_string()) { + std::string s = val.get(); + if (!s.empty()) { + s = string_lower(s); + auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true); + if (ban_tokens.size() > slot.n_buffer) { + slot.n_buffer = ban_tokens.size(); + } + slot.ban_phrases.push_back(s); + } + } + } + slot.n_buffer = slot.n_buffer + 3; // extra buffer in case + std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) { + return a.length() > b.length(); + }); + } + else if (params_base.ban_phrases.size()>0 && params_base.n_buffer == 0) { + slot.ban_phrases.clear(); + for (const auto & val : params_base.ban_phrases) { + if (!val.empty()) { + std::string s = string_lower(val); + auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true); + if (ban_tokens.size() > slot.n_buffer) { + slot.n_buffer = ban_tokens.size(); + } + slot.ban_phrases.push_back(s); + } + } + params_base.n_buffer = slot.n_buffer + 3; + slot.n_buffer = slot.n_buffer + 3; // extra buffer in case + } + slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore + slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias); + slot.banned_n = json_value(data, "banned_n", params_base.banned_n); + } + { const auto& stop = data.find("stop"); if (stop != data.end() && stop->is_array()) { @@ -1196,6 +1276,28 @@ bool server_context::system_prompt_set(const std::string& sys_prompt) { return true; } +// keep in sync with process_token(completion_token_output& result, server_slot& slot) +bool server_context::has_next_token(const completion_token_output& result, server_slot& slot) { + bool next = true; + //std::string generate_text = slot.generated_text + result.text_to_send; + //bool incomplete = validate_utf8(generate_text) < generate_text.size(); + //if (incomplete) { + // next = true; + //} + if (slot.n_decoded > 0 && !slot.has_budget(params_base)) { + next = false; + } + if (llama_token_is_eog(model, result.tok)) { + next = false; + } + auto n_ctx_train = llama_n_ctx_train(model); + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 + && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + next = false; + } + return next; +} + bool server_context::process_token(completion_token_output& result, server_slot& slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = result.text_to_send; @@ -2523,7 +2625,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t if (slot.n_past_prompt == slot.n_prompt_tokens) { slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; - GGML_ASSERT(batch.n_tokens > 0); GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling); @@ -2651,6 +2752,124 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_ return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end(); }; + +void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) { + int count = 0; + for (auto& it : results) { + bool has_next = process_token(it, slot); + count++; + if (!has_next) { + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + if (n > 0 && count >= n) { + break; + } + } + + if (count > 0) { + slot.sampled = results[results.size()-1].tok; + results.erase(results.begin(), results.begin() + count); + } + +} + +inline int32_t check_ban_phrase(const server_slot& slot) { + bool found = false; + size_t n = slot.token_buffer.size(); + size_t start; + int32_t n_rewind = 0; + std::string string_buffer; + llama_tokens tokens; + for (auto& it : slot.token_buffer) { + string_buffer = string_buffer + it.text_to_send; + tokens.push_back(it.tok); + } + string_buffer = string_lower(string_buffer); + for (auto it : slot.ban_phrases) { + start = string_buffer.find(it); + // has been sorted from longest to shortest + if (start != std::string::npos) { + found = true; + break; + } + } + if (found) { + std::vector unused; + LLAMA_LOG_DEBUG("Banned string dectected: %s\n ", string_buffer.substr(start).c_str()); + n = find_n_tokens_from_string(slot.ctx, tokens, start, 0, unused); + n_rewind = (int32_t) slot.token_buffer.size() - (int32_t) n; + } + return n_rewind; +} + +inline void rewind_context(server_slot& slot, int32_t n_rewind) { + slot.rewind_count++; + int32_t n_keep_rewind = (int32_t)slot.token_buffer.size() - n_rewind; + std::set tokens; + // ban all tokens for better coherence + if (slot.banned_n != 0) { + int32_t n = 0; + for (auto result = slot.token_buffer.begin() + n_keep_rewind; result != slot.token_buffer.end(); result++) + { + if (!tokens.contains(result->tok)) { + slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias; + } + else { + tokens.insert(result->tok); + } + n++; + if (slot.banned_n > 0 && n == slot.banned_n) { + break; + } + } + } + + slot.token_buffer.resize(n_keep_rewind); + size_t n_keep = slot.cache_tokens.size() - n_rewind; + slot.sampled = slot.cache_tokens[n_keep]; + slot.cache_tokens.keep_first(n_keep); +} + +void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) { + slot.token_buffer.push_back(result); + + bool next_token = has_next_token(result, slot); + bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token; + int32_t n_rewind = 0; + // don't restore if last time was also rewind + if (!slot.rewind_status) { + slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias + } + if (slot.ban_phrases.size() > 0) { + n_rewind = check_ban_phrase(slot); + } + // if found string in the ban + if (n_rewind > 0 && slot.rewind_count <= 2 * slot.ban_phrases.size()) { + rewind_context(slot, n_rewind); + slot.rewind_status = true; + } + else if (send_result) { + slot.rewind_status = false; + slot.rewind_count = 0; + if (!next_token) { + // send all remaining tokens in the buffer + send_token_results(slot.token_buffer, slot); + } + else { + // send 1 token + send_token_results(slot.token_buffer, slot, 1); + } + } + else { + // buffer the result + slot.sampled = result.tok; // for common batch add + } +} + void server_context::process_batch_tokens(int32_t & n_batch) { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -2668,7 +2887,6 @@ void server_context::process_batch_tokens(int32_t & n_batch) { }; const int ret = llama_decode(ctx, batch_view); - if (ret != 0) { if (n_batch == 1 || ret < 0) { int user_cancel = -3; @@ -2721,17 +2939,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // continue loop of slots } - completion_token_output result; if (slot.i_batch_dft.size() > 0) { continue; // sample using speculative decoding } + + completion_token_output result; const int tok_idx = slot.i_batch - i; const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, NULL, tok_idx); common_sampler_accept(slot.ctx_sampling, ctx, id, true); slot.n_decoded += 1; - const int64_t t_current = ggml_time_us(); if (slot.n_decoded == 1) { @@ -2745,16 +2963,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) { result.tok = id; result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - if (slot.sparams.n_probs > 0) { populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); } - if (!process_token(result, slot)) { - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); + if (slot.n_buffer == 0) { + slot.token_buffer = { result }; + send_token_results(slot.token_buffer, slot); + } else { + // buffer the result and check string ban. + // if ban, we need to go back, apply logit bias and regenerate + buffer_and_check_string_ban(slot, result); } slot.i_batch = -1; @@ -2794,7 +3013,7 @@ void server_context::update_slots() { common_batch_clear(batch); // frist, add sampled tokens from any ongoing sequences - add_sampled_tokens(); + add_sampled_tokens(); // Prepare batch for inference // process in chunks of params.n_batch int32_t n_batch = llama_n_batch(ctx); @@ -2806,7 +3025,7 @@ void server_context::update_slots() { int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; // next, batch any pending prompts without exceeding n_batch - batch_pending_prompt(n_ubatch, n_batch, batch_type); + batch_pending_prompt(n_ubatch, n_batch, batch_type); // Prepare batch for prompt process if (batch.n_tokens == 0) { LOG_VERBOSE("no tokens to decode", {}); @@ -2821,7 +3040,7 @@ void server_context::update_slots() { llama_set_embeddings(ctx, batch_type == 1); // process the created batch of tokens - process_batch_tokens(n_batch); + process_batch_tokens(n_batch); // Decode with batch LOG_VERBOSE("run slots completed", {}); } diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 34493565..fc2dc029 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -83,6 +83,16 @@ struct server_slot { std::string stopping_word; stop_type stop; + // For context rewind/ token buffer + int32_t n_buffer = 0; + int32_t rewind_count = 0; + bool rewind_status = false; + std::unordered_map logit_bias; + std::vectorban_phrases; + completion_token_outputs token_buffer; + float ban_phrases_bias = 0; + int32_t banned_n = 1; + server_prompt server_cached_prompt; void prompt_save(server_prompt_cache& prompt_cache) const; @@ -315,5 +325,11 @@ struct server_context { bool accept_special_token(const server_slot& slot, const llama_token token); + bool has_next_token(const completion_token_output& result, server_slot& slot); + + void send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n = 0); + + void buffer_and_check_string_ban(server_slot& slot, completion_token_output& result); + json model_meta() const; };