mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 18:10:02 +00:00
server: add checkpoint tolerance and fix grammar_trigger init (#1346)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -1024,6 +1024,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
else if (penalty_prompt->is_array()) {
|
||||
const auto n_tokens = penalty_prompt->size();
|
||||
slot.sparams.penalty_prompt_tokens.clear();
|
||||
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
||||
|
||||
const int n_vocab = llama_n_vocab(model);
|
||||
@@ -1067,6 +1068,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
|
||||
const auto preserved_tokens = data.find("preserved_tokens");
|
||||
if (preserved_tokens != data.end()) {
|
||||
slot.sparams.preserved_tokens.clear();
|
||||
for (const auto& t : *preserved_tokens) {
|
||||
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
@@ -1081,6 +1083,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
slot.sparams.grammar_triggers.clear();
|
||||
for (const auto& t : *grammar_triggers) {
|
||||
server_grammar_trigger ct(t);
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
@@ -3058,6 +3061,12 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot_npast++;
|
||||
slot.n_past_prompt++;
|
||||
slot.n_past++;
|
||||
slot.do_checkpoint = false;
|
||||
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
|
||||
slot.do_checkpoint = true;
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
LOG_VERBOSE("prompt processing progress", {
|
||||
{"id_slot", slot.id},
|
||||
@@ -3413,8 +3422,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
// save checkpoint during prompt processing
|
||||
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
if (slot.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
} else {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
}
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
@@ -3459,12 +3473,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
if (params_base.do_checkpoint) {
|
||||
// create checkpoint after prompt processing ends
|
||||
if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) {
|
||||
create_checkpoint(slot);
|
||||
}
|
||||
}
|
||||
|
||||
// save checkpoint during generation
|
||||
// create checkpoint during generation
|
||||
if (slot.n_decoded > 1) {
|
||||
create_checkpoint_at_interval(slot, params_base);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user