server: add checkpoint tolerance and fix grammar_trigger init (#1346)

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-03-02 00:45:32 -06:00
committed by GitHub
parent a568e12c8f
commit 8f9e19d57c
4 changed files with 27 additions and 4 deletions

View File

@@ -1024,6 +1024,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
else if (penalty_prompt->is_array()) {
const auto n_tokens = penalty_prompt->size();
slot.sparams.penalty_prompt_tokens.clear();
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
const int n_vocab = llama_n_vocab(model);
@@ -1067,6 +1068,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
slot.sparams.preserved_tokens.clear();
for (const auto& t : *preserved_tokens) {
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
@@ -1081,6 +1083,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
slot.sparams.grammar_triggers.clear();
for (const auto& t : *grammar_triggers) {
server_grammar_trigger ct(t);
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
@@ -3058,6 +3061,12 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
slot_npast++;
slot.n_past_prompt++;
slot.n_past++;
slot.do_checkpoint = false;
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
slot.do_checkpoint = true;
break;
}
}
LOG_VERBOSE("prompt processing progress", {
{"id_slot", slot.id},
@@ -3413,8 +3422,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
// save checkpoint during prompt processing
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
create_checkpoint_at_interval(slot, params_base);
if (slot.do_checkpoint) {
create_checkpoint(slot);
} else {
create_checkpoint_at_interval(slot, params_base);
}
}
continue; // continue loop of slots
}
@@ -3459,12 +3473,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
if (params_base.do_checkpoint) {
// create checkpoint after prompt processing ends
if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) {
create_checkpoint(slot);
}
}
// save checkpoint during generation
// create checkpoint during generation
if (slot.n_decoded > 1) {
create_checkpoint_at_interval(slot, params_base);
}