server: add checkpoint tolerance and fix grammar_trigger init (#1346)

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-03-02 00:45:32 -06:00
committed by GitHub
parent a568e12c8f
commit 8f9e19d57c
4 changed files with 27 additions and 4 deletions

View File

@@ -2051,6 +2051,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.ctx_checkpoints_interval = std::stoi(argv[i]);
return true;
}
if (arg == "--ctx-checkpoints-tolerance") {
CHECK_ARG
params.ctx_checkpoints_tolerance = std::stoi(argv[i]);
return true;
}
if (arg == "-cram" || arg == "--cache-ram") {
CHECK_ARG
params.cache_ram_mib = std::stoi(argv[i]);
@@ -2248,6 +2253,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n});
options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval});
options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance});
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });

View File

@@ -281,7 +281,6 @@ struct gpt_params {
int32_t banned_n = 1; // number of tokens that are banned in the phrase
size_t n_buffer = 0; // number of token buffers for string ban
bool can_ban_phrases = true; // whether to ban strings
bool do_checkpoint = false; // do checkpoint for recurrent models only
std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
@@ -420,8 +419,10 @@ struct gpt_params {
float slot_prompt_similarity = 0.1f;
bool do_checkpoint = false; // do checkpoint for recurrent models only
int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens

View File

@@ -1024,6 +1024,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
else if (penalty_prompt->is_array()) {
const auto n_tokens = penalty_prompt->size();
slot.sparams.penalty_prompt_tokens.clear();
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
const int n_vocab = llama_n_vocab(model);
@@ -1067,6 +1068,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
slot.sparams.preserved_tokens.clear();
for (const auto& t : *preserved_tokens) {
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
@@ -1081,6 +1083,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
slot.sparams.grammar_triggers.clear();
for (const auto& t : *grammar_triggers) {
server_grammar_trigger ct(t);
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
@@ -3058,6 +3061,12 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
slot_npast++;
slot.n_past_prompt++;
slot.n_past++;
slot.do_checkpoint = false;
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
slot.do_checkpoint = true;
break;
}
}
LOG_VERBOSE("prompt processing progress", {
{"id_slot", slot.id},
@@ -3413,8 +3422,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
// save checkpoint during prompt processing
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
create_checkpoint_at_interval(slot, params_base);
if (slot.do_checkpoint) {
create_checkpoint(slot);
} else {
create_checkpoint_at_interval(slot, params_base);
}
}
continue; // continue loop of slots
}
@@ -3459,12 +3473,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
if (params_base.do_checkpoint) {
// create checkpoint after prompt processing ends
if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) {
create_checkpoint(slot);
}
}
// save checkpoint during generation
// create checkpoint during generation
if (slot.n_decoded > 1) {
create_checkpoint_at_interval(slot, params_base);
}

View File

@@ -105,6 +105,7 @@ struct server_slot {
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
size_t checkpoint_pos = 0;
bool do_checkpoint = false;
// sampling
llama_token sampled; // in speculative mode, this is the last accepted token