mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 10:00:07 +00:00
server: add checkpoint tolerance and fix grammar_trigger init (#1346)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -2051,6 +2051,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.ctx_checkpoints_interval = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--ctx-checkpoints-tolerance") {
|
||||
CHECK_ARG
|
||||
params.ctx_checkpoints_tolerance = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-cram" || arg == "--cache-ram") {
|
||||
CHECK_ARG
|
||||
params.cache_ram_mib = std::stoi(argv[i]);
|
||||
@@ -2248,6 +2253,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
|
||||
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
|
||||
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
|
||||
options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance});
|
||||
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
|
||||
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
|
||||
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
|
||||
|
||||
@@ -281,7 +281,6 @@ struct gpt_params {
|
||||
int32_t banned_n = 1; // number of tokens that are banned in the phrase
|
||||
size_t n_buffer = 0; // number of token buffers for string ban
|
||||
bool can_ban_phrases = true; // whether to ban strings
|
||||
bool do_checkpoint = false; // do checkpoint for recurrent models only
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
@@ -420,8 +419,10 @@ struct gpt_params {
|
||||
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
bool do_checkpoint = false; // do checkpoint for recurrent models only
|
||||
int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot
|
||||
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
||||
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
|
||||
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens
|
||||
|
||||
@@ -1024,6 +1024,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
else if (penalty_prompt->is_array()) {
|
||||
const auto n_tokens = penalty_prompt->size();
|
||||
slot.sparams.penalty_prompt_tokens.clear();
|
||||
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
||||
|
||||
const int n_vocab = llama_n_vocab(model);
|
||||
@@ -1067,6 +1068,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
|
||||
const auto preserved_tokens = data.find("preserved_tokens");
|
||||
if (preserved_tokens != data.end()) {
|
||||
slot.sparams.preserved_tokens.clear();
|
||||
for (const auto& t : *preserved_tokens) {
|
||||
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
@@ -1081,6 +1083,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
slot.sparams.grammar_triggers.clear();
|
||||
for (const auto& t : *grammar_triggers) {
|
||||
server_grammar_trigger ct(t);
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
@@ -3058,6 +3061,12 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot_npast++;
|
||||
slot.n_past_prompt++;
|
||||
slot.n_past++;
|
||||
slot.do_checkpoint = false;
|
||||
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
|
||||
slot.do_checkpoint = true;
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
LOG_VERBOSE("prompt processing progress", {
|
||||
{"id_slot", slot.id},
|
||||
@@ -3413,8 +3422,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
// save checkpoint during prompt processing
|
||||
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
if (slot.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
} else {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
}
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
@@ -3459,12 +3473,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
if (params_base.do_checkpoint) {
|
||||
// create checkpoint after prompt processing ends
|
||||
if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
// create checkpoint during generation
|
||||
if (slot.n_decoded > 1) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ struct server_slot {
|
||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
||||
|
||||
size_t checkpoint_pos = 0;
|
||||
bool do_checkpoint = false;
|
||||
|
||||
// sampling
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
|
||||
Reference in New Issue
Block a user