From b63309a918c297e3bd3b7cd9cf6cef695502f9ca Mon Sep 17 00:00:00 2001 From: firecoperana Date: Sun, 9 Nov 2025 12:16:03 +0000 Subject: [PATCH] Fix embedding missing, CORS and crash using verbose in server (#924) * server: fix crash when prompt has image and is too long * server: fix CORS * server: fix empty result for embedding * change error message to truncate prompt * server: fix slot id for save and load state * bug fix * server: update slot similarity to handle mtmd * server: quick hack to calculate number of token processed with image * server: fix out of range error when detokenizing prompt under verbose * Add back Access-Control-Allow-Origin * Server: Add prompt tokens in embedding results --------- Co-authored-by: firecoperana --- common/common.cpp | 1 - examples/server/server.cpp | 214 ++++++++++++++++++++++--------------- examples/server/utils.hpp | 15 ++- 3 files changed, 139 insertions(+), 91 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 252b53e8..7e81c99b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1741,7 +1741,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--no-context-shift") { - CHECK_ARG params.ctx_shift = false; return true; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4428af24..a4dbb47d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1467,11 +1467,11 @@ struct server_context { return nullptr; } - server_slot * get_available_slot(const std::string & prompt) { + server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) { + if (ret == nullptr && slot_prompt_similarity != 0.0f) { int max_lcp_len = 0; float similarity = 0; @@ -1480,24 +1480,16 @@ struct server_context { if (!slot.available()) { continue; } - + const auto & cache_tokens = slot.cache_tokens; // skip the slot if it does not contains prompt - if (!slot.prompt.is_string()) { + if (cache_tokens.empty()) { continue; } - // current slot's prompt - std::string slot_prompt = slot.prompt.get(); - - // length of the current slot's prompt - int slot_prompt_len = slot_prompt.size(); - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - int lcp_len = common_part(slot_prompt, prompt); - + int lcp_len = cache_tokens.get_common_prefix(task.tokens); // fraction of the common substring length compared to the current slot's prompt length - similarity = static_cast(lcp_len) / slot_prompt_len; - + const float similarity = float(lcp_len) / task.tokens.size(); // select the current slot if the criteria match if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { max_lcp_len = lcp_len; @@ -2363,6 +2355,7 @@ struct server_context { res.data = json { {"embedding", std::vector(n_embd, 0.0f)}, + {"tokens_evaluated", slot.n_prompt_tokens}, }; continue; @@ -2372,6 +2365,7 @@ struct server_context { res.data = json { {"embedding", embd_res}, + {"tokens_evaluated", slot.n_prompt_tokens}, }; } @@ -2461,12 +2455,7 @@ struct server_context { if (id_slot != -1) { slot = get_slot_by_id(id_slot); } else { - std::string prompt; - if (task.data.contains("prompt") && task.data.at("prompt").is_string()) { - prompt = json_value(task.data, "prompt", std::string()); - } - - slot = get_available_slot(prompt); + slot = get_available_slot(task); } if (slot == nullptr) { @@ -2618,7 +2607,7 @@ struct server_context { std::string filename = task.data.at("filename"); std::string filepath = task.data.at("filepath"); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2661,7 +2650,7 @@ struct server_context { slot->cache_tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); if (nread == 0) { slot->cache_tokens.resize(0); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -2971,7 +2960,7 @@ struct server_context { {"n_ctx", slot.n_ctx}, {"n_keep", slot.params.n_keep}, {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, + {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, }); // empty prompt passed -> release the slot and send empty response @@ -2999,13 +2988,18 @@ struct server_context { continue; } } else { + // if input prompt is too big, truncate it (if group attention self-extend is disabled) if (slot.params.n_keep < 0) { slot.params.n_keep = slot.n_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it (if group attention self-extend is disabled) if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { + if (!params.ctx_shift) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER); + slot.release(); + continue; + } const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; @@ -3016,7 +3010,7 @@ struct server_context { for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.size() - n_discard); + new_tokens.resize((int) prompt_tokens.size() - n_discard); prompt_tokens.clear(); prompt_tokens.insert(new_tokens); slot.truncated = true; @@ -3029,8 +3023,8 @@ struct server_context { {"n_keep", slot.params.n_keep}, {"n_left", n_left}, {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, - }); + {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, + }); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } @@ -3118,7 +3112,8 @@ struct server_context { && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + size_t new_n_tokens; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past, new_n_tokens); int32_t n_pos = new_n_past - slot.n_past; if (res != 0) { LLAMA_LOG_ERROR("failed to process image, res = %d\n", res); @@ -3134,7 +3129,7 @@ struct server_context { } slot.n_past += n_pos; - slot.n_prompt_tokens_processed += n_pos; + slot.n_prompt_tokens_processed += new_n_tokens; } @@ -3648,29 +3643,72 @@ static std::vector format_partial_response_oaicompat(server_task_result ta } -static json format_embeddings_response_oaicompat(const json& request, const json& embeddings) { - json data = json::array(); - int i = 0; - for (auto& elem : embeddings) { - data.push_back(json{ - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }); - } +//static json format_embeddings_response_oaicompat(const json& request, const json& embeddings) { +// json data = json::array(); +// int32_t n_tokens = 0; +// int i = 0; +// for (auto& elem : embeddings) { +// data.push_back(json{ +// {"embedding", json_value(elem, "embedding", json::array())}, +// {"index", i++}, +// {"object", "embedding"} +// }); +// } +// +// json res = json{ +// {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, +// {"object", "list"}, +// {"usage", json { +// {"prompt_tokens", n_tokens}, +// {"total_tokens", n_tokens} +// }}, +// {"data", data} +// }; +// +// return res; +//} +static json format_embeddings_response_oaicompat(const json& request, const json& embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto& elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } + else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); + n_tokens += json_value(elem, "tokens_evaluated", 0); + } json res = json{ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, {"usage", json { - {"prompt_tokens", 0}, - {"total_tokens", 0} + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} }}, {"data", data} }; return res; } + static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port if (req.path == "/v1/health" || req.path == "/v1/completions") { @@ -3769,16 +3807,7 @@ int main(int argc, char ** argv) { std::atomic state{SERVER_STATE_LOADING_MODEL}; - svr->set_default_headers({{"Server", "llama.cpp"}}); - - // CORS preflight - svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - return res.set_content("", "application/json; charset=utf-8"); - }); + svr->set_default_headers({{"Server", "ik_llama.cpp"}}); svr->set_logger(log_server_request); @@ -3931,8 +3960,6 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - // TODO: make another middleware for CORS related logic - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WARNING("Unauthorized: Invalid API Key", {}); @@ -3940,13 +3967,45 @@ int main(int argc, char ** argv) { return false; }; + auto middleware_server_state = [&res_error, &state](const httplib::Request& req, httplib::Response& res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_LOADING_MODEL) { + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } + else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { + // allow the models endpoint to be accessed during loading + return true; + } + else { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } + return false; + } + return true; + }; + // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } if (!middleware_validate_api_key(req, res)) { return httplib::Server::HandlerResponse::Handled; } return httplib::Server::HandlerResponse::Unhandled; - }); + }); // // Route handlers (or controllers) @@ -4211,8 +4270,6 @@ int main(int argc, char ** argv) { }; const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - std::string id_slot_str = req.path_params.at("id_slot"); int id_slot; @@ -4245,7 +4302,6 @@ int main(int argc, char ** argv) { curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); } } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, @@ -4434,8 +4490,6 @@ int main(int argc, char ** argv) { }; const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json models = { {"object", "list"}, {"data", { @@ -4476,7 +4530,6 @@ int main(int argc, char ** argv) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - const int id_task = ctx_server.queue_tasks.get_new_id(); server_tokens token; // dummy tokens ctx_server.queue_results.add_waiting_task_id(id_task); @@ -4490,8 +4543,7 @@ int main(int argc, char ** argv) { OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::vector tokens; @@ -4503,8 +4555,7 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -4517,9 +4568,8 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); bool is_openai = false; @@ -4541,8 +4591,9 @@ int main(int argc, char ** argv) { { const int id_task = ctx_server.queue_tasks.get_new_id(); ctx_server.queue_results.add_waiting_task_id(id_task); - server_tokens token; // dummy token - ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true, std::move(token)); + std::vector inputs; + inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true); + ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true, std::move(inputs[0])); // get the result server_task_result result = ctx_server.queue_results.recv(id_task); @@ -4553,7 +4604,7 @@ int main(int argc, char ** argv) { responses = result.data.at("results"); } else { // result for single task - responses = std::vector{result.data}; + responses = std::vector{ result.data }; } } else { // error received, ignore everything else @@ -4564,13 +4615,12 @@ int main(int argc, char ** argv) { // write JSON response json root = is_openai - ? format_embeddings_response_oaicompat(body, responses) + ? format_embeddings_response_oaicompat(body, responses, false) : responses[0]; return res.set_content(root.dump(), "application/json; charset=utf-8"); }; - const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { auto & la = ctx_server.lora_adapters[i]; @@ -4584,9 +4634,7 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - + const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { const std::vector body = json::parse(req.body); int max_idx = ctx_server.lora_adapters.size(); @@ -4618,8 +4666,7 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { json response = json::array(); namespace fs = std::filesystem; @@ -4679,22 +4726,20 @@ int main(int argc, char ** argv) { res.set_content(response.dump(), "application/json; charset=utf-8"); }; - const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { json response = json::array(); for (server_slot & slot : ctx_server.slots) { response.push_back({ {"slot_id", slot.id}, {"token_count", slot.cache_tokens.size()}, - {"prompt", tokens_to_str(ctx_server.ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())} + {"prompt", slot.cache_tokens.detokenize(ctx_server.ctx, true) } }); } res.set_content(response.dump(), "application/json; charset=utf-8"); }; - const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { json response; namespace fs = std::filesystem; @@ -4741,8 +4786,7 @@ int main(int argc, char ** argv) { res.set_content(response.dump(), "application/json; charset=utf-8"); }; - const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { json response; namespace fs = std::filesystem; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 860144ed..fbd67573 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -1304,11 +1304,12 @@ public: // encode and decode the image chunk int32_t process_chunk( - llama_context* ctx, - mtmd_context* mctx, + llama_context * ctx, + mtmd_context * mctx, llama_pos n_past, int32_t seq_id, - llama_pos& n_pos_out) { + llama_pos & n_pos_out, + size_t & n_tokens_out) { char buffer[512]; auto& chunk = find_chunk(n_past); const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE @@ -1325,21 +1326,25 @@ public: n_batch, true, // logits last &new_n_past); + // get number of tokens in the image + const size_t new_n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); snprintf(buffer, 512, "processed in %g ms", 1.*(ggml_time_ms() - t0)); LOG_INFO(buffer, {}); if (result != 0) { snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result); LOG_ERROR(buffer, {}); n_pos_out = n_past; + n_tokens_out = 0; return result; } n_pos_out = new_n_past; + n_tokens_out = new_n_tokens; return 0; } }; // Computes FNV-1a hash of the data -static std::string fnv_hash(const uint8_t* data, size_t len) { +static std::string fnv_hash(const uint8_t * data, size_t len) { const uint64_t fnv_prime = 0x100000001b3ULL; uint64_t hash = 0xcbf29ce484222325ULL; @@ -1350,7 +1355,7 @@ static std::string fnv_hash(const uint8_t* data, size_t len) { return std::to_string(hash); } -static server_tokens process_mtmd_prompt(mtmd_context* mctx, std::string prompt, std::vector files) { +static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { mtmd::bitmaps bitmaps; for (auto& file : files) { mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size()));