From 0b126b2ca6437b6d0d7828e51f1d08e3c409afb8 Mon Sep 17 00:00:00 2001 From: firecoperana Date: Wed, 26 Nov 2025 03:34:26 -0600 Subject: [PATCH] Fix prompt tokenization issue during prompt processing (#1008) * Find common tokens between prompt and cache Fix wrong context size usage for mtmd Use start position of common part server: handle context shift * Add size check for inexact match * Change --------- Co-authored-by: firecoperana --- examples/server/server.cpp | 304 ++++++++++++++++---------- examples/server/utils.hpp | 424 +++++++++++++++++++++++++++++++------ 2 files changed, 552 insertions(+), 176 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 057a6e72..745e4c76 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -608,10 +608,11 @@ struct server_prompt_checkpoint { } }; + struct server_prompt { server_tokens tokens; - int n_keep; - int n_discarded; + int n_kept_prompt; + int n_discarded_prompt; std::vector data; @@ -633,7 +634,8 @@ struct server_prompt { }; struct server_prompt_cache { - server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + server_prompt_cache(llama_context * ctx,int32_t limit_size_mib, size_t limit_tokens) { + this->ctx = ctx; this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib); this->limit_tokens = limit_tokens; } @@ -645,7 +647,7 @@ struct server_prompt_cache { // in tokens, 0 = no limit size_t limit_tokens = 0; - + llama_context* ctx; size_t size() const { size_t res = 0; @@ -662,18 +664,18 @@ struct server_prompt_cache { for (const auto& state : states) { res += state.n_tokens(); } - return res; } server_prompt* alloc(const server_prompt& prompt, size_t state_size) { for (auto it = states.begin(); it != states.end();) { auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens - tokens_ctx_shift.discard_n_tokens(prompt.n_keep, prompt.n_discarded); - - const size_t len = it->tokens.get_common_prefix(tokens_ctx_shift); + tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt); + auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift); + const size_t len = prefix.first; + const size_t len_prompt = prefix.second; // first check if the current state is contained fully in the cache - if (len == tokens_ctx_shift.size()) { + if (len_prompt == tokens_ctx_shift.size()) { LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n"); return nullptr; } @@ -709,8 +711,8 @@ struct server_prompt_cache { auto& cur = states.emplace_back(); cur = { /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), - /*.n_keep =*/ prompt.n_keep, - /*.n_discarded =*/ prompt.n_discarded, + /*.n_keep =*/ prompt.n_kept_prompt, + /*.n_discarded_prompt =*/ prompt.n_discarded_prompt, /*.data =*/ std::move(state_data), /*.checkpoints =*/ prompt.checkpoints, }; @@ -719,19 +721,19 @@ struct server_prompt_cache { } bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) { - const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new); - float f_keep_best = float(lcp_best) / prompt.tokens.size(); - float sim_best = prompt.tokens.get_tokens_similarity(tokens_new, prompt.n_keep, prompt.n_discarded); - LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded = %d\n", f_keep_best, sim_best, prompt.n_keep, prompt.n_discarded); + float f_keep_best = float(lcp_best.second) / prompt.tokens.size(); + float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt); + LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt); auto it_best = states.end(); // find the most similar cached prompt, that would also preserve the most context for (auto it = states.begin(); it != states.end(); ++it) { - const int lcp_cur = it->tokens.get_common_prefix(tokens_new); - const float f_keep_cur = float(lcp_cur) / it->tokens.size(); - const float sim_cur = it->tokens.get_tokens_similarity(tokens_new, it->n_keep, it->n_discarded); + const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new); + const float f_keep_cur = float(lcp_cur.first) / it->tokens.size(); + const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt); if (sim_best < sim_cur) { f_keep_best = f_keep_cur; sim_best = sim_cur; @@ -740,7 +742,7 @@ struct server_prompt_cache { } if (it_best != states.end()) { - LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded = %d\n", f_keep_best, sim_best, it_best->n_keep, it_best->n_discarded); + LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt); const size_t size = it_best->data.size(); const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot); if (n != size) { @@ -798,7 +800,7 @@ struct server_prompt_cache { for (const auto& state : states) { LLAMA_LOG_INFO(" - prompt %p: %7d tokens, %7d discarded, checkpoints: %2zu, %9.3f MiB\n", - (const void*)&state, state.n_tokens(), state.n_discarded, state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + (const void*)&state, state.n_tokens(), state.n_discarded_prompt, state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); } } }; @@ -823,10 +825,12 @@ struct server_slot { // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; + int32_t n_past_prompt = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; - int32_t n_discarded = 0; - int32_t n_kept = 0; + int32_t n_discarded_prompt = 0; + int32_t n_kept_prompt = 0; + int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict @@ -989,6 +993,7 @@ struct server_slot { } } + json get_formated_timings() const { return json { {"prompt_n", n_prompt_tokens_processed}, @@ -1693,7 +1698,7 @@ struct server_context { } LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n"); // only apply ram size limit. No token limit for now. - prompt_cache = std::make_unique(params.cache_ram_mib, 0); + prompt_cache = std::make_unique(ctx,params.cache_ram_mib, 0); } else { LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); @@ -1788,13 +1793,12 @@ struct server_context { continue; } // length of the Longest Common Prefix between the current slot's prompt and the input prompt - // print_tokens(task.tokens, cache_tokens); - size_t lcp_len = cache_tokens.get_common_prefix(task.tokens); + auto lcp_len = cache_tokens.get_common_prefix(slot.ctx,task.tokens); // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length - float sim_cur = cache_tokens.get_tokens_similarity(task.tokens, 0, 0); + float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0); // handle context shift - if (slot.ga_n == 1 && slot.n_discarded > 0 && task.tokens.size()>=slot.n_ctx) { - float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(task.tokens, slot.n_kept, slot.n_discarded); + if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size()>=slot.n_ctx) { + float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt); if (sim_cur_ctx_shift > sim_cur) { sim_cur = sim_cur_ctx_shift; } @@ -1803,7 +1807,7 @@ struct server_context { // select the current slot if the criteria match if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { sim_best = sim_cur; - max_lcp_len = lcp_len; + max_lcp_len = lcp_len.first; ret = &slot; } } @@ -1842,11 +1846,11 @@ struct server_context { const auto& tokens = ret->cache_tokens; float f_keep = 0.0f; if (!tokens.empty()) { - if (ret->ga_n == 1 && ret->n_discarded > 0 && task.tokens.size() >= ret->n_ctx) { - f_keep = tokens.get_cached_tokens_similarity(task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded); + if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) { + f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt); } else { - f_keep = tokens.get_cached_tokens_similarity(task.tokens, 0, 0); + f_keep = tokens.get_cached_tokens_similarity(ret->ctx,task.tokens, 0, 0); } // if we are about to lose a large portion of the existing context - save it in the prompt cache if (f_keep < cache_ram_similarity) { @@ -1863,14 +1867,14 @@ struct server_context { // TODO: mtmd does not support prompt cache update_cache = update_cache && (ret->mctx == nullptr); - LLAMA_LOG_INFO("prompt cache: cache size: %d, n_keep: %d, n_discarded: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n", - (int)tokens.size(), ret->n_kept, ret->n_discarded, cache_ram_n_min, f_keep, cache_ram_similarity); + LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n", + (int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity); if (update_cache) { const int64_t t_start = ggml_time_us(); LLAMA_LOG_INFO("updating prompt cache\n"); ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - ret->server_cached_prompt.n_discarded = ret->n_discarded; - ret->server_cached_prompt.n_keep = ret->n_kept; + ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; + ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; ret->prompt_save(*prompt_cache); LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); @@ -1879,15 +1883,15 @@ struct server_context { if (prompt_cache && !prompt_cache->states.empty()) { const int64_t t_start = ggml_time_us(); ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - ret->server_cached_prompt.n_discarded = ret->n_discarded; - ret->server_cached_prompt.n_keep = ret->n_kept; + ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; + ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; ret->prompt_load(*prompt_cache, task.tokens); prompt_cache->update(); ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens - ret->n_discarded = ret->server_cached_prompt.n_discarded; - ret->n_kept = ret->server_cached_prompt.n_keep; + ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt; + ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt; LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } @@ -3093,9 +3097,14 @@ struct server_context { queue_results.send(result); } - void print_tokens(const server_tokens & prompt, const server_tokens& cache) { - LLAMA_LOG_INFO( "prompt: %s\n", prompt.detokenize(ctx, true).c_str()); - LLAMA_LOG_INFO( "cache: %s\n", cache.detokenize(ctx, true).c_str()); + void print_tokens(const server_tokens & prompt, const server_tokens& cache, size_t start1 = 0, size_t start2=0 , size_t length = 10) { + if (cache.size() > start2) { + LLAMA_LOG_INFO("cache : %s\n", cache.detokenize(ctx, true, start2, length).c_str()); + } + if (prompt.size()> start1) { + LLAMA_LOG_INFO("prompt: %s\n", prompt.detokenize(ctx, true, start1, length).c_str()); + } + } void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) { @@ -3106,6 +3115,60 @@ struct server_context { } } + // convert keep first few and discard next tokens in a to b + void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep, + int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false) { + + common_prefix ctx_keep_prefix = a.get_common_prefix_first_n(ctx, b, n_keep, exact); + common_prefix ctx_total_discard_prefix = a.get_common_prefix_first_n(ctx, b, n_discard + n_keep, exact); + // only if there is enough common token + int32_t discard_offset = ctx_total_discard_prefix.first - (n_discard + n_keep); + int32_t keep_offset = ctx_keep_prefix.first - n_keep; + n_kept = ctx_keep_prefix.second - keep_offset; + n_discarded = ctx_total_discard_prefix.second - ctx_keep_prefix.second - discard_offset; + if (n_kept < 0) { + n_kept = n_keep; + } + if (n_discarded < 0) { + n_discarded = n_discard; + } + } + + void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false) { + //server_tokens prompt_tokens = std::move(slot.prompt_tokens); + int n_keep = std::max(0, slot.params.n_keep + add_bos_token); + const int n_left = slot.n_ctx - n_keep; + int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + + int n_discard_prompt = 0; + // we still need to truncate input since we have not discarded enough tokens + while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) { + slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard; + n_discard_prompt = n_discard_prompt + n_discard; + } + + // Handle mistokenization between prompt and cache during context shift + // + int32_t n_discard_cache = n_discard_prompt; + int32_t n_kept = n_keep; + slot.prompt_tokens.discard_n_tokens(n_keep, slot.n_discarded_prompt - n_discard_prompt); + if (n_discard_prompt > 0) { + context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep, + n_discard, n_kept, n_discard_cache, exact); + } + + int n_discard_cache_max = std::max((int32_t)slot.cache_tokens.size() - n_kept, 0); + n_discard_cache = std::min(n_discard_cache, n_discard_cache_max); + // discard matching tokens from cache and kv cache to avoid reprocessing the prompt + if (n_discard_cache > 0) { + discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache); + } + // discard extra tokens from prompts + slot.n_kept_prompt = n_keep; + slot.prompt_tokens.discard_n_tokens(n_keep, n_discard_prompt); + slot.n_prompt_tokens = slot.prompt_tokens.size(); + } + void update_slots() { if (system_need_update) { system_prompt_update(); @@ -3189,23 +3252,29 @@ struct server_context { const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + int32_t n_kept; + int32_t n_discard_cache; + if (n_discard > 0) { + context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep, + n_discard, n_kept, n_discard_cache); + LOG_INFO("slot context shift", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_keep", n_keep}, + {"n_left", n_left}, + {"n_discard", n_discard}, + {"n_ctx", n_ctx}, + {"n_past", slot.n_past}, + {"n_system_tokens", system_tokens.size()}, + {"n_cache_tokens", slot.cache_tokens.size()} + }); + slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard; + slot.n_kept_prompt = n_keep; + discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache); + slot.n_past -= n_discard_cache; + slot.truncated = true; + } - LOG_INFO("slot context shift", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_keep", n_keep}, - {"n_left", n_left}, - {"n_discard", n_discard}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()} - }); - slot.n_discarded = slot.n_discarded + n_discard; - slot.n_kept = n_keep; - discard_n_kv_and_cache_tokens(ctx, slot, n_keep, n_discard); - slot.n_past -= n_discard; - slot.truncated = true; } } } @@ -3229,7 +3298,7 @@ struct server_context { // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.cache_tokens.pos_next(), { slot.id }, true); slot.n_past += 1; @@ -3355,43 +3424,39 @@ struct server_context { slot.release(); continue; } - int n_keep = std::max(0, slot.params.n_keep + add_bos_token); - const int n_left = slot.n_ctx - n_keep; - int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - int n_discard_cache = 0; - // we still need to truncate input since we have not discarded enough tokens - while (slot.n_prompt_tokens - slot.n_discarded >= slot.n_ctx) { - slot.n_discarded = slot.n_discarded + n_discard; - n_discard_cache = n_discard_cache + n_discard; + if (mctx) { + // we should never reach this because params.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); } - int n_discard_cache_max = std::max((int)slot.cache_tokens.size() - n_keep, 0); - n_discard_cache = std::min(n_discard_cache, n_discard_cache_max); - // discard matching tokens from cache and kv cache to avoid reprocessing the prompt - if (n_discard_cache > 0) { - discard_n_kv_and_cache_tokens(ctx, slot, n_keep, n_discard_cache); - } - // discard extra tokens from prompts - n_discard = slot.n_discarded; - slot.n_kept = n_keep; - prompt_tokens.discard_n_tokens(n_keep, n_discard); + + context_shift_prompt(ctx, slot); slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); LOG_VERBOSE("input truncated", { {"id_slot", slot.id}, {"id_task", slot.id_task}, {"n_ctx", slot.n_ctx}, {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, + {"n_left", slot.n_ctx- slot.params.n_keep}, {"n_prompt_tokens", slot.n_prompt_tokens}, {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, }); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - //print_tokens(prompt_tokens, slot.cache_tokens); - + +#ifndef NDEBUG + // debug + common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false); + int32_t back = 1; + if (slot.cache_tokens.size() && slot.cache_tokens.size() > prefix.first+20 + && prefix.second >= back && prefix.first >= back) { + LLAMA_LOG_INFO("After context shift :\n"); + print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 50); + } +#endif } else { - slot.n_discarded = 0; + slot.n_discarded_prompt = 0; } llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); @@ -3402,7 +3467,28 @@ struct server_context { GGML_ASSERT(slot.ga_n == 1); // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match + common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false); + auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match + LLAMA_LOG_INFO("======== Cache: cache_size = %ld, n_past0 = %ld, n_past1 = %ld, n_past_prompt1 = %ld, n_past2 = %ld, n_past_prompt2 = %ld\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second); + int32_t size_threshold = 20; + if (prefix.first + size_threshold < prefix_nonexact.first) { + LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n"); + prefix = prefix_nonexact; + } + slot.n_past = prefix.first; + slot.n_past_prompt = prefix.second; + if (slot.n_past != slot.n_past_prompt) { + LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n"); + } + if ((slot.n_past + size_threshold < slot.cache_tokens.size())) + { + LLAMA_LOG_WARN("Common part does not match fully\n"); + int32_t back = 4; + if (prefix.second >= back && prefix.first >= back) { + print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 30); + } + } // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { @@ -3411,13 +3497,14 @@ struct server_context { } } - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) { // we have to evaluate at least 1 token to generate logits. LOG_INFO("we have to evaluate at least 1 token to generate logits", { { "id_slot", slot.id }, { "id_task", slot.id_task } }); + slot.n_past_prompt--; slot.n_past--; if (slot.ga_i > 0) { slot.n_past_se--; @@ -3443,7 +3530,10 @@ struct server_context { } // keep only the common part + // remove the non-common part from the cache + slot.cache_tokens.keep_first(slot.n_past); int p0 = (int) system_tokens.size() + slot.n_past; + p0 = system_tokens.size() + slot.cache_tokens.pos_next(); if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) { // could not partially delete (likely using a non-Transformer model) llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); @@ -3462,9 +3552,6 @@ struct server_context { llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); } - // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); - LOG_INFO("kv cache rm [p0, end)", { { "id_slot", slot.id }, { "id_task", slot.id_task }, @@ -3472,13 +3559,12 @@ struct server_context { }); // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens - && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + if (slot.n_past_prompt < slot.n_prompt_tokens + && slot.prompt_tokens[slot.n_past_prompt] == LLAMA_TOKEN_NULL) { // process the image - int32_t new_n_past; - size_t new_n_tokens; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past, new_n_tokens); - int32_t n_pos = new_n_past - slot.n_past; + size_t n_tokens_out = 0; + llama_pos p1 = slot.cache_tokens.pos_next()+slot.n_past_prompt-slot.n_past; // add offset to prompt + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out); if (res != 0) { LLAMA_LOG_ERROR("failed to process image, res = %d\n", res); slot.release(); @@ -3488,12 +3574,14 @@ struct server_context { // add the image chunk to cache { - const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past); + const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past_prompt); slot.cache_tokens.push_back(chunk.get()); // copy } - slot.n_past += n_pos; - slot.n_prompt_tokens_processed += new_n_tokens; + slot.n_past += n_tokens_out; + slot.n_past_prompt += n_tokens_out; + slot.n_prompt_tokens_processed += n_tokens_out; + } @@ -3506,9 +3594,9 @@ struct server_context { // add prompt tokens for processing in the current batch // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past_prompt < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process - llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + llama_token cur_tok = slot.prompt_tokens[slot.n_past_prompt]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } @@ -3520,13 +3608,15 @@ struct server_context { } } - llama_batch_add(batch, cur_tok, system_tokens.size() + slot_npast, { slot.id }, false); - { - slot.cache_tokens.push_back(cur_tok); - } + int p0=system_tokens.size() + slot.cache_tokens.pos_next(); + llama_batch_add(batch, cur_tok, p0, { slot.id }, false); + + slot.cache_tokens.push_back(cur_tok); + slot.n_prompt_tokens_processed++; slot_npast++; + slot.n_past_prompt++; slot.n_past++; } LOG_VERBOSE("prompt processing progress", { @@ -3538,7 +3628,7 @@ struct server_context { }); // entire prompt has been processed - start decoding new tokens - if (slot.n_past == slot.n_prompt_tokens) { + if (slot.n_past_prompt == slot.n_prompt_tokens) { slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; @@ -3643,11 +3733,13 @@ struct server_context { slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; slot.release(); + LLAMA_LOG_INFO("n_past =% d\n", slot.cache_tokens.size()); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); } break; // break loop of n_batch } + // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; i -= n_batch; @@ -3775,10 +3867,10 @@ struct server_context { // construct the speculation batch llama_batch_clear(slot.batch_spec); - llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id }, true); + llama_batch_add(slot.batch_spec, id, slot.cache_tokens.pos_next(), { slot.id }, true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + llama_batch_add(slot.batch_spec, draft[i], slot.cache_tokens.pos_next() + 1 + i, { slot.id }, true); } LOG_VERBOSE("decoding speculative batch", { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3c4b6dd1..4b4747fd 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include #include "common.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: @@ -14,6 +15,7 @@ #include #include #include +#include // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 @@ -333,6 +335,165 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } +struct common_prefix { + size_t first = 0; + size_t second = 0; +}; + +common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b) { + common_prefix prefix; + prefix.first = a.first + b.first; + prefix.second = a.second + b.second; + return prefix; +} + +common_prefix find_common_string_prefix(const std::string & a_str, const std::string & b_str, const std::set& ignore_set) { + size_t i = 0; + size_t j = 0; + while (i < a_str.size() && j < b_str.size()) { + auto a_chr = a_str[i]; + auto b_chr = b_str[j]; + if (a_chr == b_chr) { + ++i; + ++j; + } + else if (ignore_set.count(a_chr) && ignore_set.count(b_chr)) { + ++i; + ++j; + } + else if (ignore_set.count(a_chr)) { + ++i; + } + else if (ignore_set.count(b_chr)) { + ++j; + } + else { + break; + } + } + common_prefix string_prefix; + string_prefix.first = i; + string_prefix.second = j; + return string_prefix; +} + +size_t find_n_tokens_from_string(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, + std::vector & map) { + size_t n = 0; + size_t string_len = 0; + std::string str; + auto model = llama_get_model(ctx); + for (n = start; n < a.size(); ++n) { + str = llama_token_to_piece(model, a[n], true); + string_len = string_len + str.size(); + if (string_len <= max_size) { + map.push_back(string_len); + } + else { + break; + } + } + return map.size(); +} + +std::string remove_with_set(std::string str, const std::set& chars_to_remove) { + str.erase(std::remove_if(str.begin(), str.end(), + [&chars_to_remove](char c) { return chars_to_remove.find(c) != chars_to_remove.end(); }), + str.end()); + return str; +} + +common_prefix find_largest_common_number(const std::vector& a_list, const std::vector& b_list) { + common_prefix token_prefix; + token_prefix.first = 0; + token_prefix.second = 0; + int i = a_list.size() - 1; // start from end of a + int j = b_list.size() - 1; // start from end of b + if (i < 0 || j < 0) { + return token_prefix; + } + while (i >= 0 && j >= 0) { + if (a_list[i] == b_list[j]) { + // found largest common value + token_prefix.first = (size_t)i + 1; + token_prefix.second = (size_t)j + 1; + break; + } + else if (a_list[i] > b_list[j]) { + --i; + } + else { + --j; + } + } + return token_prefix; +} + +size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, const std::set & ignore_set, + std::vector& map) { + bool use_ignore = ignore_set.size()>0; + size_t n = 0; + size_t string_len = 0; + size_t string_len_ignore = 0; + std::string str; + std::string str_ignore; + auto model = llama_get_model(ctx); + for (n = start; n < a.size(); ++n) { + str = llama_token_to_piece(model, a[n], true); + string_len = string_len + str.size(); + if (use_ignore) { + str_ignore = remove_with_set(str, ignore_set); + } + else { + str_ignore = str; + } + string_len_ignore = string_len_ignore + str_ignore.size(); + if (string_len <= max_size) { + map.push_back(string_len_ignore); + } + else { + break; + } + } + return map.size(); +} + +common_prefix find_common_text_token_prefix(const llama_context * ctx, const llama_tokens & a, const llama_tokens& b, + size_t start, bool exact) { + common_prefix token_prefix; + if (a.size()<= start || b.size()<= start) { + return token_prefix; + } + std::set ignore_set = { ' ', '\n' ,'\r'}; + + llama_tokens a_sub(a.begin() + start, a.end()); + llama_tokens b_sub(b.begin() + start, b.end()); + + std::string a_str = llama_detokenize(ctx, a_sub, true); + std::string b_str = llama_detokenize(ctx, b_sub, true); + common_prefix string_prefix; + + std::vector a_list; + std::vector b_list; + + if (exact) { + size_t lcp = common_part(a_str, b_str); + string_prefix.first = lcp; + string_prefix.second = lcp; + token_prefix.first = find_n_tokens_from_string(ctx, a_sub, string_prefix.first, 0, a_list); + token_prefix.second = find_n_tokens_from_string(ctx, b_sub, string_prefix.second, 0, b_list); + } + else { + string_prefix = find_common_string_prefix(a_str, b_str, ignore_set); + token_prefix.first = find_n_tokens_from_string_with_ignore(ctx, a_sub, string_prefix.first, 0, ignore_set, a_list); + token_prefix.second = find_n_tokens_from_string_with_ignore(ctx, b_sub, string_prefix.second, 0, ignore_set, b_list); + } + + token_prefix = find_largest_common_number(a_list, b_list); + return token_prefix; +} + + struct completion_token_output { llama_token tok; std::string text_to_send; @@ -1000,19 +1161,22 @@ struct server_tokens { private: // disallow accessing these members directly, risking out-of-sync - // map a **start** position in tokens to the image chunk - std::unordered_map map_pos_to_media; + // map a **start** index in tokens to the image chunk + // note: the order need to be in-sync with tokens + std::map map_idx_to_media; // list of tokens - // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token - // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** - // important: for models using mrope, an image can contain multiple tokens but will use only one **position** - std::vector tokens; + // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk + // otherwise, it is a normal text token + // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list + // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos + llama_tokens tokens; - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // pos 0 1 2 3 4 5 6 7 8 9 - // map_pos_to_media will contain: {5, img0}, {8, img1} + // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] + // idx 0 1 2 3 4 5 6 7 8 9 10 + // pos 0 1 2 3 4 5 5 5 7 7 7 + // map_idx_to_media will contain: {5, img0}, {8, img1} public: server_tokens() = default; @@ -1036,7 +1200,8 @@ public: } } - server_tokens(const std::vector& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { + } llama_pos pos_next() const { if (!has_mtmd) { @@ -1045,7 +1210,7 @@ public: llama_pos res = tokens.size(); - for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ++it) { + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { const auto& chunk = it->second; res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); } @@ -1057,7 +1222,9 @@ public: std::string str() const { std::ostringstream oss; oss << "tokens: "; - for (const auto& t : tokens) { + for (size_t idx = 0; idx < tokens.size(); ++idx) { + llama_token t = tokens[idx]; + oss << "idx:" << idx << " "; if (t == LLAMA_TOKEN_NULL) { oss << " "; } @@ -1066,16 +1233,16 @@ public: } } oss << "\n"; - oss << "image pos: "; - for (const auto& it : map_pos_to_media) { + oss << "image idx: "; + for (const auto& it : map_idx_to_media) { oss << it.first << ", "; } return oss.str(); } - const mtmd::input_chunk_ptr& find_chunk(llama_pos pos) const { - auto it = map_pos_to_media.find(pos); - if (it != map_pos_to_media.end()) { + const mtmd::input_chunk_ptr& find_chunk(size_t idx) const { + auto it = map_idx_to_media.find(idx); + if (it != map_idx_to_media.end()) { return it->second; } throw std::runtime_error("Chunk not found"); @@ -1093,17 +1260,17 @@ public: auto type = mtmd_input_chunk_get_type(chunk); if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { GGML_ASSERT(has_mtmd); - const int n_pos = mtmd_input_chunk_get_n_pos(chunk); - llama_pos start_pos = tokens.size(); - for (int i = 0; i < n_pos; ++i) { + const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); + size_t start_idx = tokens.size(); + for (size_t i = 0; i < n_tokens; ++i) { tokens.emplace_back(LLAMA_TOKEN_NULL); } mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_pos_to_media[start_pos] = std::move(new_chunk); + map_idx_to_media[start_idx] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; - auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + const auto* text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); for (size_t i = 0; i < n_tokens; ++i) { push_back(text_tokens[i]); } @@ -1115,7 +1282,7 @@ public: // appends server tokens, updates the media map. copies media chunks. void push_back(server_tokens& tokens) { - size_t start_pos = size(); + size_t start_idx = size(); for (size_t i = 0; i < tokens.size(); i++) { push_back(tokens[i]); } @@ -1123,10 +1290,10 @@ public: // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. // We could also just check, but this will prevent silently dropping MTMD data. GGML_ASSERT(has_mtmd); - for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { - auto chunk = tokens.map_pos_to_media[it->first].get(); + for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { + auto* chunk = tokens.map_idx_to_media[it->first].get(); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_pos_to_media[start_pos + it->first] = std::move(new_chunk); + map_idx_to_media[start_idx + it->first] = std::move(new_chunk); } } } @@ -1164,7 +1331,6 @@ public: } llama_tokens tokens_data() { - return tokens; } @@ -1212,10 +1378,10 @@ public: } } // remove all image chunks that are not used anymore - for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) { - llama_pos pos = it->first; - if (pos >= (llama_pos)n) { - it = map_pos_to_media.erase(it); + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { + size_t idx = it->first; + if (idx >= n) { + it = map_idx_to_media.erase(it); } else { ++it; @@ -1236,7 +1402,37 @@ public: return llama_detokenize(ctx, text_tokens, special); } - size_t get_common_prefix(const server_tokens& b) const { + std::string detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const { + std::string str; + if (tokens.size() <= start || length == 0) { + return str; + } + llama_tokens text_tokens; + text_tokens.reserve(tokens.size() - start); + size_t i = 0; + size_t count = 0; + for (const auto& t : tokens) { + if (t != LLAMA_TOKEN_NULL && i>=start) { + text_tokens.push_back(t); + ++count; + if (count >= length) { + break; + } + } + ++i; + } + return llama_detokenize(ctx, text_tokens, special); + } + + size_t find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special, + size_t start, const size_t length) { + std::string str = detokenize(ctx, special, start, length); + std::vector tmp; + size_t n = find_n_tokens_from_string(ctx, b.tokens, start, length, tmp); + return n; + } + + size_t get_common_prefix_exact(const server_tokens& b) const { const size_t max_idx = std::min(tokens.size(), b.tokens.size()); if (!has_mtmd) { @@ -1262,12 +1458,12 @@ public: const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); - const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); + const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); + const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - if (id_ai == id_bi && pos_a == pos_b) { - GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen - i += pos_a - 1; // will be +1 by the for loop + if (id_ai == id_bi && n_tok_a == n_tok_b) { + GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen + i += n_tok_a - 1; // will be +1 by the for loop continue; } @@ -1285,6 +1481,94 @@ public: } + common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const { + common_prefix token_prefix; + + size_t n = get_common_prefix_exact(b); // strict token match as a starting point + token_prefix.first = n; + token_prefix.second = n; + + if (!has_mtmd) { + token_prefix = find_common_text_token_prefix(ctx, this->tokens, b.tokens, n, exact); + token_prefix.first += n; + token_prefix.second += n; + return token_prefix; + } + size_t i = n; + size_t j = n; + llama_tokens a_list; + llama_tokens b_list; + while (i < size() && j < b.size()) { + llama_token ai = tokens[i]; + llama_token bi = b.tokens[j]; + if (ai != LLAMA_TOKEN_NULL) { + a_list.push_back(ai); + ++i; + } + if (bi != LLAMA_TOKEN_NULL) { + b_list.push_back(bi); + ++j; + } + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact); + // text match or empty + if (prefix.first == a_list.size() && prefix.second == b_list.size()) { + a_list.clear(); + b_list.clear(); + const auto& a_chunk = find_chunk(i); + const auto& b_chunk = b.find_chunk(j); + + GGML_ASSERT(a_chunk && b_chunk); + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); + const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); + + // image match + if (id_ai == id_bi && n_tok_a == n_tok_b) { + GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen + i += n_tok_a; + j += n_tok_a; + prefix.first += n_tok_a; + prefix.second += n_tok_a; + token_prefix = common_prefix_add(prefix, token_prefix); + } else { + // do no include image token prefix + // only return text token prefix + token_prefix = common_prefix_add(prefix, token_prefix); + return token_prefix; + } + } + else { + // text not match + token_prefix = common_prefix_add(prefix, token_prefix); + return token_prefix; + } + } + } + common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact); + token_prefix = common_prefix_add(prefix, token_prefix); + + return token_prefix; + + } + + // take first n tokens of tokens list a + // find the common prefix between a and b + common_prefix get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact = false) const { + // not work for mtmd + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + auto tokens = get_text_tokens(); + if (n > tokens.size()) { + n = tokens.size(); + } + llama_tokens copy(tokens.begin(), tokens.begin()+n); + server_tokens a = server_tokens(copy, false); + return a.get_common_prefix(ctx, b, exact); + } + // make sure all text tokens are within the vocab range bool validate(const struct llama_context* ctx) const { const llama_model* model = llama_get_model(ctx); @@ -1296,8 +1580,8 @@ public: if (t == LLAMA_TOKEN_NULL) { try { const auto& chunk = find_chunk(i); - size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); - i += n_pos - 1; // will be +1 by the for loop + size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); + i += n_tokens - 1; // will be +1 by the for loop } catch (const std::exception& e) { return false; @@ -1312,41 +1596,33 @@ public: // encode and decode the image chunk int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - llama_pos n_past, + llama_context* ctx, + mtmd_context* mctx, + size_t idx, + llama_pos pos, int32_t seq_id, - llama_pos & n_pos_out, - size_t & n_tokens_out) { - char buffer[512]; - auto& chunk = find_chunk(n_past); + size_t& n_tokens_out) const { + const auto& chunk = find_chunk(idx); const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; - snprintf(buffer, 512, "processing : %s",name); - LOG_INFO(buffer, {}); + LLAMA_LOG_INFO("processing %s...\n", name); int32_t n_batch = llama_n_batch(ctx); int64_t t0 = ggml_time_ms(); - llama_pos new_n_past = n_past; + llama_pos new_n_past; // unused for now int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, chunk.get(), - n_past, + pos, seq_id, n_batch, true, // logits last &new_n_past); - // get number of tokens in the image - const size_t new_n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); - snprintf(buffer, 512, "processed in %g ms", 1.*(ggml_time_ms() - t0)); - LOG_INFO(buffer, {}); + LLAMA_LOG_INFO("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); if (result != 0) { - snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result); - LOG_ERROR(buffer, {}); - n_pos_out = n_past; + LLAMA_LOG_ERROR("mtmd_helper_eval failed with status %d", result); n_tokens_out = 0; return result; } - n_pos_out = new_n_past; - n_tokens_out = new_n_tokens; + n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); return 0; } @@ -1368,37 +1644,37 @@ public: } // Similarity between prompt and cached - float get_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const { + float get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const { GGML_ASSERT(n_keep >= 0 && n_discard >= 0); float sim_cur = 0; if (n_keep == 0 && n_discard == 0) { - size_t lcp_len= get_common_prefix(tokens); - sim_cur = get_slot_similarity(lcp_len, tokens.size(), size()); + auto lcp_len= get_common_prefix(ctx, tokens); + sim_cur = get_slot_similarity(lcp_len.second, tokens.size(), size()); } else { // remove tokens due to context shift and compare auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens tokens_ctx_shift.discard_n_tokens(n_keep, n_discard); - size_t lcp_len = get_common_prefix(tokens_ctx_shift); - sim_cur = get_slot_similarity(lcp_len, tokens_ctx_shift.size(), size()); + auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift); + sim_cur = get_slot_similarity(lcp_len.second, tokens_ctx_shift.size(), size()); } return sim_cur; } // Similarity between common part and cache - float get_cached_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const { + float get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const { GGML_ASSERT(n_keep >= 0 && n_discard >= 0); float sim_cur = 0; if (n_keep == 0 && n_discard == 0) { - size_t lcp_len = get_common_prefix(tokens); - sim_cur = (float) lcp_len/size(); + auto lcp_len = get_common_prefix(ctx, tokens); + sim_cur = (float) lcp_len.first/size(); } else { // remove tokens due to context shift and compare auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens tokens_ctx_shift.discard_n_tokens(n_keep, n_discard); - size_t lcp_len = get_common_prefix(tokens_ctx_shift); - sim_cur = (float) lcp_len / size(); + auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift); + sim_cur = (float) lcp_len.first / size(); } return sim_cur; } @@ -1541,3 +1817,11 @@ inline void print_files_info(const std::vector& files) { std::cout << std::dec << "\n\n"; // Reset to decimal } } + +inline bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens, + const server_tokens& prompt_tokens, size_t start, const common_prefix & prefix ) { + std::string common_cache = cache_tokens.detokenize(ctx, true, start, prefix.first); + std::string common_prompt = prompt_tokens.detokenize(ctx, true, start, prefix.second); + bool equal = common_cache == common_prompt; + return equal; +}