server : support multi-modal context checkpoints and prompt caching (#1398)

* server : support multi-modal context checkpoints and prompt caching

do not create checkpoint right after image processing

improve mtmd check for slot ops

fix context shift

do not abort if template parse failed

* change to debug message when detecting ban token

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-03-13 02:07:57 -05:00
committed by GitHub
parent d2141b802b
commit 433531ddae
10 changed files with 741 additions and 593 deletions

View File

@@ -420,7 +420,7 @@ struct gpt_params {
float slot_prompt_similarity = 0.1f;
bool do_checkpoint = false; // do checkpoint for recurrent models only
int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot
int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot
int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.

View File

@@ -101,7 +101,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
sequence->back() += *it;
auto is_star = *it == '*';
++it;
if (is_star) {
if (it != end && is_star) {
if (*it == '?') {
++it;
}

File diff suppressed because it is too large Load Diff

View File

@@ -352,12 +352,19 @@ public:
server_tokens(const llama_tokens& tokens, bool has_mtmd);
llama_pos pos_next() const;
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
llama_pos pos_next(int64_t n_tokens = -1) const;
// number of tokens with position <= max_pos
size_t size_up_to_pos(llama_pos max_pos) const;
int n_tokens() const {
return tokens.size();
}
bool has_mtmd_data() {
return !map_idx_to_media.empty();
}
// for debugging
std::string str() const;
@@ -412,7 +419,7 @@ public:
size_t get_common_prefix_exact(const server_tokens& b) const;
llama_tokens get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const;
server_tokens get_tokens_exclude_think(const llama_context * ctx, const thinking_tokens & think_token) const;
common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const;
// take first n tokens of tokens list a
@@ -431,6 +438,8 @@ public:
int32_t seq_id,
size_t& n_tokens_out) const;
server_tokens clone() const;
// Keep the first n_keep and remove n_discard tokens from tokens
void discard_n_tokens(int32_t n_keep, int32_t n_discard);

View File

@@ -87,11 +87,6 @@ bool server_context::load_model(const gpt_params& params_) {
}
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
LOG_WARNING("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
}
//if (params.n_cache_reuse) {
// params_base.n_cache_reuse = 0;
// SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
@@ -298,9 +293,8 @@ void server_context::init() {
}
catch (const std::exception & e) {
SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
SRV_ERR("%s: please consider disabling jinja via --no-jinja, or use a custom chat template via --chat-template\n", __func__);
SRV_ERR("%s: for example: --no-jinja --chat-template chatml\n", __func__);
return;
SRV_ERR("%s: please consider enabling jinja via --jinja, or use a custom chat template via --chat-template\n", __func__);
SRV_ERR("%s: for example: --chat-template chatml\n", __func__);
}
// thinking is enabled if:
@@ -375,6 +369,8 @@ void server_slot::reset() {
generated_token_probs.clear();
checkpoint_pos = 0;
image_just_processed = false;
do_checkpoint = false;
positional_bans.clear();
ban_phrases.clear();
@@ -463,6 +459,7 @@ void server_slot::release() {
if (state == SLOT_STATE_PROCESSING) {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = SLOT_COMMAND_RELEASE;
state = SLOT_STATE_IDLE;
task.reset();
llama_decode_reset();
}
@@ -697,7 +694,7 @@ std::pair<common_prefix, float> server_context::calculate_slot_similarity(const
}
void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
slot.server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
slot.server_cached_prompt.tokens = tokens.clone(); // copy cache tokens
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
@@ -722,13 +719,10 @@ server_slot* server_context::get_available_slot(const server_task& task) {
if (cache_tokens.empty()) {
continue;
}
bool exclude_think = !cache_tokens.has_mtmd && slot.params.think_tokens.exclude;
std::pair<common_prefix, float> sim;
if (exclude_think) {
auto temp = slot.cache_tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
server_tokens cache_tokens_exclude_think = server_tokens(temp, false);
temp = task.tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
server_tokens prompt_tokens_exclude_think = server_tokens(temp, false);
if (slot.params.think_tokens.exclude) {
server_tokens cache_tokens_exclude_think = slot.cache_tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
server_tokens prompt_tokens_exclude_think = task.tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
}
else {
@@ -780,13 +774,9 @@ server_slot* server_context::get_available_slot(const server_task& task) {
float f_keep = 0;
size_t cache_token_size = tokens.size();
if (!tokens.empty()) {
bool exclude_think = !tokens.has_mtmd && ret->params.think_tokens.exclude;
if (exclude_think) {
auto temp = tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
server_tokens cache_exclude_think = server_tokens(temp, false);
temp = task.tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
server_tokens prompt_exclude_think = server_tokens(temp, false);
if (ret->params.think_tokens.exclude) {
server_tokens cache_exclude_think = tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
server_tokens prompt_exclude_think = task.tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
cache_token_size = cache_exclude_think.size();
f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
@@ -807,9 +797,6 @@ server_slot* server_context::get_available_slot(const server_task& task) {
// don't update the cache if the slot's context is above cache_ram_n_min
update_cache = update_cache && cache_token_size >= cache_ram_n_min;
// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);
LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
(int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity);
if (update_cache) {
@@ -829,7 +816,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
ret->prompt_load(*prompt_cache, task.tokens);
prompt_cache->update();
ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens
ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens
ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt;
ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt;
@@ -1335,11 +1322,14 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
// do_checkpoint = do_checkpoint && llama_model_has_recurrent(model);
params_base.do_checkpoint = do_checkpoint;
if (slot.n_buffer != 0) {
LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n");
LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n");
}
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
LOG_WARNING("%s\n", "ctx_shift is not supported by recurrent model, it will be disabled");
}
}
{
const auto& stop = data.find("stop");
@@ -1713,18 +1703,18 @@ void server_context::send_error(const int id_task, const int id_multi, const std
{"error", error},
});
server_task_result res;
res.id = id_task;
res.id_multi = id_multi;
res.stop = false;
res.error = true;
res.data = format_error_response(error, type);
queue_results.send(res);
auto res = std::make_unique<server_task_result_error>();
res->id = id_task;
res->id_multi = id_multi;
res->stop = false;
res->error = true;
res->err_type = type;
res->err_msg = error;
queue_results.send(std::move(res));
}
// if multimodal is enabled, send an error and return false
bool server_context::ensure_no_mtmd(const int id_task) {
bool server_context::check_no_mtmd(const int id_task) {
if (mctx) {
int id_multi = 0;
send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
@@ -2127,9 +2117,6 @@ void server_context::process_single_task(server_task&& task) {
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
if (!ensure_no_mtmd(task.id)) {
break;
}
int id_slot = task.data.at("id_slot");
server_slot* slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
@@ -2142,7 +2129,9 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
@@ -2171,7 +2160,6 @@ void server_context::process_single_task(server_task&& task) {
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
if (!ensure_no_mtmd(task.id)) break;
int id_slot = task.data.at("id_slot");
server_slot* slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
@@ -2184,7 +2172,9 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename");
@@ -2199,7 +2189,9 @@ void server_context::process_single_task(server_task&& task) {
break;
}
slot->cache_tokens.resize(token_count);
if (mctx) {
slot->cache_tokens.has_mtmd = true;
}
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
@@ -2220,7 +2212,6 @@ void server_context::process_single_task(server_task&& task) {
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
if (!ensure_no_mtmd(task.id)) break;
int id_slot = task.data.at("id_slot");
server_slot* slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
@@ -2233,7 +2224,9 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
@@ -2489,13 +2482,59 @@ void server_context::print_tokens(const server_tokens& prompt, const server_toke
}
void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) {
llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
auto kv_keep = slot.cache_tokens.pos_next(n_keep);
auto kv_discard = slot.cache_tokens.pos_next(n_keep + n_discard) - kv_keep;
auto kv_past = slot.cache_tokens.pos_next(slot.n_past);
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
if (slot.params.cache_prompt) {
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
}
}
inline static bool tokens_support_context_shift(const server_tokens & tokens, int32_t n_keep,
int32_t n_discard) {
bool can_shift = !tokens.has_mtmd;
if (tokens.has_mtmd) {
can_shift = true;
if (n_keep > 0 && n_keep<= tokens.n_tokens()) {
can_shift = tokens[n_keep - 1] != LLAMA_TOKEN_NULL;
}
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
can_shift = can_shift && tokens[n_discard + n_keep - 1] != LLAMA_TOKEN_NULL;
}
}
return can_shift;
}
inline static void adjust_n_to_support_context_shift(const server_tokens & tokens, int32_t & n_keep,
int32_t & n_discard) {
if (!tokens.has_mtmd) {
return;
}
if (n_keep > 0 && n_keep <= tokens.n_tokens()) {
while (tokens[n_keep - 1] == LLAMA_TOKEN_NULL) {
n_keep--;
if (n_keep<1 || n_keep>tokens.size()) {
break;
}
}
}
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
while (tokens[n_discard + n_keep - 1] == LLAMA_TOKEN_NULL) {
n_discard++;
if (n_discard + n_keep<1 || n_discard + n_keep>tokens.size()) {
break;
}
}
}
}
// convert keep first few and discard next tokens in a to b
void server_context::context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact) {
@@ -2519,7 +2558,10 @@ void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot,
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
const int n_left = slot.n_ctx - n_keep;
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
adjust_n_to_support_context_shift(slot.prompt_tokens, n_keep, n_discard);
if (n_discard<=0 || !tokens_support_context_shift(slot.prompt_tokens, n_keep, n_discard)) {
return;
}
int n_discard_prompt = 0;
// we still need to truncate input since we have not discarded enough tokens
while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) {
@@ -2598,15 +2640,11 @@ void server_context::context_shift() {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
slot.print_timings();
slot.release();
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
continue;
}
if (mctx) {
// we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
// we don't support ctx_shift because an image chunk may contains multiple tokens
GGML_ABORT("not supported by multimodal");
}
// Shift context
int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep;
if (add_bos_token) {
@@ -2614,11 +2652,12 @@ void server_context::context_shift() {
}
n_keep = std::min(slot.n_ctx - 4, n_keep);
const int n_left = (int)system_tokens.size() + slot.n_past - n_keep;
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
const int32_t n_left = (int)system_tokens.size() + slot.n_past - n_keep;
int32_t n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
int32_t n_kept;
int32_t n_discard_cache;
if (n_discard > 0) {
adjust_n_to_support_context_shift(slot.cache_tokens, n_keep, n_discard);
if (n_discard > 0 && tokens_support_context_shift(slot.cache_tokens, n_keep, n_discard)) {
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
n_discard, n_kept, n_discard_cache);
LOG_INFO("slot context shift", {
@@ -2725,21 +2764,21 @@ void server_context::create_checkpoint_at_interval(server_slot & slot, const gp
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
create_checkpoint(slot);
slot.checkpoint_pos = pos;
bool created = create_checkpoint(slot);
if (created) {
slot.checkpoint_pos = pos;
}
}
}
}
void server_context::apply_checkpoint(server_slot & slot) {
const auto pos_min_thold = std::max(0, slot.n_past - 1);
if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past);
const auto pos_min_thold = std::max(0, pos_next - 1);
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
if (pos_min > pos_min_thold) {
// TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.cache_tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
// search for a context checkpoint
@@ -2765,8 +2804,10 @@ void server_context::apply_checkpoint(server_slot & slot) {
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
slot.n_past = std::min(slot.n_past, std::max(it->pos_min+1, it->pos_max));
slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1);
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt+1, it->pos_max_prompt));
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1);
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
}
}
@@ -2794,8 +2835,8 @@ void server_context::apply_checkpoint(server_slot & slot) {
}
}
void server_context::create_checkpoint(server_slot & slot) {
bool do_checkpoint = true;
bool server_context::create_checkpoint(server_slot & slot) {
bool do_checkpoint = !slot.image_just_processed;
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
@@ -2833,6 +2874,7 @@ void server_context::create_checkpoint(server_slot & slot) {
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
(ggml_time_us() - t_start) / 1000.0);
}
return do_checkpoint;
}
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
@@ -2935,12 +2977,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
slot.release();
continue;
}
if (mctx) {
// we should never reach this because params.ctx_shift is automatically disabled if mmproj is loaded
// we don't support ctx_shift because an image chunk may contains multiple tokens
GGML_ABORT("not supported by multimodal");
}
context_shift_prompt(ctx, slot);
slot.truncated = true;
LOG_VERBOSE("input truncated", {
@@ -3100,7 +3136,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
slot.n_past += n_tokens_out;
slot.n_past_prompt += n_tokens_out;
slot.n_prompt_tokens_processed += n_tokens_out;
slot.image_just_processed = true; // do not checkpoint right after an image chunk
}
@@ -3137,7 +3173,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t
slot_npast++;
slot.n_past_prompt++;
slot.n_past++;
slot.do_checkpoint = false;
slot.image_just_processed = false;
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
slot.do_checkpoint = true;
break;
@@ -3286,6 +3322,8 @@ void server_context::speculative_decoding_accept() {
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.cache_tokens.push_back(slot.sampled);
slot.n_past++;
send_final_response(slot);
release_slot_after_final_response(slot);
break;
@@ -3338,6 +3376,8 @@ void server_context::send_token_results(completion_token_outputs& results, serve
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
continue;
}
slot.cache_tokens.push_back(slot.sampled);
slot.n_past++;
send_final_response(slot);
release_slot_after_final_response(slot);
released = true;
@@ -3349,6 +3389,8 @@ void server_context::send_token_results(completion_token_outputs& results, serve
}
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
slot.cache_tokens.push_back(slot.sampled);
slot.n_past++;
send_final_response(slot);
release_slot_after_final_response(slot);
}
@@ -3381,10 +3423,10 @@ inline int32_t check_ban_phrase(server_slot& slot) {
if (start != std::string::npos) {
if (start < best_start) {
best_start = start;
found = true;
}
found = true;
}
}
}
// 2. Check regex
for (const auto& pattern : slot.ban_regex) {
@@ -3424,8 +3466,8 @@ inline int32_t check_ban_phrase(server_slot& slot) {
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
token_idx = (int32_t)i;
break;
}
}
}
}
if (token_idx != -1) {
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
@@ -3449,7 +3491,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
llama_token banned_tok = result->tok;
if (n == 0) {
LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
ban_pos, banned_tok, result->text_to_send.c_str());
}
@@ -3462,11 +3504,11 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
}
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
size_t n_keep_cache = 0;
if (ban_pos > 0) {
n_keep_cache = (size_t)(ban_pos - 1);
}
}
if (n_keep_cache > slot.cache_tokens.size()) {
n_keep_cache = slot.cache_tokens.size();
@@ -3516,7 +3558,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
if (n_keep_buffer < 0) n_keep_buffer = 0;
n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer;
}
}
}
bool allow_rewind = true;
@@ -3559,16 +3601,16 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
send_token_results(slot.token_buffer, slot, 1);
}
if (slot.sparams.adaptive_target >= 0.0f) {
sent_results = true;
}
sent_results = true;
}
}
else {
// buffer the result, wait for more tokens to validate string
slot.sampled = result.tok;
}
if (slot.sparams.adaptive_target >= 0.0f) {
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
}
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
}
}
void server_context::process_batch_tokens(int32_t & n_batch) {

View File

@@ -110,6 +110,7 @@ struct server_slot {
size_t checkpoint_pos = 0;
bool do_checkpoint = false;
bool image_just_processed = false;
// sampling
llama_token sampled; // in speculative mode, this is the last accepted token
@@ -302,7 +303,7 @@ struct server_context {
void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
// if multimodal is enabled, send an error and return false
bool ensure_no_mtmd(const int id_task);
bool check_no_mtmd(const int id_task);
void send_partial_response(server_slot& slot, completion_token_output tkn);
@@ -363,7 +364,7 @@ struct server_context {
// Re-aggregates all active vectors and updates the model state
bool apply_control_vectors_internal();
void create_checkpoint(server_slot & slot);
bool create_checkpoint(server_slot & slot);
void apply_checkpoint(server_slot & slot);

View File

@@ -1081,12 +1081,12 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
server_tokens prompt_tokens;
server_tokens tokens_new_ex;
if (think_tokens.exclude) {
prompt_tokens = server_tokens(prompt.tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
tokens_new_ex = server_tokens(tokens_new.get_text_tokens_exclude_think(ctx, think_tokens), false);
prompt_tokens = prompt.tokens.get_tokens_exclude_think(ctx, think_tokens);
tokens_new_ex = tokens_new.get_tokens_exclude_think(ctx, think_tokens);
}
else {
prompt_tokens = std::move(prompt.tokens); //server_tokens(prompt.tokens.get_text_tokens(), false);
tokens_new_ex = server_tokens(tokens_new.get_text_tokens(), false);
prompt_tokens = std::move(prompt.tokens);
tokens_new_ex = tokens_new.clone();
}
const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex);
float f_keep_best = float(lcp_best.second) / prompt_tokens.size();
@@ -1099,7 +1099,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
for (auto it = states.begin(); it != states.end(); ++it) {
server_tokens tokens;
if (think_tokens.exclude) {
tokens = server_tokens(it->tokens.get_text_tokens_exclude_think(ctx, think_tokens), false);
tokens = it->tokens.get_tokens_exclude_think(ctx, think_tokens);
}
else {
tokens = std::move(it->tokens);
@@ -1136,7 +1136,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t state_size) {
for (auto it = states.begin(); it != states.end();) {
auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens
auto tokens_ctx_shift = prompt.tokens.clone(); // copy cache tokens
tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt);
auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift);
const size_t len = prefix.first;
@@ -1177,7 +1177,7 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
auto& cur = states.emplace_back();
cur = {
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.tokens =*/ prompt.tokens.clone(),
/*.n_keep =*/ prompt.n_kept_prompt,
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
/*.think_tokens =*/ prompt.think_tokens,

View File

@@ -371,6 +371,16 @@ struct server_prompt {
return tokens.size();
}
server_prompt clone() const {
return server_prompt{
tokens.clone(),
n_kept_prompt,
n_discarded_prompt,
think_tokens,
data,
checkpoints
};
}
};
struct server_prompt_cache {

View File

@@ -125,7 +125,7 @@ ggml_cgraph * llm_build_context::build_k_shift() {
GGML_ASSERT(kv_self.size == n_ctx);
const auto & rope_type_shift = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
const auto & rope_type_shift = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
// @ngxson : this is a workaround
// for M-RoPE, we want to rotate the whole vector when doing KV shift
// a normal RoPE should work, we just need to use the correct ordering

View File

@@ -4058,12 +4058,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
}
static bool get_can_shift(struct llama_context & lctx) {
bool no_shift = lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA; // not supported due to MLA
no_shift = no_shift || lctx.model.hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE;
return !no_shift;
}
static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
bool need_reserve = false;
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA) { // not supported due to MLA
if (!get_can_shift(lctx)) {
return 1;
}