Handle context shift better to reduce pp

Add context-shift args

Add back ga_n in context shift
This commit is contained in:
firecoperana
2025-11-16 07:54:56 -06:00
parent 4d003e29ee
commit 415015f386
3 changed files with 172 additions and 44 deletions

View File

@@ -1343,6 +1343,58 @@ public:
n_tokens_out = new_n_tokens;
return 0;
}
// Keep the first n_keep and remove n_discard tokens from tokens
void discard_n_tokens(int32_t n_keep, int32_t n_discard) {
llama_tokens new_tokens = get_text_tokens(); // copy
bool discard = false;
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
discard = true;
}
int32_t token_size = (int32_t) size();
if (discard) {
new_tokens.resize(token_size - n_discard);
clear();
insert(new_tokens);
}
}
// Similarity between prompt and cached
float get_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
float sim_cur = 0;
if (n_keep == 0 && n_discard == 0) {
size_t lcp_len= get_common_prefix(tokens);
sim_cur = get_slot_similarity(lcp_len, tokens.size(), size());
}
else {
// remove tokens due to context shift and compare
auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
tokens_ctx_shift.discard_n_tokens(n_keep, n_discard);
size_t lcp_len = get_common_prefix(tokens_ctx_shift);
sim_cur = get_slot_similarity(lcp_len, tokens_ctx_shift.size(), size());
}
return sim_cur;
}
// Similarity between common part and cache
float get_cached_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
float sim_cur = 0;
if (n_keep == 0 && n_discard == 0) {
size_t lcp_len = get_common_prefix(tokens);
sim_cur = (float) lcp_len/size();
}
else {
// remove tokens due to context shift and compare
auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
tokens_ctx_shift.discard_n_tokens(n_keep, n_discard);
size_t lcp_len = get_common_prefix(tokens_ctx_shift);
sim_cur = (float) lcp_len / size();
}
return sim_cur;
}
};
// Computes FNV-1a hash of the data