mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 08:34:09 +00:00
server: stop processing the prompt when client disconnects (#1134)
implement generator-based API for task results Update httplib.h to 0.27.0 Fix embedding error Stop prompt processing when disconnected Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -9,6 +9,12 @@
|
||||
#include "sampling.h"
|
||||
#include "llama.h"
|
||||
#include "llama-vocab.h"
|
||||
#include <fstream>
|
||||
|
||||
|
||||
// mime type for sending response
|
||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
@@ -48,6 +54,7 @@ struct DatabaseHandle {
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
namespace fs = std::filesystem;
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
bool server_verbose = false;
|
||||
bool server_log_json = true;
|
||||
@@ -299,6 +306,117 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
||||
});
|
||||
}
|
||||
|
||||
// generator-like API for server responses, support pooling connection state and aggregating results
|
||||
struct server_response_reader {
|
||||
std::unordered_set<int> id_tasks;
|
||||
server_context& ctx_server;
|
||||
size_t received_count = 0;
|
||||
bool cancelled = false;
|
||||
|
||||
server_response_reader(server_context& ctx_server) : ctx_server(ctx_server) {}
|
||||
~server_response_reader() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void post_tasks(std::vector<server_task>&& tasks) {
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
bool has_next() {
|
||||
return !cancelled && received_count < id_tasks.size();
|
||||
}
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr next(const std::function<bool()>& should_stop) {
|
||||
while (true) {
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
||||
if (result == nullptr) {
|
||||
// timeout, check stop condition
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (result->is_error()) {
|
||||
stop(); // cancel remaining tasks
|
||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||
return result;
|
||||
}
|
||||
if (result->is_stop()) {
|
||||
received_count++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// should not reach here
|
||||
}
|
||||
|
||||
struct batch_response {
|
||||
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
|
||||
std::vector<server_task_result_ptr> results;
|
||||
server_task_result_ptr error; // nullptr if no error
|
||||
};
|
||||
|
||||
batch_response wait_for_all(const std::function<bool()>& should_stop) {
|
||||
batch_response batch_res;
|
||||
batch_res.results.resize(id_tasks.size());
|
||||
while (has_next()) {
|
||||
auto res = next(should_stop);
|
||||
if (res == nullptr) {
|
||||
batch_res.is_terminated = true;
|
||||
return batch_res;
|
||||
}
|
||||
if (res->error) {
|
||||
batch_res.error = std::move(res);
|
||||
return batch_res;
|
||||
}
|
||||
const size_t idx = res->get_index();
|
||||
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
||||
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
||||
batch_res.results[idx] = std::move(res);
|
||||
}
|
||||
return batch_res;
|
||||
}
|
||||
|
||||
void stop() {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(id_tasks);
|
||||
if (has_next() && !cancelled) {
|
||||
// if tasks is not finished yet, cancel them
|
||||
cancelled = true;
|
||||
std::vector<server_task> cancel_tasks;
|
||||
cancel_tasks.reserve(id_tasks.size());
|
||||
for (const auto& id_task : id_tasks) {
|
||||
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(std::move(task));
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
ctx_server.queue_tasks.post(std::move(cancel_tasks), true);
|
||||
}
|
||||
else {
|
||||
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto res_err = [](httplib::Response& res, json error_data) {
|
||||
json final_response{ {"error", error_data} };
|
||||
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
auto res_ok = [](httplib::Response& res, const json& data) {
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
res.status = 200;
|
||||
};
|
||||
|
||||
std::function<void(int)> shutdown_handler;
|
||||
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
@@ -380,18 +498,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
auto res_error = [](httplib::Response & res, json error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(final_response.dump(), "application/json; charset=utf-8");
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
auto res_ok = [](httplib::Response& res, const json& data) {
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
res.status = 200;
|
||||
};
|
||||
|
||||
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
std::string message;
|
||||
try {
|
||||
std::rethrow_exception(std::move(ep));
|
||||
@@ -403,14 +512,14 @@ int main(int argc, char ** argv) {
|
||||
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_VERBOSE("Got exception", formatted_error);
|
||||
res_error(res, formatted_error);
|
||||
res_err(res, formatted_error);
|
||||
});
|
||||
|
||||
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 404) {
|
||||
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
}
|
||||
// for other error codes, we skip processing here because it's already done by res_error()
|
||||
// for other error codes, we skip processing here because it's already done by res_err()
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
@@ -492,7 +601,7 @@ int main(int argc, char ** argv) {
|
||||
// Middlewares
|
||||
//
|
||||
|
||||
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) {
|
||||
static const std::unordered_set<std::string> public_endpoints = {
|
||||
"/health",
|
||||
"/v1/health",
|
||||
@@ -544,7 +653,7 @@ int main(int argc, char ** argv) {
|
||||
return false;
|
||||
};
|
||||
|
||||
auto middleware_server_state = [&res_error, &state](const httplib::Request& req, httplib::Response& res) {
|
||||
auto middleware_server_state = [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
server_state current_state = state.load();
|
||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||
auto tmp = string_split<std::string>(req.path, '.');
|
||||
@@ -557,7 +666,7 @@ int main(int argc, char ** argv) {
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -632,18 +741,18 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
case SERVER_STATE_LOADING_MODEL:
|
||||
{
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
} break;
|
||||
case SERVER_STATE_ERROR:
|
||||
{
|
||||
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
||||
res_err(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
||||
} break;
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_slots) {
|
||||
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -667,7 +776,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_metrics) {
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -766,11 +875,11 @@ int main(int argc, char ** argv) {
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_slots_save = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -790,17 +899,17 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
res_err(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_restore = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -820,13 +929,13 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
res_err(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
|
||||
task.data = {
|
||||
@@ -840,20 +949,20 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
res_err(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_slots_action = [&handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string id_slot_str = req.path_params.at("id_slot");
|
||||
int id_slot;
|
||||
|
||||
try {
|
||||
id_slot = std::stoi(id_slot_str);
|
||||
} catch (const std::exception &) {
|
||||
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -866,7 +975,7 @@ int main(int argc, char ** argv) {
|
||||
} else if (action == "erase") {
|
||||
handle_slots_erase(req, res, id_slot);
|
||||
} else {
|
||||
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -931,146 +1040,161 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// handle completion-like requests (completion, chat, infill)
|
||||
// we can optionally provide a custom format for partial results and final results
|
||||
const auto handle_completions_impl = [&ctx_server, ¶ms, &res_error, &res_ok](
|
||||
const auto handle_completions_impl = [&ctx_server, ¶ms](
|
||||
server_task_type type,
|
||||
json& data,
|
||||
const std::vector<raw_buffer>& files,
|
||||
const std::function<bool()>& is_connection_closed,
|
||||
httplib::Response& res,
|
||||
oaicompat_type oaicompat) -> void {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION);
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
// need to store the reader as a pointer, so that it won't be destroyed when the handle returns
|
||||
// use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
|
||||
const auto rd = std::make_shared<server_response_reader>(ctx_server);
|
||||
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
const auto& prompt = data.at("prompt");
|
||||
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
}
|
||||
else {
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.data = data;
|
||||
//task.params = server_task::params_from_json_cmpl(
|
||||
// ctx_server.ctx,
|
||||
// ctx_server.params,
|
||||
// data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
rd->post_tasks(std::move(tasks));
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& prompt = data.at("prompt");
|
||||
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
#ifndef NDEBUG
|
||||
print_files_info(files);
|
||||
#endif // !NDEBUG
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
}
|
||||
else {
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, data, false, false, std::move(inputs[0]));
|
||||
bool stream = json_value(data, "stream", false);
|
||||
if (!stream) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
result.oaicompat = oaicompat;
|
||||
result.oaicompat_cmpl_id = completion_id;
|
||||
json result_oai;
|
||||
if (oaicompat) {
|
||||
if (result.final_result) {
|
||||
result_oai = result.to_json_final();
|
||||
}
|
||||
else {
|
||||
result_oai = result.to_json_partial();
|
||||
}
|
||||
// non-stream, wait for the results
|
||||
auto all_results = rd->wait_for_all(is_connection_closed);
|
||||
if (all_results.is_terminated) {
|
||||
llama_decode_stop(); // send a signal to stop decode process
|
||||
return; // connection is closed
|
||||
}
|
||||
else if (all_results.error) {
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
}
|
||||
else {
|
||||
// legacy completions
|
||||
result_oai = result.data;
|
||||
}
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
}
|
||||
else {
|
||||
res_error(res, result_oai);
|
||||
}
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
json arr = json::array();
|
||||
for (auto& res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json());
|
||||
}
|
||||
// if single request, return single object instead of array
|
||||
res_ok(res, arr.size() == 1 ? arr[0] : arr);
|
||||
}
|
||||
}
|
||||
else {
|
||||
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, oaicompat, send_done = params.send_done](size_t, httplib::DataSink& sink) {
|
||||
bool successful_completion = false;
|
||||
const auto sse = [oaicompat, &sink](const json &res) {
|
||||
if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) {
|
||||
return server_sent_anthropic_event(sink, res);
|
||||
} else {
|
||||
return server_sent_event(sink, res);
|
||||
}
|
||||
};
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
if (!result.error) {
|
||||
result.oaicompat = oaicompat;
|
||||
result.oaicompat_cmpl_id = completion_id;
|
||||
json res_json;
|
||||
if (oaicompat) {
|
||||
if (result.final_result) {
|
||||
res_json = result.to_json_final();
|
||||
}
|
||||
else {
|
||||
res_json = result.to_json_partial();
|
||||
}
|
||||
}
|
||||
else {
|
||||
// legacy completions
|
||||
res_json = result.data;
|
||||
}
|
||||
if (res_json.is_array()) {
|
||||
// chat completions and oai completions
|
||||
for (const auto& res : res_json) {
|
||||
if (!sse(res)) {
|
||||
// sending failed (HTTP connection closed), cancel the generation
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (result.stop) {
|
||||
successful_completion = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// legacy completions
|
||||
if (!sse(res_json)) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
if (result.stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!sse(result.data)) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
// this is to match the OAI API behavior
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
|
||||
server_task_result_ptr first_result = rd->next(is_connection_closed);
|
||||
if (first_result == nullptr) {
|
||||
llama_decode_stop(); // send a signal to stop decode process
|
||||
return; // connection is closed
|
||||
}
|
||||
else if (first_result->is_error()) {
|
||||
res_err(res, first_result->to_json());
|
||||
return;
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
|
||||
);
|
||||
}
|
||||
// next responses are streamed
|
||||
json first_result_json = first_result->to_json();
|
||||
const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink& sink) mutable -> bool {
|
||||
// flush the first result as it's not an error
|
||||
if (!first_result_json.empty()) {
|
||||
if (!server_sent_event(sink, first_result_json)) {
|
||||
sink.done();
|
||||
return false; // sending failed, go to on_complete()
|
||||
}
|
||||
first_result_json.clear(); // mark as sent
|
||||
}
|
||||
bool ok = true;
|
||||
if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
static const std::string done_message = "data: [DONE]\n\n";
|
||||
LOG_VERBOSE("data stream", { {"to_send", done_message} });
|
||||
if (!sink.write(done_message.c_str(), done_message.size())) {
|
||||
// If writing [DONE] fails, the stream is likely already problematic.
|
||||
ok = false;
|
||||
}
|
||||
|
||||
// receive subsequent results
|
||||
auto result = rd->next([&sink] { return !sink.is_writable(); });
|
||||
if (result == nullptr) {
|
||||
sink.done();
|
||||
return false; // connection is closed, go to on_complete()
|
||||
}
|
||||
sink.done();
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return ok;
|
||||
|
||||
// send the results
|
||||
json res_json = result->to_json();
|
||||
bool ok = false;
|
||||
if (result->is_error()) {
|
||||
ok = server_sent_event(sink, json{ { "error", result->to_json() } });
|
||||
sink.done();
|
||||
return false; // go to on_complete()
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
ok = server_sent_event(sink, res_json);
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
sink.done();
|
||||
return false; // sending failed, go to on_complete()
|
||||
}
|
||||
|
||||
// check if there is more data
|
||||
if (!rd->has_next()) {
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
static const std::string ev_done = "data: [DONE]\n\n";
|
||||
sink.write(ev_done.data(), ev_done.size());
|
||||
}
|
||||
sink.done();
|
||||
return false; // no more data, go to on_complete()
|
||||
}
|
||||
|
||||
// has next data, continue
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [id_task, &ctx_server](bool) {
|
||||
// cancel request
|
||||
ctx_server.request_cancel(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
auto on_complete = [rd](bool) {
|
||||
rd->stop();
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
};
|
||||
@@ -1082,6 +1206,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
@@ -1094,6 +1219,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_COMPLETION);
|
||||
};
|
||||
@@ -1117,7 +1243,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
std::vector<raw_buffer> files;
|
||||
json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt, files);
|
||||
@@ -1125,6 +1251,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
@@ -1141,11 +1268,12 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_ANTHROPIC);
|
||||
};
|
||||
|
||||
const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
std::vector<raw_buffer> files;
|
||||
json body = json::parse(req.body);
|
||||
|
||||
@@ -1164,14 +1292,14 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request& req, httplib::Response& res) {
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) {
|
||||
auto body = json::parse(req.body);
|
||||
std::vector<raw_buffer> files; // dummy, unused
|
||||
json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt, files);
|
||||
res_ok(res, { { "prompt", std::move(data.at("prompt")) } });
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
server_tokens token; // dummy tokens
|
||||
@@ -1182,6 +1310,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||
};
|
||||
@@ -1211,58 +1340,119 @@ int main(int argc, char ** argv) {
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
|
||||
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
bool is_openai = false;
|
||||
|
||||
// an input prompt can be a string or a list of tokens (integer)
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
is_openai = true;
|
||||
prompt = body.at("input");
|
||||
} else if (body.count("content") != 0) {
|
||||
// with "content", we only support single prompt
|
||||
prompt = std::vector<std::string>{body.at("content")};
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request& req, httplib::Response& res, oaicompat_type oaicompat) {
|
||||
if (!ctx_server.params.embedding) {
|
||||
res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses;
|
||||
{
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
std::vector<server_tokens> inputs;
|
||||
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
|
||||
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true, std::move(inputs[0]));
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
if (!result.error) {
|
||||
if (result.data.count("results")) {
|
||||
// result for multi-task
|
||||
responses = result.data.at("results");
|
||||
} else {
|
||||
// result for single task
|
||||
responses = std::vector<json>{ result.data };
|
||||
}
|
||||
} else {
|
||||
// error received, ignore everything else
|
||||
res_error(res, result.data);
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
// for the shape of input/content, see tokenize_input_prompts()
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
prompt = body.at("input");
|
||||
}
|
||||
else if (body.contains("content")) {
|
||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||
prompt = body.at("content");
|
||||
}
|
||||
else {
|
||||
res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
bool use_base64 = false;
|
||||
if (body.count("encoding_format") != 0) {
|
||||
const std::string& format = body.at("encoding_format");
|
||||
if (format == "base64") {
|
||||
use_base64 = true;
|
||||
}
|
||||
else if (format != "float") {
|
||||
res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto vocab = llama_get_vocab(ctx_server.ctx);
|
||||
auto tokenized_prompts = tokenize_input_prompts(vocab, ctx_server.mctx, prompt, true, true);
|
||||
for (const auto& tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||
if (body.count("embd_normalize") != 0) {
|
||||
embd_normalize = body.at("embd_normalize");
|
||||
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
|
||||
}
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
server_response_reader rd(ctx_server);
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.embd_normalize = embd_normalize;
|
||||
task.embedding = true; // probably not needed
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
rd.post_tasks(std::move(tasks));
|
||||
}
|
||||
|
||||
// wait for the results
|
||||
auto all_results = rd.wait_for_all(req.is_connection_closed);
|
||||
|
||||
// collect results
|
||||
if (all_results.is_terminated) {
|
||||
llama_decode_stop();
|
||||
return; // connection is closed
|
||||
}
|
||||
else if (all_results.error) {
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
}
|
||||
else {
|
||||
for (auto& res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}
|
||||
|
||||
// write JSON response
|
||||
json root = is_openai
|
||||
? format_embeddings_response_oaicompat(body, responses, false)
|
||||
: responses[0];
|
||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||
: json(responses);
|
||||
res_ok(res, root);
|
||||
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) {
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
|
||||
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) {
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
|
||||
};
|
||||
|
||||
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
json result = json::array();
|
||||
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
|
||||
@@ -1277,6 +1467,7 @@ int main(int argc, char ** argv) {
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
const std::vector<json> body = json::parse(req.body);
|
||||
int max_idx = ctx_server.lora_adapters.size();
|
||||
@@ -1718,7 +1909,7 @@ int main(int argc, char ** argv) {
|
||||
svr->Post("/infill", handle_infill);
|
||||
svr->Post("/embedding", handle_embeddings); // legacy
|
||||
svr->Post("/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
svr->Post("/apply-template", handle_apply_template);
|
||||
|
||||
Reference in New Issue
Block a user