mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-22 15:39:23 +00:00
Fix prompt tokenization issue during prompt processing (#1008)
* Find common tokens between prompt and cache Fix wrong context size usage for mtmd Use start position of common part server: handle context shift * Add size check for inexact match * Change --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include <src/llama-impl.h>
|
||||
#include "common.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
@@ -14,6 +15,7 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
@@ -333,6 +335,165 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
|
||||
return out;
|
||||
}
|
||||
|
||||
struct common_prefix {
|
||||
size_t first = 0;
|
||||
size_t second = 0;
|
||||
};
|
||||
|
||||
common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b) {
|
||||
common_prefix prefix;
|
||||
prefix.first = a.first + b.first;
|
||||
prefix.second = a.second + b.second;
|
||||
return prefix;
|
||||
}
|
||||
|
||||
common_prefix find_common_string_prefix(const std::string & a_str, const std::string & b_str, const std::set<char>& ignore_set) {
|
||||
size_t i = 0;
|
||||
size_t j = 0;
|
||||
while (i < a_str.size() && j < b_str.size()) {
|
||||
auto a_chr = a_str[i];
|
||||
auto b_chr = b_str[j];
|
||||
if (a_chr == b_chr) {
|
||||
++i;
|
||||
++j;
|
||||
}
|
||||
else if (ignore_set.count(a_chr) && ignore_set.count(b_chr)) {
|
||||
++i;
|
||||
++j;
|
||||
}
|
||||
else if (ignore_set.count(a_chr)) {
|
||||
++i;
|
||||
}
|
||||
else if (ignore_set.count(b_chr)) {
|
||||
++j;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
common_prefix string_prefix;
|
||||
string_prefix.first = i;
|
||||
string_prefix.second = j;
|
||||
return string_prefix;
|
||||
}
|
||||
|
||||
size_t find_n_tokens_from_string(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start,
|
||||
std::vector<size_t> & map) {
|
||||
size_t n = 0;
|
||||
size_t string_len = 0;
|
||||
std::string str;
|
||||
auto model = llama_get_model(ctx);
|
||||
for (n = start; n < a.size(); ++n) {
|
||||
str = llama_token_to_piece(model, a[n], true);
|
||||
string_len = string_len + str.size();
|
||||
if (string_len <= max_size) {
|
||||
map.push_back(string_len);
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return map.size();
|
||||
}
|
||||
|
||||
std::string remove_with_set(std::string str, const std::set<char>& chars_to_remove) {
|
||||
str.erase(std::remove_if(str.begin(), str.end(),
|
||||
[&chars_to_remove](char c) { return chars_to_remove.find(c) != chars_to_remove.end(); }),
|
||||
str.end());
|
||||
return str;
|
||||
}
|
||||
|
||||
common_prefix find_largest_common_number(const std::vector<size_t>& a_list, const std::vector<size_t>& b_list) {
|
||||
common_prefix token_prefix;
|
||||
token_prefix.first = 0;
|
||||
token_prefix.second = 0;
|
||||
int i = a_list.size() - 1; // start from end of a
|
||||
int j = b_list.size() - 1; // start from end of b
|
||||
if (i < 0 || j < 0) {
|
||||
return token_prefix;
|
||||
}
|
||||
while (i >= 0 && j >= 0) {
|
||||
if (a_list[i] == b_list[j]) {
|
||||
// found largest common value
|
||||
token_prefix.first = (size_t)i + 1;
|
||||
token_prefix.second = (size_t)j + 1;
|
||||
break;
|
||||
}
|
||||
else if (a_list[i] > b_list[j]) {
|
||||
--i;
|
||||
}
|
||||
else {
|
||||
--j;
|
||||
}
|
||||
}
|
||||
return token_prefix;
|
||||
}
|
||||
|
||||
size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, const std::set<char> & ignore_set,
|
||||
std::vector<size_t>& map) {
|
||||
bool use_ignore = ignore_set.size()>0;
|
||||
size_t n = 0;
|
||||
size_t string_len = 0;
|
||||
size_t string_len_ignore = 0;
|
||||
std::string str;
|
||||
std::string str_ignore;
|
||||
auto model = llama_get_model(ctx);
|
||||
for (n = start; n < a.size(); ++n) {
|
||||
str = llama_token_to_piece(model, a[n], true);
|
||||
string_len = string_len + str.size();
|
||||
if (use_ignore) {
|
||||
str_ignore = remove_with_set(str, ignore_set);
|
||||
}
|
||||
else {
|
||||
str_ignore = str;
|
||||
}
|
||||
string_len_ignore = string_len_ignore + str_ignore.size();
|
||||
if (string_len <= max_size) {
|
||||
map.push_back(string_len_ignore);
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return map.size();
|
||||
}
|
||||
|
||||
common_prefix find_common_text_token_prefix(const llama_context * ctx, const llama_tokens & a, const llama_tokens& b,
|
||||
size_t start, bool exact) {
|
||||
common_prefix token_prefix;
|
||||
if (a.size()<= start || b.size()<= start) {
|
||||
return token_prefix;
|
||||
}
|
||||
std::set<char> ignore_set = { ' ', '\n' ,'\r'};
|
||||
|
||||
llama_tokens a_sub(a.begin() + start, a.end());
|
||||
llama_tokens b_sub(b.begin() + start, b.end());
|
||||
|
||||
std::string a_str = llama_detokenize(ctx, a_sub, true);
|
||||
std::string b_str = llama_detokenize(ctx, b_sub, true);
|
||||
common_prefix string_prefix;
|
||||
|
||||
std::vector<size_t> a_list;
|
||||
std::vector<size_t> b_list;
|
||||
|
||||
if (exact) {
|
||||
size_t lcp = common_part(a_str, b_str);
|
||||
string_prefix.first = lcp;
|
||||
string_prefix.second = lcp;
|
||||
token_prefix.first = find_n_tokens_from_string(ctx, a_sub, string_prefix.first, 0, a_list);
|
||||
token_prefix.second = find_n_tokens_from_string(ctx, b_sub, string_prefix.second, 0, b_list);
|
||||
}
|
||||
else {
|
||||
string_prefix = find_common_string_prefix(a_str, b_str, ignore_set);
|
||||
token_prefix.first = find_n_tokens_from_string_with_ignore(ctx, a_sub, string_prefix.first, 0, ignore_set, a_list);
|
||||
token_prefix.second = find_n_tokens_from_string_with_ignore(ctx, b_sub, string_prefix.second, 0, ignore_set, b_list);
|
||||
}
|
||||
|
||||
token_prefix = find_largest_common_number(a_list, b_list);
|
||||
return token_prefix;
|
||||
}
|
||||
|
||||
|
||||
struct completion_token_output {
|
||||
llama_token tok;
|
||||
std::string text_to_send;
|
||||
@@ -1000,19 +1161,22 @@ struct server_tokens {
|
||||
|
||||
private: // disallow accessing these members directly, risking out-of-sync
|
||||
|
||||
// map a **start** position in tokens to the image chunk
|
||||
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
|
||||
// map a **start** index in tokens to the image chunk
|
||||
// note: the order need to be in-sync with tokens
|
||||
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
|
||||
|
||||
// list of tokens
|
||||
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
|
||||
// a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
|
||||
// important: for models using mrope, an image can contain multiple tokens but will use only one **position**
|
||||
std::vector<llama_token> tokens;
|
||||
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
|
||||
// otherwise, it is a normal text token
|
||||
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
|
||||
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
|
||||
llama_tokens tokens;
|
||||
|
||||
// for ex. with input of 5 text tokens and 2 images:
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
|
||||
// pos 0 1 2 3 4 5 6 7 8 9
|
||||
// map_pos_to_media will contain: {5, img0}, {8, img1}
|
||||
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
|
||||
// idx 0 1 2 3 4 5 6 7 8 9 10
|
||||
// pos 0 1 2 3 4 5 5 5 7 7 7
|
||||
// map_idx_to_media will contain: {5, img0}, {8, img1}
|
||||
|
||||
public:
|
||||
server_tokens() = default;
|
||||
@@ -1036,7 +1200,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
server_tokens(const std::vector<llama_token>& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
|
||||
server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
|
||||
}
|
||||
|
||||
llama_pos pos_next() const {
|
||||
if (!has_mtmd) {
|
||||
@@ -1045,7 +1210,7 @@ public:
|
||||
|
||||
llama_pos res = tokens.size();
|
||||
|
||||
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ++it) {
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
const auto& chunk = it->second;
|
||||
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
}
|
||||
@@ -1057,7 +1222,9 @@ public:
|
||||
std::string str() const {
|
||||
std::ostringstream oss;
|
||||
oss << "tokens: ";
|
||||
for (const auto& t : tokens) {
|
||||
for (size_t idx = 0; idx < tokens.size(); ++idx) {
|
||||
llama_token t = tokens[idx];
|
||||
oss << "idx:" << idx << " ";
|
||||
if (t == LLAMA_TOKEN_NULL) {
|
||||
oss << "<embd> ";
|
||||
}
|
||||
@@ -1066,16 +1233,16 @@ public:
|
||||
}
|
||||
}
|
||||
oss << "\n";
|
||||
oss << "image pos: ";
|
||||
for (const auto& it : map_pos_to_media) {
|
||||
oss << "image idx: ";
|
||||
for (const auto& it : map_idx_to_media) {
|
||||
oss << it.first << ", ";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const mtmd::input_chunk_ptr& find_chunk(llama_pos pos) const {
|
||||
auto it = map_pos_to_media.find(pos);
|
||||
if (it != map_pos_to_media.end()) {
|
||||
const mtmd::input_chunk_ptr& find_chunk(size_t idx) const {
|
||||
auto it = map_idx_to_media.find(idx);
|
||||
if (it != map_idx_to_media.end()) {
|
||||
return it->second;
|
||||
}
|
||||
throw std::runtime_error("Chunk not found");
|
||||
@@ -1093,17 +1260,17 @@ public:
|
||||
auto type = mtmd_input_chunk_get_type(chunk);
|
||||
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
GGML_ASSERT(has_mtmd);
|
||||
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
|
||||
llama_pos start_pos = tokens.size();
|
||||
for (int i = 0; i < n_pos; ++i) {
|
||||
const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
|
||||
size_t start_idx = tokens.size();
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
tokens.emplace_back(LLAMA_TOKEN_NULL);
|
||||
}
|
||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||
map_pos_to_media[start_pos] = std::move(new_chunk);
|
||||
map_idx_to_media[start_idx] = std::move(new_chunk);
|
||||
}
|
||||
else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
size_t n_tokens;
|
||||
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||
const auto* text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
push_back(text_tokens[i]);
|
||||
}
|
||||
@@ -1115,7 +1282,7 @@ public:
|
||||
|
||||
// appends server tokens, updates the media map. copies media chunks.
|
||||
void push_back(server_tokens& tokens) {
|
||||
size_t start_pos = size();
|
||||
size_t start_idx = size();
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
push_back(tokens[i]);
|
||||
}
|
||||
@@ -1123,10 +1290,10 @@ public:
|
||||
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
|
||||
// We could also just check, but this will prevent silently dropping MTMD data.
|
||||
GGML_ASSERT(has_mtmd);
|
||||
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
|
||||
auto chunk = tokens.map_pos_to_media[it->first].get();
|
||||
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
|
||||
auto* chunk = tokens.map_idx_to_media[it->first].get();
|
||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||
map_pos_to_media[start_pos + it->first] = std::move(new_chunk);
|
||||
map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1164,7 +1331,6 @@ public:
|
||||
}
|
||||
|
||||
llama_tokens tokens_data() {
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
@@ -1212,10 +1378,10 @@ public:
|
||||
}
|
||||
}
|
||||
// remove all image chunks that are not used anymore
|
||||
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
|
||||
llama_pos pos = it->first;
|
||||
if (pos >= (llama_pos)n) {
|
||||
it = map_pos_to_media.erase(it);
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
|
||||
size_t idx = it->first;
|
||||
if (idx >= n) {
|
||||
it = map_idx_to_media.erase(it);
|
||||
}
|
||||
else {
|
||||
++it;
|
||||
@@ -1236,7 +1402,37 @@ public:
|
||||
return llama_detokenize(ctx, text_tokens, special);
|
||||
}
|
||||
|
||||
size_t get_common_prefix(const server_tokens& b) const {
|
||||
std::string detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const {
|
||||
std::string str;
|
||||
if (tokens.size() <= start || length == 0) {
|
||||
return str;
|
||||
}
|
||||
llama_tokens text_tokens;
|
||||
text_tokens.reserve(tokens.size() - start);
|
||||
size_t i = 0;
|
||||
size_t count = 0;
|
||||
for (const auto& t : tokens) {
|
||||
if (t != LLAMA_TOKEN_NULL && i>=start) {
|
||||
text_tokens.push_back(t);
|
||||
++count;
|
||||
if (count >= length) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
++i;
|
||||
}
|
||||
return llama_detokenize(ctx, text_tokens, special);
|
||||
}
|
||||
|
||||
size_t find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special,
|
||||
size_t start, const size_t length) {
|
||||
std::string str = detokenize(ctx, special, start, length);
|
||||
std::vector<size_t> tmp;
|
||||
size_t n = find_n_tokens_from_string(ctx, b.tokens, start, length, tmp);
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t get_common_prefix_exact(const server_tokens& b) const {
|
||||
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
||||
|
||||
if (!has_mtmd) {
|
||||
@@ -1262,12 +1458,12 @@ public:
|
||||
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
|
||||
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
|
||||
|
||||
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
||||
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
||||
const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
|
||||
const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
|
||||
|
||||
if (id_ai == id_bi && pos_a == pos_b) {
|
||||
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
|
||||
i += pos_a - 1; // will be +1 by the for loop
|
||||
if (id_ai == id_bi && n_tok_a == n_tok_b) {
|
||||
GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
|
||||
i += n_tok_a - 1; // will be +1 by the for loop
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1285,6 +1481,94 @@ public:
|
||||
}
|
||||
|
||||
|
||||
common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const {
|
||||
common_prefix token_prefix;
|
||||
|
||||
size_t n = get_common_prefix_exact(b); // strict token match as a starting point
|
||||
token_prefix.first = n;
|
||||
token_prefix.second = n;
|
||||
|
||||
if (!has_mtmd) {
|
||||
token_prefix = find_common_text_token_prefix(ctx, this->tokens, b.tokens, n, exact);
|
||||
token_prefix.first += n;
|
||||
token_prefix.second += n;
|
||||
return token_prefix;
|
||||
}
|
||||
size_t i = n;
|
||||
size_t j = n;
|
||||
llama_tokens a_list;
|
||||
llama_tokens b_list;
|
||||
while (i < size() && j < b.size()) {
|
||||
llama_token ai = tokens[i];
|
||||
llama_token bi = b.tokens[j];
|
||||
if (ai != LLAMA_TOKEN_NULL) {
|
||||
a_list.push_back(ai);
|
||||
++i;
|
||||
}
|
||||
if (bi != LLAMA_TOKEN_NULL) {
|
||||
b_list.push_back(bi);
|
||||
++j;
|
||||
}
|
||||
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
||||
common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact);
|
||||
// text match or empty
|
||||
if (prefix.first == a_list.size() && prefix.second == b_list.size()) {
|
||||
a_list.clear();
|
||||
b_list.clear();
|
||||
const auto& a_chunk = find_chunk(i);
|
||||
const auto& b_chunk = b.find_chunk(j);
|
||||
|
||||
GGML_ASSERT(a_chunk && b_chunk);
|
||||
|
||||
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
|
||||
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
|
||||
|
||||
const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
|
||||
const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
|
||||
|
||||
// image match
|
||||
if (id_ai == id_bi && n_tok_a == n_tok_b) {
|
||||
GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
|
||||
i += n_tok_a;
|
||||
j += n_tok_a;
|
||||
prefix.first += n_tok_a;
|
||||
prefix.second += n_tok_a;
|
||||
token_prefix = common_prefix_add(prefix, token_prefix);
|
||||
} else {
|
||||
// do no include image token prefix
|
||||
// only return text token prefix
|
||||
token_prefix = common_prefix_add(prefix, token_prefix);
|
||||
return token_prefix;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// text not match
|
||||
token_prefix = common_prefix_add(prefix, token_prefix);
|
||||
return token_prefix;
|
||||
}
|
||||
}
|
||||
}
|
||||
common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact);
|
||||
token_prefix = common_prefix_add(prefix, token_prefix);
|
||||
|
||||
return token_prefix;
|
||||
|
||||
}
|
||||
|
||||
// take first n tokens of tokens list a
|
||||
// find the common prefix between a and b
|
||||
common_prefix get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact = false) const {
|
||||
// not work for mtmd
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
auto tokens = get_text_tokens();
|
||||
if (n > tokens.size()) {
|
||||
n = tokens.size();
|
||||
}
|
||||
llama_tokens copy(tokens.begin(), tokens.begin()+n);
|
||||
server_tokens a = server_tokens(copy, false);
|
||||
return a.get_common_prefix(ctx, b, exact);
|
||||
}
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context* ctx) const {
|
||||
const llama_model* model = llama_get_model(ctx);
|
||||
@@ -1296,8 +1580,8 @@ public:
|
||||
if (t == LLAMA_TOKEN_NULL) {
|
||||
try {
|
||||
const auto& chunk = find_chunk(i);
|
||||
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
|
||||
i += n_pos - 1; // will be +1 by the for loop
|
||||
size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
i += n_tokens - 1; // will be +1 by the for loop
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
return false;
|
||||
@@ -1312,41 +1596,33 @@ public:
|
||||
|
||||
// encode and decode the image chunk
|
||||
int32_t process_chunk(
|
||||
llama_context * ctx,
|
||||
mtmd_context * mctx,
|
||||
llama_pos n_past,
|
||||
llama_context* ctx,
|
||||
mtmd_context* mctx,
|
||||
size_t idx,
|
||||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
llama_pos & n_pos_out,
|
||||
size_t & n_tokens_out) {
|
||||
char buffer[512];
|
||||
auto& chunk = find_chunk(n_past);
|
||||
size_t& n_tokens_out) const {
|
||||
const auto& chunk = find_chunk(idx);
|
||||
const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
||||
? "image" : "audio";
|
||||
snprintf(buffer, 512, "processing : %s",name);
|
||||
LOG_INFO(buffer, {});
|
||||
LLAMA_LOG_INFO("processing %s...\n", name);
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
llama_pos new_n_past = n_past;
|
||||
llama_pos new_n_past; // unused for now
|
||||
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
|
||||
chunk.get(),
|
||||
n_past,
|
||||
pos,
|
||||
seq_id,
|
||||
n_batch,
|
||||
true, // logits last
|
||||
&new_n_past);
|
||||
// get number of tokens in the image
|
||||
const size_t new_n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
snprintf(buffer, 512, "processed in %g ms", 1.*(ggml_time_ms() - t0));
|
||||
LOG_INFO(buffer, {});
|
||||
LLAMA_LOG_INFO("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||
if (result != 0) {
|
||||
snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result);
|
||||
LOG_ERROR(buffer, {});
|
||||
n_pos_out = n_past;
|
||||
LLAMA_LOG_ERROR("mtmd_helper_eval failed with status %d", result);
|
||||
n_tokens_out = 0;
|
||||
return result;
|
||||
}
|
||||
n_pos_out = new_n_past;
|
||||
n_tokens_out = new_n_tokens;
|
||||
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1368,37 +1644,37 @@ public:
|
||||
}
|
||||
|
||||
// Similarity between prompt and cached
|
||||
float get_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||
float get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
|
||||
float sim_cur = 0;
|
||||
if (n_keep == 0 && n_discard == 0) {
|
||||
size_t lcp_len= get_common_prefix(tokens);
|
||||
sim_cur = get_slot_similarity(lcp_len, tokens.size(), size());
|
||||
auto lcp_len= get_common_prefix(ctx, tokens);
|
||||
sim_cur = get_slot_similarity(lcp_len.second, tokens.size(), size());
|
||||
}
|
||||
else {
|
||||
// remove tokens due to context shift and compare
|
||||
auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
tokens_ctx_shift.discard_n_tokens(n_keep, n_discard);
|
||||
size_t lcp_len = get_common_prefix(tokens_ctx_shift);
|
||||
sim_cur = get_slot_similarity(lcp_len, tokens_ctx_shift.size(), size());
|
||||
auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift);
|
||||
sim_cur = get_slot_similarity(lcp_len.second, tokens_ctx_shift.size(), size());
|
||||
}
|
||||
return sim_cur;
|
||||
}
|
||||
|
||||
// Similarity between common part and cache
|
||||
float get_cached_tokens_similarity(const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||
float get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
|
||||
float sim_cur = 0;
|
||||
if (n_keep == 0 && n_discard == 0) {
|
||||
size_t lcp_len = get_common_prefix(tokens);
|
||||
sim_cur = (float) lcp_len/size();
|
||||
auto lcp_len = get_common_prefix(ctx, tokens);
|
||||
sim_cur = (float) lcp_len.first/size();
|
||||
}
|
||||
else {
|
||||
// remove tokens due to context shift and compare
|
||||
auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens
|
||||
tokens_ctx_shift.discard_n_tokens(n_keep, n_discard);
|
||||
size_t lcp_len = get_common_prefix(tokens_ctx_shift);
|
||||
sim_cur = (float) lcp_len / size();
|
||||
auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift);
|
||||
sim_cur = (float) lcp_len.first / size();
|
||||
}
|
||||
return sim_cur;
|
||||
}
|
||||
@@ -1541,3 +1817,11 @@ inline void print_files_info(const std::vector<raw_buffer>& files) {
|
||||
std::cout << std::dec << "\n\n"; // Reset to decimal
|
||||
}
|
||||
}
|
||||
|
||||
inline bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
|
||||
const server_tokens& prompt_tokens, size_t start, const common_prefix & prefix ) {
|
||||
std::string common_cache = cache_tokens.detokenize(ctx, true, start, prefix.first);
|
||||
std::string common_prompt = prompt_tokens.detokenize(ctx, true, start, prefix.second);
|
||||
bool equal = common_cache == common_prompt;
|
||||
return equal;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user