diff --git a/common/common.cpp b/common/common.cpp index 9767b9d7..942379d6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2247,7 +2247,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx }); options.push_back({ "*", "--ctx-checkpoints N", "max number of context checkpoints to create per slot (default: %d)",params.ctx_checkpoints_n}); - options.push_back({ "*", "--ctx-checkpoints-interval N", "number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval}); + options.push_back({ "*", "--ctx-checkpoints-interval N", "minimum number of tokens between each context checkpoint. (default: %d, <=0 disable)",params.ctx_checkpoints_interval}); options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); diff --git a/common/common.h b/common/common.h index da209432..44653a6f 100644 --- a/common/common.h +++ b/common/common.h @@ -280,7 +280,8 @@ struct gpt_params { std::vector ban_phrases; // strings that are banned in generation int32_t banned_n = 1; // number of tokens that are banned in the phrase size_t n_buffer = 0; // number of token buffers for string ban - bool can_ban_phrases = true; // whether to ban strings + bool can_ban_phrases = true; // whether to ban strings + bool do_checkpoint = false; // do checkpoint for recurrent models only std::vector kv_overrides; std::vector tensor_buft_overrides; @@ -420,7 +421,7 @@ struct gpt_params { float slot_prompt_similarity = 0.1f; int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot - int32_t ctx_checkpoints_interval = 0; // number of tokens between each context checkpoints + int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 445890ae..05f66f0c 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -361,7 +361,7 @@ void server_slot::reset() { rewind_status = false; generated_token_probs.clear(); - + checkpoint_pos = 0; // Reset speculative decoding stats n_draft_total = 0; @@ -1248,6 +1248,16 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) } if (llama_model_has_recurrent(llama_get_model(slot.ctx))) { params_base.can_ban_phrases = false; + bool do_checkpoint = params_base.ctx_checkpoints_n > 0; + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION; + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + // do_checkpoint = do_checkpoint && llama_model_has_recurrent(model); + params_base.do_checkpoint = do_checkpoint; if (slot.n_buffer != 0) { LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n"); } @@ -2632,6 +2642,16 @@ void server_context::add_sampled_tokens() { } } +void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) { + if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) { + auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) { + create_checkpoint(slot); + slot.checkpoint_pos = pos; + } + } +} + void server_context::apply_checkpoint(server_slot & slot) { const auto pos_min_thold = std::max(0, slot.n_past - 1); if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) { @@ -2696,20 +2716,7 @@ void server_context::apply_checkpoint(server_slot & slot) { } void server_context::create_checkpoint(server_slot & slot) { - bool do_checkpoint = params_base.ctx_checkpoints_n > 0; - - // make checkpoints only for completion tasks - do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - do_checkpoint = do_checkpoint && llama_model_has_recurrent(model); - if (!do_checkpoint) { - return; - } + bool do_checkpoint = true; int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); @@ -2743,9 +2750,9 @@ void server_context::create_checkpoint(server_slot & slot) { llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_WRN(slot, "created context checkpoint %d of %d took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, - (ggml_time_us() - t_start) / 1000.0, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n", + (int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024, + (ggml_time_us() - t_start) / 1000.0); } } @@ -3221,7 +3228,9 @@ bool server_context::accept_special_token(const server_slot& slot, const llama_ void server_context::release_slot_after_final_response(server_slot & slot) { slot.print_timings(); - create_checkpoint(slot); + if (params_base.do_checkpoint) { + create_checkpoint(slot); + } slot.release(); slot.released = true; metrics.on_prediction(slot); @@ -3404,6 +3413,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) { for (auto& slot : slots) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { + if (slot.command == SLOT_COMMAND_LOAD_PROMPT) { + create_checkpoint_at_interval(slot, params_base); + } continue; // continue loop of slots } @@ -3447,14 +3459,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.t_start_generation = ggml_time_us(); slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); - } - - // save checkpoint during generation - if (params_base.ctx_checkpoints_interval > 0) { - if (slot.n_decoded % params_base.ctx_checkpoints_interval == 0) { + if (params_base.do_checkpoint) { create_checkpoint(slot); } } + + // save checkpoint during generation + if (slot.n_decoded > 1) { + create_checkpoint_at_interval(slot, params_base); + } + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; result.tok = id; @@ -3516,7 +3530,7 @@ void server_context::update_slots() { // apply context-shift if needed // TODO: simplify and improve context_shift(); - + // start populating the batch for this iteration common_batch_clear(batch); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 2ccc7ed6..f6444ffb 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -104,6 +104,8 @@ struct server_slot { void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens); + size_t checkpoint_pos = 0; + // sampling llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; @@ -358,5 +360,7 @@ struct server_context { void apply_checkpoint(server_slot & slot); + void create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base); + void release_slot_after_final_response(server_slot & slot); };