From 18f5a6caefb252777bf900feab69c3ad6bfa84ba Mon Sep 17 00:00:00 2001 From: firecoperana Date: Thu, 6 Nov 2025 05:10:51 +0000 Subject: [PATCH] Bug fixes for completions and prompt caching in server (#906) * Bug fixes for completions and prompt caching in server * Fix compiler warning about redefinition --------- Co-authored-by: firecoperana --- examples/server/server.cpp | 62 +++++++++++------------- examples/server/utils.hpp | 97 ++++++++++++++++++++++++++------------ 2 files changed, 96 insertions(+), 63 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index acc10a1b..4428af24 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -15,13 +15,7 @@ // crash the server in debug mode, otherwise send an http 500 error #define CPPHTTPLIB_NO_EXCEPTIONS 1 #endif -// increase max payload length to allow use of larger context size -#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 -// disable Nagle's algorithm -#define CPPHTTPLIB_TCP_NODELAY true -#include "httplib.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT + #include #include "index.html.gz.hpp" #include "index_llamacpp.html.gz.hpp" @@ -3050,7 +3044,7 @@ struct server_context { GGML_ASSERT(slot.ga_n == 1); // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_part(slot.cache_tokens.tokens_data(), prompt_tokens.tokens_data()); + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { @@ -3137,7 +3131,6 @@ struct server_context { { const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past); slot.cache_tokens.push_back(chunk.get()); // copy - fprintf(stdout, slot.cache_tokens.detokenize(ctx, true).c_str()); } slot.n_past += n_pos; @@ -4293,14 +4286,15 @@ int main(int argc, char ** argv) { } const auto& prompt = data.at("prompt"); - fprintf(stdout, prompt.get().c_str()); // process prompt std::vector inputs; if (oaicompat && ctx_server.mctx != nullptr) { // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - printFilesInfo(files); +#ifndef NDEBUG + print_files_info(files); +#endif // !NDEBUG inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); } else { @@ -4346,31 +4340,26 @@ int main(int argc, char ** argv) { if (!result.error) { result.oaicompat = oaicompat; result.oaicompat_cmpl_id = completion_id; - json result_array; + json res_json; if (oaicompat) { if (result.final_result) { - result_array = result.to_json_final(); + res_json = result.to_json_final(); } else { - result_array = result.to_json_partial(); + res_json = result.to_json_partial(); } } else { // legacy completions - result_array = result.data; + res_json = result.data; } - if (result_array.is_array()) { - for (auto it = result_array.begin(); it != result_array.end(); ++it) { - if (!it->empty()) { - const std::string str = - "data: " + - it->dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { {"to_send", str} }); - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_task_id(id_task); - return false; - } + if (res_json.is_array()) { + // chat completions and oai completions + for (const auto& res : res_json) { + if (!server_sent_event(sink, res)) { + // sending failed (HTTP connection closed), cancel the generation + ctx_server.queue_results.remove_waiting_task_id(id_task); + return false; } } if (result.stop) { @@ -4378,14 +4367,19 @@ int main(int argc, char ** argv) { break; } } + else { + // legacy completions + if (!server_sent_event(sink, res_json)) { + ctx_server.queue_results.remove_waiting_task_id(id_task); + return false; + } + if (result.stop) { + break; + } + } } else { - const std::string str = - "error: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { {"to_send", str} }); - if (!sink.write(str.c_str(), str.size())) { + if (!server_sent_event(sink, result.data)) { ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } @@ -4436,7 +4430,7 @@ int main(int argc, char ** argv) { data, files, res, - OAICOMPAT_TYPE_CHAT); + OAICOMPAT_TYPE_COMPLETION); }; const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f69b6401..3a411f50 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -15,6 +15,16 @@ #include #include +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +// increase backlog size to avoid connection resets for >> 1 slots +#define CPPHTTPLIB_LISTEN_BACKLOG 512 +// increase max URI length to handle longer prompts in query string +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768 +// disable Nagle's algorithm +#define CPPHTTPLIB_TCP_NODELAY true +#include "httplib.h" + #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" using json = nlohmann::ordered_json; @@ -411,6 +421,17 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector 0 && "Invalid media chunk"); // should never happen - i += a_pos - 1; // will be +1 by the for loop + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { continue; } - else { - return i; - } - } - else if (ai == bi) { - continue; - } - else { return i; } + return max_idx; } + + for (size_t i = 0; i < max_idx; ++i) { + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + const auto& a_chunk = find_chunk(i); + const auto& b_chunk = b.find_chunk(i); + + GGML_ASSERT(a_chunk && b_chunk); + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); + const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); + + if (id_ai == id_bi && pos_a == pos_b) { + GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen + i += pos_a - 1; // will be +1 by the for loop + continue; + } + + return i; + } + + if (ai == bi) { + continue; + } + + return i; + } + return max_idx; // all tokens are equal } + // make sure all text tokens are within the vocab range bool validate(const struct llama_context* ctx) const { const llama_model* model = llama_get_model(ctx); @@ -1274,10 +1309,12 @@ public: llama_pos n_past, int32_t seq_id, llama_pos& n_pos_out) { + char buffer[512]; auto& chunk = find_chunk(n_past); const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; - LOG_INFO("processing %s...\n", name); + snprintf(buffer, 512, "processing : %s",name); + LOG_INFO(buffer, {}); int32_t n_batch = llama_n_batch(ctx); int64_t t0 = ggml_time_ms(); llama_pos new_n_past = n_past; @@ -1288,9 +1325,11 @@ public: n_batch, true, // logits last &new_n_past); - LOG_INFO("processed in %" PRId64 " ms\n", ggml_time_ms() - t0); + snprintf(buffer, 512, "processed in %d ms", ggml_time_ms() - t0); + LOG_INFO(buffer, {}); if (result != 0) { - LOG_ERROR("mtmd_helper_eval failed with status %d", result); + snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result); + LOG_ERROR(buffer, {}); n_pos_out = n_past; return result; } @@ -1422,7 +1461,7 @@ static std::vector tokenize_input_prompts(const llama_vocab* voca return result; } // Assuming raw_buffer has .data() and .size() members -inline void printFilesInfo(const std::vector& files) { +inline void print_files_info(const std::vector& files) { for (size_t i = 0; i < files.size(); ++i) { const auto& file = files[i]; std::cout << "File " << i << ": Size = " << file.size() << " bytes\n";