Fix embedding missing, CORS and crash using verbose in server (#924)

* server: fix crash when prompt has image and is too long

* server: fix CORS

* server: fix empty result for embedding

* change error message to truncate prompt

* server: fix slot id for save and load state

* bug fix

* server: update slot similarity to handle mtmd

* server: quick hack to calculate number of token processed with image

* server: fix out of range error when detokenizing prompt under verbose

* Add back Access-Control-Allow-Origin

* Server: Add prompt tokens in embedding results

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-11-09 12:16:03 +00:00
committed by GitHub
parent 5cc15d0ecf
commit b63309a918
3 changed files with 139 additions and 91 deletions

View File

@@ -1741,7 +1741,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--no-context-shift") {
CHECK_ARG
params.ctx_shift = false;
return true;
}

View File

@@ -1467,11 +1467,11 @@ struct server_context {
return nullptr;
}
server_slot * get_available_slot(const std::string & prompt) {
server_slot * get_available_slot(const server_task & task) {
server_slot * ret = nullptr;
// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
int max_lcp_len = 0;
float similarity = 0;
@@ -1480,24 +1480,16 @@ struct server_context {
if (!slot.available()) {
continue;
}
const auto & cache_tokens = slot.cache_tokens;
// skip the slot if it does not contains prompt
if (!slot.prompt.is_string()) {
if (cache_tokens.empty()) {
continue;
}
// current slot's prompt
std::string slot_prompt = slot.prompt.get<std::string>();
// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = common_part(slot_prompt, prompt);
int lcp_len = cache_tokens.get_common_prefix(task.tokens);
// fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
const float similarity = float(lcp_len) / task.tokens.size();
// select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
max_lcp_len = lcp_len;
@@ -2363,6 +2355,7 @@ struct server_context {
res.data = json {
{"embedding", std::vector<float>(n_embd, 0.0f)},
{"tokens_evaluated", slot.n_prompt_tokens},
};
continue;
@@ -2372,6 +2365,7 @@ struct server_context {
res.data = json {
{"embedding", embd_res},
{"tokens_evaluated", slot.n_prompt_tokens},
};
}
@@ -2461,12 +2455,7 @@ struct server_context {
if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
} else {
std::string prompt;
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
prompt = json_value(task.data, "prompt", std::string());
}
slot = get_available_slot(prompt);
slot = get_available_slot(task);
}
if (slot == nullptr) {
@@ -2618,7 +2607,7 @@ struct server_context {
std::string filename = task.data.at("filename");
std::string filepath = task.data.at("filepath");
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -2661,7 +2650,7 @@ struct server_context {
slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
if (nread == 0) {
slot->cache_tokens.resize(0);
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@@ -2971,7 +2960,7 @@ struct server_context {
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
});
// empty prompt passed -> release the slot and send empty response
@@ -2999,13 +2988,18 @@ struct server_context {
continue;
}
} else {
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
if (!params.ctx_shift) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
slot.release();
continue;
}
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
@@ -3016,7 +3010,7 @@ struct server_context {
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
}
new_tokens.resize(slot.cache_tokens.size() - n_discard);
new_tokens.resize((int) prompt_tokens.size() - n_discard);
prompt_tokens.clear();
prompt_tokens.insert(new_tokens);
slot.truncated = true;
@@ -3029,8 +3023,8 @@ struct server_context {
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
});
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
@@ -3118,7 +3112,8 @@ struct server_context {
&& slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
// process the image
int32_t new_n_past;
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
size_t new_n_tokens;
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past, new_n_tokens);
int32_t n_pos = new_n_past - slot.n_past;
if (res != 0) {
LLAMA_LOG_ERROR("failed to process image, res = %d\n", res);
@@ -3134,7 +3129,7 @@ struct server_context {
}
slot.n_past += n_pos;
slot.n_prompt_tokens_processed += n_pos;
slot.n_prompt_tokens_processed += new_n_tokens;
}
@@ -3648,29 +3643,72 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
}
static json format_embeddings_response_oaicompat(const json& request, const json& embeddings) {
json data = json::array();
int i = 0;
for (auto& elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
}
//static json format_embeddings_response_oaicompat(const json& request, const json& embeddings) {
// json data = json::array();
// int32_t n_tokens = 0;
// int i = 0;
// for (auto& elem : embeddings) {
// data.push_back(json{
// {"embedding", json_value(elem, "embedding", json::array())},
// {"index", i++},
// {"object", "embedding"}
// });
// }
//
// json res = json{
// {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
// {"object", "list"},
// {"usage", json {
// {"prompt_tokens", n_tokens},
// {"total_tokens", n_tokens}
// }},
// {"data", data}
// };
//
// return res;
//}
static json format_embeddings_response_oaicompat(const json& request, const json& embeddings, bool use_base64 = false) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto& elem : embeddings) {
json embedding_obj;
if (use_base64) {
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
size_t data_size = vec.size() * sizeof(float);
embedding_obj = {
{"embedding", base64::encode(data_ptr, data_size)},
{"index", i++},
{"object", "embedding"},
{"encoding_format", "base64"}
};
}
else {
embedding_obj = {
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
}
data.push_back(embedding_obj);
n_tokens += json_value(elem, "tokens_evaluated", 0);
}
json res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", 0},
{"total_tokens", 0}
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"data", data}
};
return res;
}
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
// skip GH copilot requests when using default port
if (req.path == "/v1/health" || req.path == "/v1/completions") {
@@ -3769,16 +3807,7 @@ int main(int argc, char ** argv) {
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
svr->set_default_headers({{"Server", "llama.cpp"}});
// CORS preflight
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "application/json; charset=utf-8");
});
svr->set_default_headers({{"Server", "ik_llama.cpp"}});
svr->set_logger(log_server_request);
@@ -3931,8 +3960,6 @@ int main(int argc, char ** argv) {
}
// API key is invalid or not provided
// TODO: make another middleware for CORS related logic
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
LOG_WARNING("Unauthorized: Invalid API Key", {});
@@ -3940,13 +3967,45 @@ int main(int argc, char ** argv) {
return false;
};
auto middleware_server_state = [&res_error, &state](const httplib::Request& req, httplib::Response& res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split<std::string>(req.path, '.');
if (req.path == "/" || tmp.back() == "html") {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503;
}
else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
// allow the models endpoint to be accessed during loading
return true;
}
else {
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
}
return false;
}
return true;
};
// register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request& req, httplib::Response& res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
if (req.method == "OPTIONS") {
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "GET, POST");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_content("", "text/html"); // blank response, no data
return httplib::Server::HandlerResponse::Handled; // skip further processing
}
if (!middleware_server_state(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
});
//
// Route handlers (or controllers)
@@ -4211,8 +4270,6 @@ int main(int argc, char ** argv) {
};
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::string id_slot_str = req.path_params.at("id_slot");
int id_slot;
@@ -4245,7 +4302,6 @@ int main(int argc, char ** argv) {
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
}
}
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "system_prompt", ctx_server.system_prompt.c_str() },
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
@@ -4434,8 +4490,6 @@ int main(int argc, char ** argv) {
};
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = {
{"object", "list"},
{"data", {
@@ -4476,7 +4530,6 @@ int main(int argc, char ** argv) {
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
const int id_task = ctx_server.queue_tasks.get_new_id();
server_tokens token; // dummy tokens
ctx_server.queue_results.add_waiting_task_id(id_task);
@@ -4490,8 +4543,7 @@ int main(int argc, char ** argv) {
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
std::vector<llama_token> tokens;
@@ -4503,8 +4555,7 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
};
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
std::string content;
@@ -4517,9 +4568,8 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
};
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
bool is_openai = false;
@@ -4541,8 +4591,9 @@ int main(int argc, char ** argv) {
{
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
server_tokens token; // dummy token
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true, std::move(token));
std::vector<server_tokens> inputs;
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true, std::move(inputs[0]));
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
@@ -4553,7 +4604,7 @@ int main(int argc, char ** argv) {
responses = result.data.at("results");
} else {
// result for single task
responses = std::vector<json>{result.data};
responses = std::vector<json>{ result.data };
}
} else {
// error received, ignore everything else
@@ -4564,13 +4615,12 @@ int main(int argc, char ** argv) {
// write JSON response
json root = is_openai
? format_embeddings_response_oaicompat(body, responses)
? format_embeddings_response_oaicompat(body, responses, false)
: responses[0];
return res.set_content(root.dump(), "application/json; charset=utf-8");
};
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
json result = json::array();
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
auto & la = ctx_server.lora_adapters[i];
@@ -4584,9 +4634,7 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.lora_adapters.size();
@@ -4618,8 +4666,7 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
const auto list_saved_prompts = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto list_saved_prompts = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res) {
json response = json::array();
namespace fs = std::filesystem;
@@ -4679,22 +4726,20 @@ int main(int argc, char ** argv) {
res.set_content(response.dump(), "application/json; charset=utf-8");
};
const auto list_slot_prompts = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto list_slot_prompts = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res) {
json response = json::array();
for (server_slot & slot : ctx_server.slots) {
response.push_back({
{"slot_id", slot.id},
{"token_count", slot.cache_tokens.size()},
{"prompt", tokens_to_str(ctx_server.ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())}
{"prompt", slot.cache_tokens.detokenize(ctx_server.ctx, true) }
});
}
res.set_content(response.dump(), "application/json; charset=utf-8");
};
const auto delete_saved_prompt = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res)-> void {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto delete_saved_prompt = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res)-> void {
json response;
namespace fs = std::filesystem;
@@ -4741,8 +4786,7 @@ int main(int argc, char ** argv) {
res.set_content(response.dump(), "application/json; charset=utf-8");
};
const auto rename_saved_prompt = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res)-> void {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const auto rename_saved_prompt = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res)-> void {
json response;
namespace fs = std::filesystem;

View File

@@ -1304,11 +1304,12 @@ public:
// encode and decode the image chunk
int32_t process_chunk(
llama_context* ctx,
mtmd_context* mctx,
llama_context * ctx,
mtmd_context * mctx,
llama_pos n_past,
int32_t seq_id,
llama_pos& n_pos_out) {
llama_pos & n_pos_out,
size_t & n_tokens_out) {
char buffer[512];
auto& chunk = find_chunk(n_past);
const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
@@ -1325,21 +1326,25 @@ public:
n_batch,
true, // logits last
&new_n_past);
// get number of tokens in the image
const size_t new_n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
snprintf(buffer, 512, "processed in %g ms", 1.*(ggml_time_ms() - t0));
LOG_INFO(buffer, {});
if (result != 0) {
snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result);
LOG_ERROR(buffer, {});
n_pos_out = n_past;
n_tokens_out = 0;
return result;
}
n_pos_out = new_n_past;
n_tokens_out = new_n_tokens;
return 0;
}
};
// Computes FNV-1a hash of the data
static std::string fnv_hash(const uint8_t* data, size_t len) {
static std::string fnv_hash(const uint8_t * data, size_t len) {
const uint64_t fnv_prime = 0x100000001b3ULL;
uint64_t hash = 0xcbf29ce484222325ULL;
@@ -1350,7 +1355,7 @@ static std::string fnv_hash(const uint8_t* data, size_t len) {
return std::to_string(hash);
}
static server_tokens process_mtmd_prompt(mtmd_context* mctx, std::string prompt, std::vector<raw_buffer> files) {
static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
mtmd::bitmaps bitmaps;
for (auto& file : files) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size()));