mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Server: Handle context shift better to reduce prompt processing time (#973)
* Handle context shift better to reduce pp Add context-shift args Add back ga_n in context shift * optimize discard function and bring back n_keep = -1 --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -1816,6 +1816,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
params.ctx_shift = false;
|
params.ctx_shift = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--context-shift") {
|
||||||
|
CHECK_ARG
|
||||||
|
std::string next_arg{ argv[i] };
|
||||||
|
for (auto& c : next_arg) c = std::tolower(c);
|
||||||
|
if (next_arg == "auto" || next_arg == "1" || next_arg == "on") {
|
||||||
|
params.ctx_shift = true;
|
||||||
|
}
|
||||||
|
else if (next_arg == "off" || next_arg == "0") {
|
||||||
|
params.ctx_shift = false;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
invalid_param = true;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-cram" || arg == "--cache-ram") {
|
if (arg == "-cram" || arg == "--cache-ram") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.cache_ram_mib = std::stoi(argv[i]);
|
params.cache_ram_mib = std::stoi(argv[i]);
|
||||||
@@ -2173,6 +2188,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
|
options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
|
||||||
options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" });
|
options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" });
|
||||||
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
|
options.push_back({ "*", " --no-context-shift", "disable context-shift." });
|
||||||
|
options.push_back({ "*", "--context-shift (auto|on|off|0|1)", "set context-shift (default: %s)", params.ctx_shift ? "on" : "off" });
|
||||||
options.push_back({ "backend" });
|
options.push_back({ "backend" });
|
||||||
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
|
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
|
||||||
options.push_back({ "*", "-cuda, --cuda-params", "comma separate list of cuda parameters" });
|
options.push_back({ "*", "-cuda, --cuda-params", "comma separate list of cuda parameters" });
|
||||||
|
|||||||
@@ -598,6 +598,8 @@ struct server_prompt_checkpoint {
|
|||||||
|
|
||||||
struct server_prompt {
|
struct server_prompt {
|
||||||
server_tokens tokens;
|
server_tokens tokens;
|
||||||
|
int n_keep;
|
||||||
|
int n_discarded;
|
||||||
|
|
||||||
std::vector<uint8_t> data;
|
std::vector<uint8_t> data;
|
||||||
|
|
||||||
@@ -654,10 +656,12 @@ struct server_prompt_cache {
|
|||||||
|
|
||||||
server_prompt* alloc(const server_prompt& prompt, size_t state_size) {
|
server_prompt* alloc(const server_prompt& prompt, size_t state_size) {
|
||||||
for (auto it = states.begin(); it != states.end();) {
|
for (auto it = states.begin(); it != states.end();) {
|
||||||
const size_t len = it->tokens.get_common_prefix(prompt.tokens);
|
auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens
|
||||||
|
tokens_ctx_shift.discard_n_tokens(prompt.n_keep, prompt.n_discarded);
|
||||||
|
|
||||||
|
const size_t len = it->tokens.get_common_prefix(tokens_ctx_shift);
|
||||||
// first check if the current state is contained fully in the cache
|
// first check if the current state is contained fully in the cache
|
||||||
if (len == prompt.tokens.size()) {
|
if (len == tokens_ctx_shift.size()) {
|
||||||
LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n");
|
LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@@ -692,9 +696,11 @@ struct server_prompt_cache {
|
|||||||
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
|
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
|
||||||
auto& cur = states.emplace_back();
|
auto& cur = states.emplace_back();
|
||||||
cur = {
|
cur = {
|
||||||
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
|
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
|
||||||
/*.data =*/ std::move(state_data),
|
/*.n_keep =*/ prompt.n_keep,
|
||||||
/*.checkpoints =*/ prompt.checkpoints,
|
/*.n_discarded =*/ prompt.n_discarded,
|
||||||
|
/*.data =*/ std::move(state_data),
|
||||||
|
/*.checkpoints =*/ prompt.checkpoints,
|
||||||
};
|
};
|
||||||
|
|
||||||
return &cur;
|
return &cur;
|
||||||
@@ -704,20 +710,16 @@ struct server_prompt_cache {
|
|||||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||||
|
|
||||||
float f_keep_best = float(lcp_best) / prompt.tokens.size();
|
float f_keep_best = float(lcp_best) / prompt.tokens.size();
|
||||||
//float sim_best = float(lcp_best) / tokens_new.size();
|
float sim_best = prompt.tokens.get_tokens_similarity(tokens_new, prompt.n_keep, prompt.n_discarded);
|
||||||
float sim_best = get_slot_similarity(lcp_best, tokens_new.size(), prompt.tokens.size());
|
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded = %d\n", f_keep_best, sim_best, prompt.n_keep, prompt.n_discarded);
|
||||||
|
|
||||||
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
|
||||||
|
|
||||||
auto it_best = states.end();
|
auto it_best = states.end();
|
||||||
|
|
||||||
// find the most similar cached prompt, that would also preserve the most context
|
// find the most similar cached prompt, that would also preserve the most context
|
||||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||||
const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
|
const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
|
||||||
|
|
||||||
const float f_keep_cur = float(lcp_cur) / it->tokens.size();
|
const float f_keep_cur = float(lcp_cur) / it->tokens.size();
|
||||||
//const float sim_cur = float(lcp_cur) / tokens_new.size();
|
const float sim_cur = it->tokens.get_tokens_similarity(tokens_new, it->n_keep, it->n_discarded);
|
||||||
const float sim_cur = get_slot_similarity(lcp_cur, tokens_new.size(), it->tokens.size());
|
|
||||||
if (sim_best < sim_cur) {
|
if (sim_best < sim_cur) {
|
||||||
f_keep_best = f_keep_cur;
|
f_keep_best = f_keep_cur;
|
||||||
sim_best = sim_cur;
|
sim_best = sim_cur;
|
||||||
@@ -726,7 +728,7 @@ struct server_prompt_cache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (it_best != states.end()) {
|
if (it_best != states.end()) {
|
||||||
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded = %d\n", f_keep_best, sim_best, it_best->n_keep, it_best->n_discarded);
|
||||||
const size_t size = it_best->data.size();
|
const size_t size = it_best->data.size();
|
||||||
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
|
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
|
||||||
if (n != size) {
|
if (n != size) {
|
||||||
@@ -783,8 +785,8 @@ struct server_prompt_cache {
|
|||||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
||||||
|
|
||||||
for (const auto& state : states) {
|
for (const auto& state : states) {
|
||||||
LLAMA_LOG_INFO(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
|
LLAMA_LOG_INFO(" - prompt %p: %7d tokens, %7d discarded, checkpoints: %2zu, %9.3f MiB\n",
|
||||||
(const void*)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
(const void*)&state, state.n_tokens(), state.n_discarded, state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -811,6 +813,8 @@ struct server_slot {
|
|||||||
int32_t n_past = 0;
|
int32_t n_past = 0;
|
||||||
int32_t n_decoded = 0;
|
int32_t n_decoded = 0;
|
||||||
int32_t n_remaining = -1;
|
int32_t n_remaining = -1;
|
||||||
|
int32_t n_discarded = 0;
|
||||||
|
int32_t n_kept = 0;
|
||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||||
|
|
||||||
@@ -1765,9 +1769,18 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||||
|
// print_tokens(task.tokens, cache_tokens);
|
||||||
size_t lcp_len = cache_tokens.get_common_prefix(task.tokens);
|
size_t lcp_len = cache_tokens.get_common_prefix(task.tokens);
|
||||||
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
|
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
|
||||||
const float sim_cur = get_slot_similarity(lcp_len, task.tokens.size(), cache_tokens.size());
|
float sim_cur = cache_tokens.get_tokens_similarity(task.tokens, 0, 0);
|
||||||
|
// handle context shift
|
||||||
|
if (slot.ga_n == 1 && slot.n_discarded > 0 && task.tokens.size()>=slot.n_ctx) {
|
||||||
|
float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(task.tokens, slot.n_kept, slot.n_discarded);
|
||||||
|
if (sim_cur_ctx_shift > sim_cur) {
|
||||||
|
sim_cur = sim_cur_ctx_shift;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// select the current slot if the criteria match
|
// select the current slot if the criteria match
|
||||||
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
|
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
|
||||||
sim_best = sim_cur;
|
sim_best = sim_cur;
|
||||||
@@ -1810,8 +1823,12 @@ struct server_context {
|
|||||||
const auto& tokens = ret->cache_tokens;
|
const auto& tokens = ret->cache_tokens;
|
||||||
float f_keep = 0.0f;
|
float f_keep = 0.0f;
|
||||||
if (!tokens.empty()) {
|
if (!tokens.empty()) {
|
||||||
size_t lcp_len = tokens.get_common_prefix(task.tokens);
|
if (ret->ga_n == 1 && ret->n_discarded > 0 && task.tokens.size() >= ret->n_ctx) {
|
||||||
f_keep = float(lcp_len) / tokens.size();
|
f_keep = tokens.get_cached_tokens_similarity(task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
f_keep = tokens.get_cached_tokens_similarity(task.tokens, 0, 0);
|
||||||
|
}
|
||||||
// if we are about to lose a large portion of the existing context - save it in the prompt cache
|
// if we are about to lose a large portion of the existing context - save it in the prompt cache
|
||||||
if (f_keep < cache_ram_similarity) {
|
if (f_keep < cache_ram_similarity) {
|
||||||
update_cache = true;
|
update_cache = true;
|
||||||
@@ -1827,12 +1844,15 @@ struct server_context {
|
|||||||
// TODO: mtmd does not support prompt cache
|
// TODO: mtmd does not support prompt cache
|
||||||
update_cache = update_cache && (ret->mctx == nullptr);
|
update_cache = update_cache && (ret->mctx == nullptr);
|
||||||
|
|
||||||
LLAMA_LOG_INFO("prompt cache: cache size: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
|
LLAMA_LOG_INFO("prompt cache: cache size: %d, n_keep: %d, n_discarded: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
|
||||||
(int)tokens.size(), cache_ram_n_min, f_keep, cache_ram_similarity);
|
(int)tokens.size(), ret->n_kept, ret->n_discarded, cache_ram_n_min, f_keep, cache_ram_similarity);
|
||||||
if (update_cache) {
|
if (update_cache) {
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
LLAMA_LOG_INFO("updating prompt cache\n");
|
LLAMA_LOG_INFO("updating prompt cache\n");
|
||||||
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||||
|
ret->server_cached_prompt.n_discarded = ret->n_discarded;
|
||||||
|
ret->server_cached_prompt.n_keep = ret->n_kept;
|
||||||
|
|
||||||
ret->prompt_save(*prompt_cache);
|
ret->prompt_save(*prompt_cache);
|
||||||
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
||||||
}
|
}
|
||||||
@@ -1840,9 +1860,16 @@ struct server_context {
|
|||||||
if (prompt_cache && !prompt_cache->states.empty()) {
|
if (prompt_cache && !prompt_cache->states.empty()) {
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||||
|
ret->server_cached_prompt.n_discarded = ret->n_discarded;
|
||||||
|
ret->server_cached_prompt.n_keep = ret->n_kept;
|
||||||
|
|
||||||
ret->prompt_load(*prompt_cache, task.tokens);
|
ret->prompt_load(*prompt_cache, task.tokens);
|
||||||
prompt_cache->update();
|
prompt_cache->update();
|
||||||
|
|
||||||
ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens
|
ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens
|
||||||
|
ret->n_discarded = ret->server_cached_prompt.n_discarded;
|
||||||
|
ret->n_kept = ret->server_cached_prompt.n_keep;
|
||||||
|
|
||||||
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3056,6 +3083,19 @@ struct server_context {
|
|||||||
queue_results.send(result);
|
queue_results.send(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print_tokens(const server_tokens & prompt, const server_tokens& cache) {
|
||||||
|
LLAMA_LOG_INFO( "prompt: %s\n", prompt.detokenize(ctx, true).c_str());
|
||||||
|
LLAMA_LOG_INFO( "cache: %s\n", cache.detokenize(ctx, true).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) {
|
||||||
|
llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard);
|
||||||
|
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||||
|
if (slot.params.cache_prompt) {
|
||||||
|
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void update_slots() {
|
void update_slots() {
|
||||||
if (system_need_update) {
|
if (system_need_update) {
|
||||||
system_prompt_update();
|
system_prompt_update();
|
||||||
@@ -3131,7 +3171,12 @@ struct server_context {
|
|||||||
GGML_ABORT("not supported by multimodal");
|
GGML_ABORT("not supported by multimodal");
|
||||||
}
|
}
|
||||||
// Shift context
|
// Shift context
|
||||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep;
|
||||||
|
if (add_bos_token) {
|
||||||
|
n_keep += 1;
|
||||||
|
}
|
||||||
|
n_keep = std::min(slot.n_ctx - 4, n_keep);
|
||||||
|
|
||||||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
||||||
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||||
|
|
||||||
@@ -3146,22 +3191,10 @@ struct server_context {
|
|||||||
{"n_system_tokens", system_tokens.size()},
|
{"n_system_tokens", system_tokens.size()},
|
||||||
{"n_cache_tokens", slot.cache_tokens.size()}
|
{"n_cache_tokens", slot.cache_tokens.size()}
|
||||||
});
|
});
|
||||||
|
slot.n_discarded = slot.n_discarded + n_discard;
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
slot.n_kept = n_keep;
|
||||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
discard_n_kv_and_cache_tokens(ctx, slot, n_keep, n_discard);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
|
||||||
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
|
|
||||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
|
||||||
new_tokens[i - n_discard] = new_tokens[i];
|
|
||||||
}
|
|
||||||
new_tokens.resize(slot.cache_tokens.size() - n_discard);
|
|
||||||
slot.cache_tokens.clear();
|
|
||||||
slot.cache_tokens.insert(new_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_past -= n_discard;
|
slot.n_past -= n_discard;
|
||||||
|
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3305,16 +3338,51 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||||
if (slot.params.n_keep < 0) {
|
// context shift for prompt processing
|
||||||
slot.params.n_keep = slot.n_prompt_tokens;
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
}
|
if (!params.ctx_shift) {
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
|
||||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
slot.release();
|
||||||
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
|
continue;
|
||||||
slot.release();
|
}
|
||||||
continue;
|
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
|
||||||
}
|
const int n_left = slot.n_ctx - n_keep;
|
||||||
|
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||||
|
int n_discard_cache = 0;
|
||||||
|
// we still need to truncate input since we have not discarded enough tokens
|
||||||
|
while (slot.n_prompt_tokens - slot.n_discarded >= slot.n_ctx) {
|
||||||
|
slot.n_discarded = slot.n_discarded + n_discard;
|
||||||
|
n_discard_cache = n_discard_cache + n_discard;
|
||||||
|
}
|
||||||
|
int n_discard_cache_max = std::max((int)slot.cache_tokens.size() - n_keep, 0);
|
||||||
|
n_discard_cache = std::min(n_discard_cache, n_discard_cache_max);
|
||||||
|
// discard matching tokens from cache and kv cache to avoid reprocessing the prompt
|
||||||
|
if (n_discard_cache > 0) {
|
||||||
|
discard_n_kv_and_cache_tokens(ctx, slot, n_keep, n_discard_cache);
|
||||||
|
}
|
||||||
|
// discard extra tokens from prompts
|
||||||
|
n_discard = slot.n_discarded;
|
||||||
|
slot.n_kept = n_keep;
|
||||||
|
prompt_tokens.discard_n_tokens(n_keep, n_discard);
|
||||||
|
slot.truncated = true;
|
||||||
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
LOG_VERBOSE("input truncated", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"id_task", slot.id_task},
|
||||||
|
{"n_ctx", slot.n_ctx},
|
||||||
|
{"n_keep", slot.params.n_keep},
|
||||||
|
{"n_left", n_left},
|
||||||
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
||||||
|
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
|
||||||
|
});
|
||||||
|
|
||||||
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
|
//print_tokens(prompt_tokens, slot.cache_tokens);
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
slot.n_discarded = 0;
|
||||||
|
}
|
||||||
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||||
|
|
||||||
if (!slot.params.cache_prompt) {
|
if (!slot.params.cache_prompt) {
|
||||||
|
|||||||
@@ -1343,6 +1343,59 @@ public:
|
|||||||
n_tokens_out = new_n_tokens;
|
n_tokens_out = new_n_tokens;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Keep the first n_keep and remove n_discard tokens from tokens
|
||||||
|
void discard_n_tokens(int32_t n_keep, int32_t n_discard) {
|
||||||
|
if (n_discard <= 0 || n_keep + n_discard >= size()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_tokens new_tokens = get_text_tokens(); // copy
|
||||||
|
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||||
|
new_tokens[i - n_discard] = new_tokens[i];
|
||||||
|
}
|
||||||
|
int32_t token_size = (int32_t) size();
|
||||||
|
new_tokens.resize(token_size - n_discard);
|
||||||
|
clear();
|
||||||
|
insert(new_tokens);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similarity between prompt and cached
|
||||||
|
float get_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||||
|
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
|
||||||
|
float sim_cur = 0;
|
||||||
|
if (n_keep == 0 && n_discard == 0) {
|
||||||
|
size_t lcp_len= get_common_prefix(tokens);
|
||||||
|
sim_cur = get_slot_similarity(lcp_len, tokens.size(), size());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// remove tokens due to context shift and compare
|
||||||
|
auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||||
|
tokens_ctx_shift.discard_n_tokens(n_keep, n_discard);
|
||||||
|
size_t lcp_len = get_common_prefix(tokens_ctx_shift);
|
||||||
|
sim_cur = get_slot_similarity(lcp_len, tokens_ctx_shift.size(), size());
|
||||||
|
}
|
||||||
|
return sim_cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similarity between common part and cache
|
||||||
|
float get_cached_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||||
|
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
|
||||||
|
float sim_cur = 0;
|
||||||
|
if (n_keep == 0 && n_discard == 0) {
|
||||||
|
size_t lcp_len = get_common_prefix(tokens);
|
||||||
|
sim_cur = (float) lcp_len/size();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// remove tokens due to context shift and compare
|
||||||
|
auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||||
|
tokens_ctx_shift.discard_n_tokens(n_keep, n_discard);
|
||||||
|
size_t lcp_len = get_common_prefix(tokens_ctx_shift);
|
||||||
|
sim_cur = (float) lcp_len / size();
|
||||||
|
}
|
||||||
|
return sim_cur;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Computes FNV-1a hash of the data
|
// Computes FNV-1a hash of the data
|
||||||
|
|||||||
Reference in New Issue
Block a user