Fix prompt tokenization issue during prompt processing (#1008)

* Find common tokens between prompt and cache
Fix wrong context size usage for mtmd
Use start position of common part
server: handle context shift

* Add size check for inexact match

* Change

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-11-26 03:34:26 -06:00
committed by GitHub
parent 9337229274
commit 0b126b2ca6
2 changed files with 552 additions and 176 deletions

View File

@@ -608,10 +608,11 @@ struct server_prompt_checkpoint {
}
};
struct server_prompt {
server_tokens tokens;
int n_keep;
int n_discarded;
int n_kept_prompt;
int n_discarded_prompt;
std::vector<uint8_t> data;
@@ -633,7 +634,8 @@ struct server_prompt {
};
struct server_prompt_cache {
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
server_prompt_cache(llama_context * ctx,int32_t limit_size_mib, size_t limit_tokens) {
this->ctx = ctx;
this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib);
this->limit_tokens = limit_tokens;
}
@@ -645,7 +647,7 @@ struct server_prompt_cache {
// in tokens, 0 = no limit
size_t limit_tokens = 0;
llama_context* ctx;
size_t size() const {
size_t res = 0;
@@ -662,18 +664,18 @@ struct server_prompt_cache {
for (const auto& state : states) {
res += state.n_tokens();
}
return res;
}
server_prompt* alloc(const server_prompt& prompt, size_t state_size) {
for (auto it = states.begin(); it != states.end();) {
auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens
tokens_ctx_shift.discard_n_tokens(prompt.n_keep, prompt.n_discarded);
const size_t len = it->tokens.get_common_prefix(tokens_ctx_shift);
tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt);
auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift);
const size_t len = prefix.first;
const size_t len_prompt = prefix.second;
// first check if the current state is contained fully in the cache
if (len == tokens_ctx_shift.size()) {
if (len_prompt == tokens_ctx_shift.size()) {
LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n");
return nullptr;
}
@@ -709,8 +711,8 @@ struct server_prompt_cache {
auto& cur = states.emplace_back();
cur = {
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.n_keep =*/ prompt.n_keep,
/*.n_discarded =*/ prompt.n_discarded,
/*.n_keep =*/ prompt.n_kept_prompt,
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};
@@ -719,19 +721,19 @@ struct server_prompt_cache {
}
bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new);
float f_keep_best = float(lcp_best) / prompt.tokens.size();
float sim_best = prompt.tokens.get_tokens_similarity(tokens_new, prompt.n_keep, prompt.n_discarded);
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded = %d\n", f_keep_best, sim_best, prompt.n_keep, prompt.n_discarded);
float f_keep_best = float(lcp_best.second) / prompt.tokens.size();
float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt);
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt);
auto it_best = states.end();
// find the most similar cached prompt, that would also preserve the most context
for (auto it = states.begin(); it != states.end(); ++it) {
const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
const float f_keep_cur = float(lcp_cur) / it->tokens.size();
const float sim_cur = it->tokens.get_tokens_similarity(tokens_new, it->n_keep, it->n_discarded);
const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new);
const float f_keep_cur = float(lcp_cur.first) / it->tokens.size();
const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt);
if (sim_best < sim_cur) {
f_keep_best = f_keep_cur;
sim_best = sim_cur;
@@ -740,7 +742,7 @@ struct server_prompt_cache {
}
if (it_best != states.end()) {
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded = %d\n", f_keep_best, sim_best, it_best->n_keep, it_best->n_discarded);
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
if (n != size) {
@@ -798,7 +800,7 @@ struct server_prompt_cache {
for (const auto& state : states) {
LLAMA_LOG_INFO(" - prompt %p: %7d tokens, %7d discarded, checkpoints: %2zu, %9.3f MiB\n",
(const void*)&state, state.n_tokens(), state.n_discarded, state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
(const void*)&state, state.n_tokens(), state.n_discarded_prompt, state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
}
}
};
@@ -823,10 +825,12 @@ struct server_slot {
// generation props
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
int32_t n_past_prompt = 0;
int32_t n_decoded = 0;
int32_t n_remaining = -1;
int32_t n_discarded = 0;
int32_t n_kept = 0;
int32_t n_discarded_prompt = 0;
int32_t n_kept_prompt = 0;
int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
@@ -989,6 +993,7 @@ struct server_slot {
}
}
json get_formated_timings() const {
return json {
{"prompt_n", n_prompt_tokens_processed},
@@ -1693,7 +1698,7 @@ struct server_context {
}
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
// only apply ram size limit. No token limit for now.
prompt_cache = std::make_unique<server_prompt_cache>(params.cache_ram_mib, 0);
prompt_cache = std::make_unique<server_prompt_cache>(ctx,params.cache_ram_mib, 0);
}
else {
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
@@ -1788,13 +1793,12 @@ struct server_context {
continue;
}
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
// print_tokens(task.tokens, cache_tokens);
size_t lcp_len = cache_tokens.get_common_prefix(task.tokens);
auto lcp_len = cache_tokens.get_common_prefix(slot.ctx,task.tokens);
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
float sim_cur = cache_tokens.get_tokens_similarity(task.tokens, 0, 0);
float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0);
// handle context shift
if (slot.ga_n == 1 && slot.n_discarded > 0 && task.tokens.size()>=slot.n_ctx) {
float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(task.tokens, slot.n_kept, slot.n_discarded);
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size()>=slot.n_ctx) {
float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt);
if (sim_cur_ctx_shift > sim_cur) {
sim_cur = sim_cur_ctx_shift;
}
@@ -1803,7 +1807,7 @@ struct server_context {
// select the current slot if the criteria match
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
sim_best = sim_cur;
max_lcp_len = lcp_len;
max_lcp_len = lcp_len.first;
ret = &slot;
}
}
@@ -1842,11 +1846,11 @@ struct server_context {
const auto& tokens = ret->cache_tokens;
float f_keep = 0.0f;
if (!tokens.empty()) {
if (ret->ga_n == 1 && ret->n_discarded > 0 && task.tokens.size() >= ret->n_ctx) {
f_keep = tokens.get_cached_tokens_similarity(task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded);
if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) {
f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt);
}
else {
f_keep = tokens.get_cached_tokens_similarity(task.tokens, 0, 0);
f_keep = tokens.get_cached_tokens_similarity(ret->ctx,task.tokens, 0, 0);
}
// if we are about to lose a large portion of the existing context - save it in the prompt cache
if (f_keep < cache_ram_similarity) {
@@ -1863,14 +1867,14 @@ struct server_context {
// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);
LLAMA_LOG_INFO("prompt cache: cache size: %d, n_keep: %d, n_discarded: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
(int)tokens.size(), ret->n_kept, ret->n_discarded, cache_ram_n_min, f_keep, cache_ram_similarity);
LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
(int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity);
if (update_cache) {
const int64_t t_start = ggml_time_us();
LLAMA_LOG_INFO("updating prompt cache\n");
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
ret->server_cached_prompt.n_discarded = ret->n_discarded;
ret->server_cached_prompt.n_keep = ret->n_kept;
ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
ret->prompt_save(*prompt_cache);
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
@@ -1879,15 +1883,15 @@ struct server_context {
if (prompt_cache && !prompt_cache->states.empty()) {
const int64_t t_start = ggml_time_us();
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
ret->server_cached_prompt.n_discarded = ret->n_discarded;
ret->server_cached_prompt.n_keep = ret->n_kept;
ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt;
ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt;
ret->prompt_load(*prompt_cache, task.tokens);
prompt_cache->update();
ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens
ret->n_discarded = ret->server_cached_prompt.n_discarded;
ret->n_kept = ret->server_cached_prompt.n_keep;
ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt;
ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt;
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
@@ -3093,9 +3097,14 @@ struct server_context {
queue_results.send(result);
}
void print_tokens(const server_tokens & prompt, const server_tokens& cache) {
LLAMA_LOG_INFO( "prompt: %s\n", prompt.detokenize(ctx, true).c_str());
LLAMA_LOG_INFO( "cache: %s\n", cache.detokenize(ctx, true).c_str());
void print_tokens(const server_tokens & prompt, const server_tokens& cache, size_t start1 = 0, size_t start2=0 , size_t length = 10) {
if (cache.size() > start2) {
LLAMA_LOG_INFO("cache : %s\n", cache.detokenize(ctx, true, start2, length).c_str());
}
if (prompt.size()> start1) {
LLAMA_LOG_INFO("prompt: %s\n", prompt.detokenize(ctx, true, start1, length).c_str());
}
}
void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) {
@@ -3106,6 +3115,60 @@ struct server_context {
}
}
// convert keep first few and discard next tokens in a to b
void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false) {
common_prefix ctx_keep_prefix = a.get_common_prefix_first_n(ctx, b, n_keep, exact);
common_prefix ctx_total_discard_prefix = a.get_common_prefix_first_n(ctx, b, n_discard + n_keep, exact);
// only if there is enough common token
int32_t discard_offset = ctx_total_discard_prefix.first - (n_discard + n_keep);
int32_t keep_offset = ctx_keep_prefix.first - n_keep;
n_kept = ctx_keep_prefix.second - keep_offset;
n_discarded = ctx_total_discard_prefix.second - ctx_keep_prefix.second - discard_offset;
if (n_kept < 0) {
n_kept = n_keep;
}
if (n_discarded < 0) {
n_discarded = n_discard;
}
}
void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false) {
//server_tokens prompt_tokens = std::move(slot.prompt_tokens);
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
const int n_left = slot.n_ctx - n_keep;
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
int n_discard_prompt = 0;
// we still need to truncate input since we have not discarded enough tokens
while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) {
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
n_discard_prompt = n_discard_prompt + n_discard;
}
// Handle mistokenization between prompt and cache during context shift
//
int32_t n_discard_cache = n_discard_prompt;
int32_t n_kept = n_keep;
slot.prompt_tokens.discard_n_tokens(n_keep, slot.n_discarded_prompt - n_discard_prompt);
if (n_discard_prompt > 0) {
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
n_discard, n_kept, n_discard_cache, exact);
}
int n_discard_cache_max = std::max((int32_t)slot.cache_tokens.size() - n_kept, 0);
n_discard_cache = std::min(n_discard_cache, n_discard_cache_max);
// discard matching tokens from cache and kv cache to avoid reprocessing the prompt
if (n_discard_cache > 0) {
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
}
// discard extra tokens from prompts
slot.n_kept_prompt = n_keep;
slot.prompt_tokens.discard_n_tokens(n_keep, n_discard_prompt);
slot.n_prompt_tokens = slot.prompt_tokens.size();
}
void update_slots() {
if (system_need_update) {
system_prompt_update();
@@ -3189,23 +3252,29 @@ struct server_context {
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
int32_t n_kept;
int32_t n_discard_cache;
if (n_discard > 0) {
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
n_discard, n_kept, n_discard_cache);
LOG_INFO("slot context shift", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_keep", n_keep},
{"n_left", n_left},
{"n_discard", n_discard},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}
});
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
slot.n_kept_prompt = n_keep;
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
slot.n_past -= n_discard_cache;
slot.truncated = true;
}
LOG_INFO("slot context shift", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_keep", n_keep},
{"n_left", n_left},
{"n_discard", n_discard},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}
});
slot.n_discarded = slot.n_discarded + n_discard;
slot.n_kept = n_keep;
discard_n_kv_and_cache_tokens(ctx, slot, n_keep, n_discard);
slot.n_past -= n_discard;
slot.truncated = true;
}
}
}
@@ -3229,7 +3298,7 @@ struct server_context {
// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.cache_tokens.pos_next(), { slot.id }, true);
slot.n_past += 1;
@@ -3355,43 +3424,39 @@ struct server_context {
slot.release();
continue;
}
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
const int n_left = slot.n_ctx - n_keep;
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
int n_discard_cache = 0;
// we still need to truncate input since we have not discarded enough tokens
while (slot.n_prompt_tokens - slot.n_discarded >= slot.n_ctx) {
slot.n_discarded = slot.n_discarded + n_discard;
n_discard_cache = n_discard_cache + n_discard;
if (mctx) {
// we should never reach this because params.ctx_shift is automatically disabled if mmproj is loaded
// we don't support ctx_shift because an image chunk may contains multiple tokens
GGML_ABORT("not supported by multimodal");
}
int n_discard_cache_max = std::max((int)slot.cache_tokens.size() - n_keep, 0);
n_discard_cache = std::min(n_discard_cache, n_discard_cache_max);
// discard matching tokens from cache and kv cache to avoid reprocessing the prompt
if (n_discard_cache > 0) {
discard_n_kv_and_cache_tokens(ctx, slot, n_keep, n_discard_cache);
}
// discard extra tokens from prompts
n_discard = slot.n_discarded;
slot.n_kept = n_keep;
prompt_tokens.discard_n_tokens(n_keep, n_discard);
context_shift_prompt(ctx, slot);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("input truncated", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_left", slot.n_ctx- slot.params.n_keep},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
});
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
//print_tokens(prompt_tokens, slot.cache_tokens);
#ifndef NDEBUG
// debug
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
int32_t back = 1;
if (slot.cache_tokens.size() && slot.cache_tokens.size() > prefix.first+20
&& prefix.second >= back && prefix.first >= back) {
LLAMA_LOG_INFO("After context shift :\n");
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 50);
}
#endif
}
else {
slot.n_discarded = 0;
slot.n_discarded_prompt = 0;
}
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
@@ -3402,7 +3467,28 @@ struct server_context {
GGML_ASSERT(slot.ga_n == 1);
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match
common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match
LLAMA_LOG_INFO("======== Cache: cache_size = %ld, n_past0 = %ld, n_past1 = %ld, n_past_prompt1 = %ld, n_past2 = %ld, n_past_prompt2 = %ld\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second);
int32_t size_threshold = 20;
if (prefix.first + size_threshold < prefix_nonexact.first) {
LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n");
prefix = prefix_nonexact;
}
slot.n_past = prefix.first;
slot.n_past_prompt = prefix.second;
if (slot.n_past != slot.n_past_prompt) {
LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
}
if ((slot.n_past + size_threshold < slot.cache_tokens.size()))
{
LLAMA_LOG_WARN("Common part does not match fully\n");
int32_t back = 4;
if (prefix.second >= back && prefix.first >= back) {
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 30);
}
}
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
@@ -3411,13 +3497,14 @@ struct server_context {
}
}
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) {
// we have to evaluate at least 1 token to generate logits.
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task }
});
slot.n_past_prompt--;
slot.n_past--;
if (slot.ga_i > 0) {
slot.n_past_se--;
@@ -3443,7 +3530,10 @@ struct server_context {
}
// keep only the common part
// remove the non-common part from the cache
slot.cache_tokens.keep_first(slot.n_past);
int p0 = (int) system_tokens.size() + slot.n_past;
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
@@ -3462,9 +3552,6 @@ struct server_context {
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
}
// remove the non-common part from the cache
slot.cache_tokens.keep_first(slot.n_past);
LOG_INFO("kv cache rm [p0, end)", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
@@ -3472,13 +3559,12 @@ struct server_context {
});
// check if we should process the image
if (slot.n_past < slot.n_prompt_tokens
&& slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
if (slot.n_past_prompt < slot.n_prompt_tokens
&& slot.prompt_tokens[slot.n_past_prompt] == LLAMA_TOKEN_NULL) {
// process the image
int32_t new_n_past;
size_t new_n_tokens;
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past, new_n_tokens);
int32_t n_pos = new_n_past - slot.n_past;
size_t n_tokens_out = 0;
llama_pos p1 = slot.cache_tokens.pos_next()+slot.n_past_prompt-slot.n_past; // add offset to prompt
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out);
if (res != 0) {
LLAMA_LOG_ERROR("failed to process image, res = %d\n", res);
slot.release();
@@ -3488,12 +3574,14 @@ struct server_context {
// add the image chunk to cache
{
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past);
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past_prompt);
slot.cache_tokens.push_back(chunk.get()); // copy
}
slot.n_past += n_pos;
slot.n_prompt_tokens_processed += new_n_tokens;
slot.n_past += n_tokens_out;
slot.n_past_prompt += n_tokens_out;
slot.n_prompt_tokens_processed += n_tokens_out;
}
@@ -3506,9 +3594,9 @@ struct server_context {
// add prompt tokens for processing in the current batch
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past_prompt < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
llama_token cur_tok = slot.prompt_tokens[slot.n_past];
llama_token cur_tok = slot.prompt_tokens[slot.n_past_prompt];
if (cur_tok == LLAMA_TOKEN_NULL) {
break; // end of text chunk
}
@@ -3520,13 +3608,15 @@ struct server_context {
}
}
llama_batch_add(batch, cur_tok, system_tokens.size() + slot_npast, { slot.id }, false);
{
slot.cache_tokens.push_back(cur_tok);
}
int p0=system_tokens.size() + slot.cache_tokens.pos_next();
llama_batch_add(batch, cur_tok, p0, { slot.id }, false);
slot.cache_tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
slot_npast++;
slot.n_past_prompt++;
slot.n_past++;
}
LOG_VERBOSE("prompt processing progress", {
@@ -3538,7 +3628,7 @@ struct server_context {
});
// entire prompt has been processed - start decoding new tokens
if (slot.n_past == slot.n_prompt_tokens) {
if (slot.n_past_prompt == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
@@ -3643,11 +3733,13 @@ struct server_context {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release();
LLAMA_LOG_INFO("n_past =% d\n", slot.cache_tokens.size());
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
break; // break loop of n_batch
}
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
@@ -3775,10 +3867,10 @@ struct server_context {
// construct the speculation batch
llama_batch_clear(slot.batch_spec);
llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id }, true);
llama_batch_add(slot.batch_spec, id, slot.cache_tokens.pos_next(), { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
llama_batch_add(slot.batch_spec, draft[i], slot.cache_tokens.pos_next() + 1 + i, { slot.id }, true);
}
LOG_VERBOSE("decoding speculative batch", {