mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-20 14:39:45 +00:00
server : support multi-modal context checkpoints and prompt caching (#1398)
* server : support multi-modal context checkpoints and prompt caching do not create checkpoint right after image processing improve mtmd check for slot ops fix context shift do not abort if template parse failed * change to debug message when detecting ban token --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -420,7 +420,7 @@ struct gpt_params {
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
bool do_checkpoint = false; // do checkpoint for recurrent models only
|
||||
int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot
|
||||
int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot
|
||||
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
|
||||
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
@@ -101,7 +101,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
sequence->back() += *it;
|
||||
auto is_star = *it == '*';
|
||||
++it;
|
||||
if (is_star) {
|
||||
if (it != end && is_star) {
|
||||
if (*it == '?') {
|
||||
++it;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -352,12 +352,19 @@ public:
|
||||
|
||||
server_tokens(const llama_tokens& tokens, bool has_mtmd);
|
||||
|
||||
llama_pos pos_next() const;
|
||||
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
|
||||
llama_pos pos_next(int64_t n_tokens = -1) const;
|
||||
|
||||
// number of tokens with position <= max_pos
|
||||
size_t size_up_to_pos(llama_pos max_pos) const;
|
||||
|
||||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
bool has_mtmd_data() {
|
||||
return !map_idx_to_media.empty();
|
||||
}
|
||||
// for debugging
|
||||
std::string str() const;
|
||||
|
||||
@@ -412,7 +419,7 @@ public:
|
||||
|
||||
size_t get_common_prefix_exact(const server_tokens& b) const;
|
||||
|
||||
llama_tokens get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const;
|
||||
server_tokens get_tokens_exclude_think(const llama_context * ctx, const thinking_tokens & think_token) const;
|
||||
|
||||
common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const;
|
||||
// take first n tokens of tokens list a
|
||||
@@ -431,6 +438,8 @@ public:
|
||||
int32_t seq_id,
|
||||
size_t& n_tokens_out) const;
|
||||
|
||||
server_tokens clone() const;
|
||||
|
||||
// Keep the first n_keep and remove n_discard tokens from tokens
|
||||
void discard_n_tokens(int32_t n_keep, int32_t n_discard);
|
||||
|
||||
|
||||
@@ -87,11 +87,6 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
}
|
||||
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
|
||||
if (params_base.ctx_shift) {
|
||||
params_base.ctx_shift = false;
|
||||
LOG_WARNING("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
|
||||
}
|
||||
|
||||
//if (params.n_cache_reuse) {
|
||||
// params_base.n_cache_reuse = 0;
|
||||
// SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
||||
@@ -298,9 +293,8 @@ void server_context::init() {
|
||||
}
|
||||
catch (const std::exception & e) {
|
||||
SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
|
||||
SRV_ERR("%s: please consider disabling jinja via --no-jinja, or use a custom chat template via --chat-template\n", __func__);
|
||||
SRV_ERR("%s: for example: --no-jinja --chat-template chatml\n", __func__);
|
||||
return;
|
||||
SRV_ERR("%s: please consider enabling jinja via --jinja, or use a custom chat template via --chat-template\n", __func__);
|
||||
SRV_ERR("%s: for example: --chat-template chatml\n", __func__);
|
||||
}
|
||||
|
||||
// thinking is enabled if:
|
||||
@@ -375,6 +369,8 @@ void server_slot::reset() {
|
||||
|
||||
generated_token_probs.clear();
|
||||
checkpoint_pos = 0;
|
||||
image_just_processed = false;
|
||||
do_checkpoint = false;
|
||||
|
||||
positional_bans.clear();
|
||||
ban_phrases.clear();
|
||||
@@ -463,6 +459,7 @@ void server_slot::release() {
|
||||
if (state == SLOT_STATE_PROCESSING) {
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
command = SLOT_COMMAND_RELEASE;
|
||||
state = SLOT_STATE_IDLE;
|
||||
task.reset();
|
||||
llama_decode_reset();
|
||||
}
|
||||
@@ -697,7 +694,7 @@ std::pair<common_prefix, float> server_context::calculate_slot_similarity(const
|
||||
}
|
||||
|
||||
void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
|
||||
slot.server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
slot.server_cached_prompt.tokens = tokens.clone(); // copy cache tokens
|
||||
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
|
||||
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
|
||||
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
|
||||
@@ -722,13 +719,10 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
if (cache_tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
bool exclude_think = !cache_tokens.has_mtmd && slot.params.think_tokens.exclude;
|
||||
std::pair<common_prefix, float> sim;
|
||||
if (exclude_think) {
|
||||
auto temp = slot.cache_tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
||||
server_tokens cache_tokens_exclude_think = server_tokens(temp, false);
|
||||
temp = task.tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
||||
server_tokens prompt_tokens_exclude_think = server_tokens(temp, false);
|
||||
if (slot.params.think_tokens.exclude) {
|
||||
server_tokens cache_tokens_exclude_think = slot.cache_tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
||||
server_tokens prompt_tokens_exclude_think = task.tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
||||
sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
|
||||
}
|
||||
else {
|
||||
@@ -780,13 +774,9 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
float f_keep = 0;
|
||||
size_t cache_token_size = tokens.size();
|
||||
if (!tokens.empty()) {
|
||||
bool exclude_think = !tokens.has_mtmd && ret->params.think_tokens.exclude;
|
||||
if (exclude_think) {
|
||||
auto temp = tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
||||
server_tokens cache_exclude_think = server_tokens(temp, false);
|
||||
|
||||
temp = task.tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
||||
server_tokens prompt_exclude_think = server_tokens(temp, false);
|
||||
if (ret->params.think_tokens.exclude) {
|
||||
server_tokens cache_exclude_think = tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
||||
server_tokens prompt_exclude_think = task.tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
||||
|
||||
cache_token_size = cache_exclude_think.size();
|
||||
f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
|
||||
@@ -807,9 +797,6 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
// don't update the cache if the slot's context is above cache_ram_n_min
|
||||
update_cache = update_cache && cache_token_size >= cache_ram_n_min;
|
||||
|
||||
// TODO: mtmd does not support prompt cache
|
||||
update_cache = update_cache && (ret->mctx == nullptr);
|
||||
|
||||
LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
|
||||
(int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity);
|
||||
if (update_cache) {
|
||||
@@ -829,7 +816,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
ret->prompt_load(*prompt_cache, task.tokens);
|
||||
prompt_cache->update();
|
||||
|
||||
ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens
|
||||
ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens
|
||||
ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt;
|
||||
ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt;
|
||||
|
||||
@@ -1335,11 +1322,14 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
|
||||
params_base.do_checkpoint = do_checkpoint;
|
||||
if (slot.n_buffer != 0) {
|
||||
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
|
||||
LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n");
|
||||
}
|
||||
if (params_base.ctx_shift) {
|
||||
params_base.ctx_shift = false;
|
||||
LOG_WARNING("%s\n", "ctx_shift is not supported by recurrent model, it will be disabled");
|
||||
}
|
||||
}
|
||||
{
|
||||
const auto& stop = data.find("stop");
|
||||
@@ -1713,18 +1703,18 @@ void server_context::send_error(const int id_task, const int id_multi, const std
|
||||
{"error", error},
|
||||
});
|
||||
|
||||
server_task_result res;
|
||||
res.id = id_task;
|
||||
res.id_multi = id_multi;
|
||||
res.stop = false;
|
||||
res.error = true;
|
||||
res.data = format_error_response(error, type);
|
||||
|
||||
queue_results.send(res);
|
||||
auto res = std::make_unique<server_task_result_error>();
|
||||
res->id = id_task;
|
||||
res->id_multi = id_multi;
|
||||
res->stop = false;
|
||||
res->error = true;
|
||||
res->err_type = type;
|
||||
res->err_msg = error;
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
// if multimodal is enabled, send an error and return false
|
||||
bool server_context::ensure_no_mtmd(const int id_task) {
|
||||
bool server_context::check_no_mtmd(const int id_task) {
|
||||
if (mctx) {
|
||||
int id_multi = 0;
|
||||
send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
|
||||
@@ -2127,9 +2117,6 @@ void server_context::process_single_task(server_task&& task) {
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_SLOT_SAVE:
|
||||
{
|
||||
if (!ensure_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
int id_slot = task.data.at("id_slot");
|
||||
server_slot* slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
@@ -2142,7 +2129,9 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
const size_t token_count = slot->cache_tokens.size();
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
@@ -2171,7 +2160,6 @@ void server_context::process_single_task(server_task&& task) {
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
||||
{
|
||||
if (!ensure_no_mtmd(task.id)) break;
|
||||
int id_slot = task.data.at("id_slot");
|
||||
server_slot* slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
@@ -2184,7 +2172,9 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::string filename = task.data.at("filename");
|
||||
@@ -2199,7 +2189,9 @@ void server_context::process_single_task(server_task&& task) {
|
||||
break;
|
||||
}
|
||||
slot->cache_tokens.resize(token_count);
|
||||
|
||||
if (mctx) {
|
||||
slot->cache_tokens.has_mtmd = true;
|
||||
}
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
||||
@@ -2220,7 +2212,6 @@ void server_context::process_single_task(server_task&& task) {
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_SLOT_ERASE:
|
||||
{
|
||||
if (!ensure_no_mtmd(task.id)) break;
|
||||
int id_slot = task.data.at("id_slot");
|
||||
server_slot* slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
@@ -2233,7 +2224,9 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||
@@ -2489,13 +2482,59 @@ void server_context::print_tokens(const server_tokens& prompt, const server_toke
|
||||
}
|
||||
|
||||
void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) {
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
auto kv_keep = slot.cache_tokens.pos_next(n_keep);
|
||||
auto kv_discard = slot.cache_tokens.pos_next(n_keep + n_discard) - kv_keep;
|
||||
auto kv_past = slot.cache_tokens.pos_next(slot.n_past);
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline static bool tokens_support_context_shift(const server_tokens & tokens, int32_t n_keep,
|
||||
int32_t n_discard) {
|
||||
bool can_shift = !tokens.has_mtmd;
|
||||
if (tokens.has_mtmd) {
|
||||
can_shift = true;
|
||||
if (n_keep > 0 && n_keep<= tokens.n_tokens()) {
|
||||
can_shift = tokens[n_keep - 1] != LLAMA_TOKEN_NULL;
|
||||
}
|
||||
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
|
||||
can_shift = can_shift && tokens[n_discard + n_keep - 1] != LLAMA_TOKEN_NULL;
|
||||
}
|
||||
}
|
||||
return can_shift;
|
||||
}
|
||||
|
||||
inline static void adjust_n_to_support_context_shift(const server_tokens & tokens, int32_t & n_keep,
|
||||
int32_t & n_discard) {
|
||||
if (!tokens.has_mtmd) {
|
||||
return;
|
||||
}
|
||||
if (n_keep > 0 && n_keep <= tokens.n_tokens()) {
|
||||
while (tokens[n_keep - 1] == LLAMA_TOKEN_NULL) {
|
||||
n_keep--;
|
||||
if (n_keep<1 || n_keep>tokens.size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
|
||||
while (tokens[n_discard + n_keep - 1] == LLAMA_TOKEN_NULL) {
|
||||
n_discard++;
|
||||
if (n_discard + n_keep<1 || n_discard + n_keep>tokens.size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// convert keep first few and discard next tokens in a to b
|
||||
void server_context::context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
|
||||
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact) {
|
||||
@@ -2519,7 +2558,10 @@ void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot,
|
||||
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
|
||||
const int n_left = slot.n_ctx - n_keep;
|
||||
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||
|
||||
adjust_n_to_support_context_shift(slot.prompt_tokens, n_keep, n_discard);
|
||||
if (n_discard<=0 || !tokens_support_context_shift(slot.prompt_tokens, n_keep, n_discard)) {
|
||||
return;
|
||||
}
|
||||
int n_discard_prompt = 0;
|
||||
// we still need to truncate input since we have not discarded enough tokens
|
||||
while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) {
|
||||
@@ -2598,15 +2640,11 @@ void server_context::context_shift() {
|
||||
if (!params_base.ctx_shift) {
|
||||
// this check is redundant (for good)
|
||||
// we should never get here, because generation should already stopped in process_token()
|
||||
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
||||
slot.print_timings();
|
||||
slot.release();
|
||||
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
||||
continue;
|
||||
}
|
||||
if (mctx) {
|
||||
// we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
|
||||
// we don't support ctx_shift because an image chunk may contains multiple tokens
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
// Shift context
|
||||
int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep;
|
||||
if (add_bos_token) {
|
||||
@@ -2614,11 +2652,12 @@ void server_context::context_shift() {
|
||||
}
|
||||
n_keep = std::min(slot.n_ctx - 4, n_keep);
|
||||
|
||||
const int n_left = (int)system_tokens.size() + slot.n_past - n_keep;
|
||||
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||
const int32_t n_left = (int)system_tokens.size() + slot.n_past - n_keep;
|
||||
int32_t n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||
int32_t n_kept;
|
||||
int32_t n_discard_cache;
|
||||
if (n_discard > 0) {
|
||||
adjust_n_to_support_context_shift(slot.cache_tokens, n_keep, n_discard);
|
||||
if (n_discard > 0 && tokens_support_context_shift(slot.cache_tokens, n_keep, n_discard)) {
|
||||
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
|
||||
n_discard, n_kept, n_discard_cache);
|
||||
LOG_INFO("slot context shift", {
|
||||
@@ -2725,21 +2764,21 @@ void server_context::create_checkpoint_at_interval(server_slot & slot, const gp
|
||||
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
||||
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
||||
create_checkpoint(slot);
|
||||
slot.checkpoint_pos = pos;
|
||||
bool created = create_checkpoint(slot);
|
||||
if (created) {
|
||||
slot.checkpoint_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::apply_checkpoint(server_slot & slot) {
|
||||
const auto pos_min_thold = std::max(0, slot.n_past - 1);
|
||||
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past);
|
||||
const auto pos_min_thold = std::max(0, pos_next - 1);
|
||||
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.cache_tokens.has_mtmd);
|
||||
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
|
||||
|
||||
// search for a context checkpoint
|
||||
@@ -2765,8 +2804,10 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
|
||||
slot.n_past = std::min(slot.n_past, std::max(it->pos_min+1, it->pos_max));
|
||||
slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1);
|
||||
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt+1, it->pos_max_prompt));
|
||||
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1);
|
||||
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
@@ -2794,8 +2835,8 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::create_checkpoint(server_slot & slot) {
|
||||
bool do_checkpoint = true;
|
||||
bool server_context::create_checkpoint(server_slot & slot) {
|
||||
bool do_checkpoint = !slot.image_just_processed;
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
|
||||
@@ -2833,6 +2874,7 @@ void server_context::create_checkpoint(server_slot & slot) {
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
||||
(ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
return do_checkpoint;
|
||||
}
|
||||
|
||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||
@@ -2935,12 +2977,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
if (mctx) {
|
||||
// we should never reach this because params.ctx_shift is automatically disabled if mmproj is loaded
|
||||
// we don't support ctx_shift because an image chunk may contains multiple tokens
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
context_shift_prompt(ctx, slot);
|
||||
slot.truncated = true;
|
||||
LOG_VERBOSE("input truncated", {
|
||||
@@ -3100,7 +3136,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot.n_past += n_tokens_out;
|
||||
slot.n_past_prompt += n_tokens_out;
|
||||
slot.n_prompt_tokens_processed += n_tokens_out;
|
||||
|
||||
slot.image_just_processed = true; // do not checkpoint right after an image chunk
|
||||
}
|
||||
|
||||
|
||||
@@ -3137,7 +3173,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
|
||||
slot_npast++;
|
||||
slot.n_past_prompt++;
|
||||
slot.n_past++;
|
||||
slot.do_checkpoint = false;
|
||||
slot.image_just_processed = false;
|
||||
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
|
||||
slot.do_checkpoint = true;
|
||||
break;
|
||||
@@ -3286,6 +3322,8 @@ void server_context::speculative_decoding_accept() {
|
||||
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
slot.n_past++;
|
||||
send_final_response(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
break;
|
||||
@@ -3338,6 +3376,8 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
||||
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
||||
continue;
|
||||
}
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
slot.n_past++;
|
||||
send_final_response(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
released = true;
|
||||
@@ -3349,6 +3389,8 @@ void server_context::send_token_results(completion_token_outputs& results, serve
|
||||
}
|
||||
|
||||
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
slot.n_past++;
|
||||
send_final_response(slot);
|
||||
release_slot_after_final_response(slot);
|
||||
}
|
||||
@@ -3381,10 +3423,10 @@ inline int32_t check_ban_phrase(server_slot& slot) {
|
||||
if (start != std::string::npos) {
|
||||
if (start < best_start) {
|
||||
best_start = start;
|
||||
found = true;
|
||||
}
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check regex
|
||||
for (const auto& pattern : slot.ban_regex) {
|
||||
@@ -3424,8 +3466,8 @@ inline int32_t check_ban_phrase(server_slot& slot) {
|
||||
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
|
||||
token_idx = (int32_t)i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (token_idx != -1) {
|
||||
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
|
||||
@@ -3449,7 +3491,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
llama_token banned_tok = result->tok;
|
||||
|
||||
if (n == 0) {
|
||||
LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
||||
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
||||
ban_pos, banned_tok, result->text_to_send.c_str());
|
||||
}
|
||||
|
||||
@@ -3462,11 +3504,11 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
}
|
||||
|
||||
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
|
||||
|
||||
|
||||
size_t n_keep_cache = 0;
|
||||
if (ban_pos > 0) {
|
||||
n_keep_cache = (size_t)(ban_pos - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_keep_cache > slot.cache_tokens.size()) {
|
||||
n_keep_cache = slot.cache_tokens.size();
|
||||
@@ -3516,7 +3558,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
||||
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
||||
n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool allow_rewind = true;
|
||||
@@ -3559,16 +3601,16 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
send_token_results(slot.token_buffer, slot, 1);
|
||||
}
|
||||
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||
sent_results = true;
|
||||
}
|
||||
sent_results = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// buffer the result, wait for more tokens to validate string
|
||||
slot.sampled = result.tok;
|
||||
}
|
||||
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
||||
}
|
||||
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
@@ -110,6 +110,7 @@ struct server_slot {
|
||||
|
||||
size_t checkpoint_pos = 0;
|
||||
bool do_checkpoint = false;
|
||||
bool image_just_processed = false;
|
||||
|
||||
// sampling
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
@@ -302,7 +303,7 @@ struct server_context {
|
||||
void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
|
||||
|
||||
// if multimodal is enabled, send an error and return false
|
||||
bool ensure_no_mtmd(const int id_task);
|
||||
bool check_no_mtmd(const int id_task);
|
||||
|
||||
void send_partial_response(server_slot& slot, completion_token_output tkn);
|
||||
|
||||
@@ -363,7 +364,7 @@ struct server_context {
|
||||
// Re-aggregates all active vectors and updates the model state
|
||||
bool apply_control_vectors_internal();
|
||||
|
||||
void create_checkpoint(server_slot & slot);
|
||||
bool create_checkpoint(server_slot & slot);
|
||||
|
||||
void apply_checkpoint(server_slot & slot);
|
||||
|
||||
|
||||
@@ -1081,12 +1081,12 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
|
||||
server_tokens prompt_tokens;
|
||||
server_tokens tokens_new_ex;
|
||||
if (think_tokens.exclude) {
|
||||
prompt_tokens = server_tokens(prompt.tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
|
||||
tokens_new_ex = server_tokens(tokens_new.get_text_tokens_exclude_think(ctx, think_tokens), false);
|
||||
prompt_tokens = prompt.tokens.get_tokens_exclude_think(ctx, think_tokens);
|
||||
tokens_new_ex = tokens_new.get_tokens_exclude_think(ctx, think_tokens);
|
||||
}
|
||||
else {
|
||||
prompt_tokens = std::move(prompt.tokens); //server_tokens(prompt.tokens.get_text_tokens(), false);
|
||||
tokens_new_ex = server_tokens(tokens_new.get_text_tokens(), false);
|
||||
prompt_tokens = std::move(prompt.tokens);
|
||||
tokens_new_ex = tokens_new.clone();
|
||||
}
|
||||
const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex);
|
||||
float f_keep_best = float(lcp_best.second) / prompt_tokens.size();
|
||||
@@ -1099,7 +1099,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
server_tokens tokens;
|
||||
if (think_tokens.exclude) {
|
||||
tokens = server_tokens(it->tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
|
||||
tokens = it->tokens.get_tokens_exclude_think(ctx, think_tokens);
|
||||
}
|
||||
else {
|
||||
tokens = std::move(it->tokens);
|
||||
@@ -1136,7 +1136,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
|
||||
|
||||
server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t state_size) {
|
||||
for (auto it = states.begin(); it != states.end();) {
|
||||
auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens
|
||||
auto tokens_ctx_shift = prompt.tokens.clone(); // copy cache tokens
|
||||
tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt);
|
||||
auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift);
|
||||
const size_t len = prefix.first;
|
||||
@@ -1177,7 +1177,7 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st
|
||||
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
|
||||
auto& cur = states.emplace_back();
|
||||
cur = {
|
||||
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
|
||||
/*.tokens =*/ prompt.tokens.clone(),
|
||||
/*.n_keep =*/ prompt.n_kept_prompt,
|
||||
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
|
||||
/*.think_tokens =*/ prompt.think_tokens,
|
||||
|
||||
@@ -371,6 +371,16 @@ struct server_prompt {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
server_prompt clone() const {
|
||||
return server_prompt{
|
||||
tokens.clone(),
|
||||
n_kept_prompt,
|
||||
n_discarded_prompt,
|
||||
think_tokens,
|
||||
data,
|
||||
checkpoints
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
|
||||
@@ -125,7 +125,7 @@ ggml_cgraph * llm_build_context::build_k_shift() {
|
||||
|
||||
GGML_ASSERT(kv_self.size == n_ctx);
|
||||
|
||||
const auto & rope_type_shift = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
|
||||
const auto & rope_type_shift = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
||||
// @ngxson : this is a workaround
|
||||
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
||||
// a normal RoPE should work, we just need to use the correct ordering
|
||||
|
||||
@@ -4058,12 +4058,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
static bool get_can_shift(struct llama_context & lctx) {
|
||||
bool no_shift = lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA; // not supported due to MLA
|
||||
no_shift = no_shift || lctx.model.hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE;
|
||||
return !no_shift;
|
||||
}
|
||||
|
||||
static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
bool need_reserve = false;
|
||||
|
||||
// apply K-shift if needed
|
||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA) { // not supported due to MLA
|
||||
if (!get_can_shift(lctx)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user