diff --git a/common/common.cpp b/common/common.cpp index f61cb93b..08249ab7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2051,6 +2051,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.ctx_checkpoints_interval = std::stoi(argv[i]); return true; } + if (arg == "--ctx-checkpoints-tolerance") { + CHECK_ARG + params.ctx_checkpoints_tolerance = std::stoi(argv[i]); + return true; + } if (arg == "-cram" || arg == "--cache-ram") { CHECK_ARG params.cache_ram_mib = std::stoi(argv[i]); @@ -2248,6 +2253,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n}); options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval}); + options.push_back({ "*", "--ctx-checkpoints-tolerance N", "the number of tokens before the full prompt to create the checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_tolerance}); options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); diff --git a/common/common.h b/common/common.h index 44653a6f..b6773d60 100644 --- a/common/common.h +++ b/common/common.h @@ -281,7 +281,6 @@ struct gpt_params { int32_t banned_n = 1; // number of tokens that are banned in the phrase size_t n_buffer = 0; // number of token buffers for string ban bool can_ban_phrases = true; // whether to ban strings - bool do_checkpoint = false; // do checkpoint for recurrent models only std::vector kv_overrides; std::vector tensor_buft_overrides; @@ -420,8 +419,10 @@ struct gpt_params { float slot_prompt_similarity = 0.1f; + bool do_checkpoint = false; // do checkpoint for recurrent models only int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints + int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 05f66f0c..da0f4965 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -1024,6 +1024,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) } else if (penalty_prompt->is_array()) { const auto n_tokens = penalty_prompt->size(); + slot.sparams.penalty_prompt_tokens.clear(); slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); const int n_vocab = llama_n_vocab(model); @@ -1067,6 +1068,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) const auto preserved_tokens = data.find("preserved_tokens"); if (preserved_tokens != data.end()) { + slot.sparams.preserved_tokens.clear(); for (const auto& t : *preserved_tokens) { auto ids = common_tokenize(model, t.get(), /* add_special= */ false, /* parse_special= */ true); if (ids.size() == 1) { @@ -1081,6 +1083,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) } const auto grammar_triggers = data.find("grammar_triggers"); if (grammar_triggers != data.end()) { + slot.sparams.grammar_triggers.clear(); for (const auto& t : *grammar_triggers) { server_grammar_trigger ct(t); if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { @@ -3058,6 +3061,12 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot_npast++; slot.n_past_prompt++; slot.n_past++; + slot.do_checkpoint = false; + if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) { + slot.do_checkpoint = true; + break; + } + } LOG_VERBOSE("prompt processing progress", { {"id_slot", slot.id}, @@ -3413,8 +3422,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) { for (auto& slot : slots) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { + // save checkpoint during prompt processing if (slot.command == SLOT_COMMAND_LOAD_PROMPT) { - create_checkpoint_at_interval(slot, params_base); + if (slot.do_checkpoint) { + create_checkpoint(slot); + } else { + create_checkpoint_at_interval(slot, params_base); + } } continue; // continue loop of slots } @@ -3459,12 +3473,13 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.t_start_generation = ggml_time_us(); slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); - if (params_base.do_checkpoint) { + // create checkpoint after prompt processing ends + if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) { create_checkpoint(slot); } } - // save checkpoint during generation + // create checkpoint during generation if (slot.n_decoded > 1) { create_checkpoint_at_interval(slot, params_base); } diff --git a/examples/server/server-context.h b/examples/server/server-context.h index f6444ffb..ffc71831 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -105,6 +105,7 @@ struct server_slot { void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens); size_t checkpoint_pos = 0; + bool do_checkpoint = false; // sampling llama_token sampled; // in speculative mode, this is the last accepted token