server: stop processing the prompt when client disconnects (#1134)

implement generator-based API for task results

Update httplib.h to 0.27.0

Fix embedding error

Stop prompt processing when disconnected

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-01-12 23:56:59 -06:00
committed by GitHub
parent d3e3ad40f9
commit 1a461525d5
24 changed files with 7654 additions and 4549 deletions

View File

@@ -421,7 +421,7 @@ int main(int argc, char ** argv) {
// int n_ctx = llama_n_ctx(ctx);
int n_layers = llama_n_layer(model);
int n_embd = llama_n_embd(model);
int n_embd = llama_model_n_embd(model);
// get model hint param (a.k.a model arch name)
char model_hint[128];
llama_model_meta_val_str(model, "general.architecture", model_hint, 128);

View File

@@ -72,7 +72,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
float * out = output + embd_pos * n_embd;
llama_embd_normalize(embd, out, n_embd, embd_norm);
common_embd_normalize(embd, out, n_embd, embd_norm);
}
}
@@ -187,7 +187,7 @@ int main(int argc, char ** argv) {
}
// allocate output
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data();
@@ -265,7 +265,7 @@ int main(int argc, char ** argv) {
fprintf(stdout, "\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
fprintf(stdout, "%1.10s", prompts[i].c_str());
@@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
fprintf(stdout, " [");
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f", sim);
j++;
if (j < n_embd_count) fprintf(stdout, ", "); else break;

View File

@@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_decode(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_n_embd(mdl);
uint64_t n_embd = llama_model_n_embd(mdl);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
@@ -74,7 +74,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}
std::vector<float> emb_norm(emb_unorm.size());
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
result.push_back(emb_norm);
#ifdef GRIT_DEBUG
@@ -191,12 +191,12 @@ int main(int argc, char * argv[]) {
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const int n_embd = llama_n_embd(mdl);
const int n_embd = llama_model_n_embd(mdl);
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
const float cosine_sim_q1_d0 = common_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q1_d1 = common_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);

View File

@@ -106,7 +106,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
float * out = output + batch.seq_id[i][0] * n_embd;
llama_embd_normalize(embd, out, n_embd);
common_embd_normalize(embd, out, n_embd);
}
}
@@ -215,7 +215,7 @@ int main(int argc, char ** argv) {
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_chunks * n_embd, 0);
float * emb = embeddings.data();
@@ -272,7 +272,7 @@ int main(int argc, char ** argv) {
{
std::vector<std::pair<int, float>> similarities;
for (int i = 0; i < n_chunks; i++) {
float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
similarities.push_back(std::make_pair(i, sim));
}

View File

@@ -12,7 +12,7 @@ endif()
set(TARGET_SRCS
server.cpp
httplib.h
# httplib.h
server-task.cpp
server-task.h
server-queue.cpp
@@ -78,7 +78,7 @@ target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${TARGET} PRIVATE ../mtmd)
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common mtmd cpp-httplib ${CMAKE_THREAD_LIBS_INIT})
if (LLAMA_SERVER_SSL)
find_package(OpenSSL REQUIRED)

File diff suppressed because it is too large Load Diff

View File

@@ -478,15 +478,31 @@ json probs_vector_to_json(const llama_context* ctx, const std::vector<completion
return out;
}
// note: if data is a json array, it will be sent as multiple events, one per item
bool server_sent_event(httplib::DataSink& sink, const json& data) {
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
static auto send_single = [](httplib::DataSink& sink, const json& data) -> bool {
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
//LOG_VERBOSE("data stream, to_send: %s", str.c_str());
LOG_DBG("data stream, to_send: %s", str.c_str());
return sink.write(str.c_str(), str.size());
};
return sink.write(str.c_str(), str.size());
if (data.is_array()) {
for (const auto& item : data) {
if (!send_single(sink, item)) {
return false;
}
}
}
else {
return send_single(sink, data);
}
return true;
}
bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data) {
@@ -2197,3 +2213,7 @@ bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
bool equal = common_cache == common_prompt;
return equal;
}
std::string safe_json_to_str(const json& data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}

View File

@@ -27,15 +27,15 @@
#include <random>
#include <set>
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
// increase backlog size to avoid connection resets for >> 1 slots
#define CPPHTTPLIB_LISTEN_BACKLOG 512
// increase max URI length to handle longer prompts in query string
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
// disable Nagle's algorithm
#define CPPHTTPLIB_TCP_NODELAY true
#include "httplib.h"
//// increase max payload length to allow use of larger context size
//#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
//// increase backlog size to avoid connection resets for >> 1 slots
//#define CPPHTTPLIB_LISTEN_BACKLOG 512
//// increase max URI length to handle longer prompts in query string
//#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
//// disable Nagle's algorithm
//#define CPPHTTPLIB_TCP_NODELAY true
#include <cpp-httplib/httplib.h>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
@@ -459,3 +459,6 @@ void print_files_info(const std::vector<raw_buffer>& files);
bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
const server_tokens& prompt_tokens, size_t start, const common_prefix& prefix);
std::string safe_json_to_str(const json& data);

View File

@@ -316,6 +316,7 @@ void server_slot::reset() {
stopped_limit = false;
stopping_word = "";
n_past = 0;
n_past_prompt = 0;
n_sent_text = 0;
drafted.clear();
@@ -402,7 +403,9 @@ void server_slot::release() {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = SLOT_COMMAND_RELEASE;
task.reset();
llama_decode_reset();
}
}
@@ -803,6 +806,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.oaicompat = false;
slot.oaicompat_model = "";
}
slot.params.oaicompat = task.params.oaicompat;
slot.params.oaicompat_cmpl_id =task.params.oaicompat_cmpl_id;
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
slot.params.stream = json_value(data, "stream", false);
auto stream_opt = json_value(data, "stream_options", json::object());
@@ -927,29 +932,6 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
// get prompt
if (!task.infill) {
// maybe not needed since prompt has been tokenized?
const auto& prompt = data.find("prompt");
if (!slot.prompt_tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if (prompt == data.end()) {
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if ((prompt->is_string()) ||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
}
else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
}
else {
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
slot.prompt_tokens = std::move(task.tokens);
}
@@ -1165,7 +1147,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
{"id_slot", slot.id},
{"id_task", slot.id_task},
});
slot.task = std::make_unique<const server_task>(std::move(task));
return true;
}
@@ -1490,66 +1472,66 @@ bool server_context::ensure_no_mtmd(const int id_task) {
}
void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) {
server_task_result res;
res.final_result = false;
res.id = slot.id_task;
res.id_multi = slot.id_multi;
res.error = false;
res.stop = false;
res.stream = slot.params.stream;
res.content = tkn.text_to_send;
res.post_sampling_probs = slot.params.post_sampling_probs;
res.oaicompat = slot.params.oaicompat;
res.oaicompat_model = slot.params.oaicompat_model;
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.data = json{
auto res = std::make_unique<server_task_result_cmpl_partial>();
res->final_result = false;
res->id = slot.id_task;
res->id_multi = slot.id_multi;
res->error = false;
res->stop = false;
res->stream = slot.params.stream;
res->content = tkn.text_to_send;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->data = json{
{"content", tkn.text_to_send},
{"stop", false},
{"id_slot", slot.id},
{"multimodal", false}
};
slot.update_chat_msg(res.oaicompat_msg_diffs);
slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
// populate res->probs_output
if (slot.sparams.n_probs > 0) {
res.probs_output = { tkn }; // copy the token probs
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
res->probs_output = { tkn }; // copy the token probs
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
}
if (slot.oaicompat) {
res.data["oaicompat_token_ctr"] = slot.n_decoded;
res.data["model"] = slot.oaicompat_model;
res->data["oaicompat_token_ctr"] = slot.n_decoded;
res->data["model"] = slot.oaicompat_model;
}
// populate timings if this is final response or timings_per_token is enabled
if (slot.params.timings_per_token) {
res.timings = slot.get_timings();
res->timings = slot.get_timings();
}
queue_results.send(std::move(res));
}
void server_context::send_final_response(server_slot& slot) {
server_task_result res;
res.final_result = true;
res.id = slot.id_task;
res.id_multi = slot.id_multi;
res.error = false;
res.stop = true; // to do: set value
res.stream = slot.params.stream;
res.include_usage = slot.params.include_usage;
res.content = slot.generated_text;
res.timings = slot.get_timings();
res.post_sampling_probs = slot.params.post_sampling_probs;
res.oaicompat = slot.params.oaicompat;
res.oaicompat_model = slot.params.oaicompat_model;
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs);
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.oaicompat_model = slot.oaicompat_model;
res.data = json{
auto res = std::make_unique<server_task_result_cmpl_final>();
res->final_result = true;
res->id = slot.id_task;
res->id_multi = slot.id_multi;
res->error = false;
res->stop = true; // to do: set value
res->stream = slot.params.stream;
res->include_usage = slot.params.include_usage;
res->content = slot.generated_text;
res->timings = slot.get_timings();
res->post_sampling_probs = slot.params.post_sampling_probs;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->oaicompat_model = slot.oaicompat_model;
res->data = json{
{"content", !slot.params.stream ? slot.generated_text : ""},
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
{"id_slot", slot.id},
@@ -1569,30 +1551,29 @@ void server_context::send_final_response(server_slot& slot) {
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
};
// populate res.probs_output
// populate res->probs_output
if (slot.sparams.n_probs > 0) {
res.probs_output = std::vector<completion_token_output>(
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
}
if (slot.oaicompat) {
res.data["oaicompat_token_ctr"] = slot.n_decoded;
res.data["model"] = slot.oaicompat_model;
res->data["oaicompat_token_ctr"] = slot.n_decoded;
res->data["model"] = slot.oaicompat_model;
}
queue_results.send(std::move(res));
}
void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) {
server_task_result res;
res.id = slot.id_task;
res.id_multi = slot.id_multi;
res.error = false;
res.stop = true;
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.task->id;
res->n_tokens = slot.prompt_tokens.size();
res->oaicompat = slot.task->params.oaicompat;
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f);
@@ -1601,34 +1582,31 @@ void server_context::send_embedding(const server_slot& slot, const llama_batch&
continue;
}
const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
const float* embd = nullptr;
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(ctx, i);
}
else {
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
}
if (embd == NULL) {
LOG_ERROR("failed to get embeddings", {
{"token", batch.token[i]},
{"seq_id", batch.seq_id[i][0]}
});
res.data = json{
{"embedding", std::vector<float>(n_embd, 0.0f)},
{"tokens_evaluated", slot.n_prompt_tokens},
};
if (embd == nullptr) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}
llama_embd_normalize(embd, embd_res.data(), n_embd);
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
res->embedding.push_back(embd_res);
break;
}
res.data = json{
{"embedding", embd_res},
{"tokens_evaluated", slot.n_prompt_tokens},
};
res->embedding.emplace_back(embd, embd + n_embd);
}
queue_results.send(res);
queue_results.send(std::move(res));
}
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
@@ -1708,6 +1686,9 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
void server_context::process_single_task(server_task&& task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
{
const int id_slot = json_value(task.data, "id_slot", -1);
@@ -2471,6 +2452,10 @@ void server_context::update_slots() {
// keep only the common part
// remove the non-common part from the cache
if (slot.n_past < 0)
{
slot.n_past = 0;
}
slot.cache_tokens.keep_first(slot.n_past);
int p0 = (int)system_tokens.size() + slot.n_past;
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
@@ -2549,7 +2534,7 @@ void server_context::update_slots() {
}
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
llama_batch_add(batch, cur_tok, p0, { slot.id }, false);
llama_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
slot.cache_tokens.push_back(cur_tok);
@@ -2663,18 +2648,31 @@ void server_context::update_slots() {
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
int user_cancel = -3;
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
if (ret == user_cancel) {
LOG_ERROR("Decode process is cancelled by user", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
} else {
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
}
for (auto& slot : slots) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release();
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
if (ret != user_cancel) {
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
}
break; // break loop of n_batch
}
@@ -2818,7 +2816,7 @@ json server_context::model_meta() const {
{"vocab_type", llama_vocab_type(model)},
{"n_vocab", llama_n_vocab(model)},
{"n_ctx_train", llama_n_ctx_train(model)},
{"n_embd", llama_n_embd(model)},
{"n_embd", llama_model_n_embd(model)},
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size(model)},
};

View File

@@ -22,43 +22,6 @@ enum slot_command {
SLOT_COMMAND_RELEASE,
};
struct slot_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
thinking_tokens think_tokens;
std::vector<std::string> antiprompt;
bool timings_per_token = false;
bool post_sampling_probs = false;
json input_prefix;
json input_suffix;
// speculative decoding parameters
struct {
int n_max = 16; // max drafted tokens
int n_min = 0; // min drafted tokens to accept
float p_min = 0.75f; // min probability required to accept a token in the draft
} speculative;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
};
struct server_slot {
int id;
int id_task = -1;

View File

@@ -28,6 +28,42 @@ int server_queue::post(server_task task) {
return task.id;
}
void server_queue::cleanup_pending_task(int id_target) {
// no need lock because this is called exclusively by post()
auto rm_func = [id_target](const server_task& task) {
return task.id == id_target;
};
queue_tasks.erase(
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
queue_tasks.end());
queue_tasks_deferred.erase(
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
queue_tasks_deferred.end());
}
// multi-task version of post()
int server_queue::post(std::vector<server_task>&& tasks, bool front) {
std::unique_lock<std::mutex> lock(mutex_tasks);
for (auto& task : tasks) {
if (task.id == -1) {
task.id = id++;
}
// if this is cancel task make sure to clean up pending tasks
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int)tasks.size(), front);
if (front) {
queue_tasks.push_front(std::move(task));
}
else {
queue_tasks.push_back(std::move(task));
}
}
condition_tasks.notify_one();
return 0;
}
void server_queue::defer(server_task&& task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
queue_tasks_deferred.push_back(std::move(task));
@@ -68,7 +104,7 @@ void server_queue::start_loop() {
break;
}
server_task task = std::move(queue_tasks.front());
queue_tasks.erase(queue_tasks.begin());
queue_tasks.pop_front();
lock.unlock();
//LOG_VERBOSE("callback_new_task", { {"id_task", task.id} });
callback_new_task(std::move(task));
@@ -134,13 +170,21 @@ void server_queue::update_multitask(int id_multi, int id_sub, server_task_result
void server_response::add_waiting_task_id(int id_task) {
//LOG_VERBOSE("waiting for task id", { {"id_task", id_task} });
QUE_DBG("waiting for task id, id = %d\n", id_task);
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int)waiting_task_ids.size());
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.insert(id_task);
}
void server_response::add_waiting_tasks(const std::vector<server_task>& tasks) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto& task : tasks) {
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int)waiting_task_ids.size());
waiting_task_ids.insert(task.id);
}
}
void server_response::remove_waiting_task_id(int id_task) {
//LOG_VERBOSE("remove waiting for task id", { {"id_task", id_task} });
QUE_DBG("remove waiting for task id, id = %d\n", id_task);
@@ -153,14 +197,14 @@ server_task_result server_response::recv(int id_task) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&] {
return !queue_results.empty();
return !queue_results_legacy.empty();
});
for (int i = 0; i < (int)queue_results.size(); i++) {
if (queue_results[i].id == id_task) {
assert(queue_results[i].id_multi == -1);
server_task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i);
for (int i = 0; i < (int)queue_results_legacy.size(); i++) {
if (queue_results_legacy[i].id == id_task) {
assert(queue_results_legacy[i].id_multi == -1);
server_task_result res = queue_results_legacy[i];
queue_results_legacy.erase(queue_results_legacy.begin() + i);
return res;
}
}
@@ -169,6 +213,41 @@ server_task_result server_response::recv(int id_task) {
// should never reach here
}
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int>& id_tasks, int timeout) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
for (int i = 0; i < (int)queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
}
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
if (!running) {
SRV_DBG("%s : queue result stop\n", __func__);
std::terminate(); // we cannot return here since the caller is HTTP code
}
if (cr_res == std::cv_status::timeout) {
return nullptr;
}
}
// should never reach here
}
void server_response::remove_waiting_task_ids(const std::unordered_set<int>& id_tasks) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto& id_task : id_tasks) {
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int)waiting_task_ids.size());
waiting_task_ids.erase(id_task);
}
}
void server_response::send(server_task_result result) {
//LOG_VERBOSE("send new result", { {"id_task", result.id} });
QUE_DBG("send new result, id = %d\n", result.id);
@@ -184,9 +263,25 @@ void server_response::send(server_task_result result) {
}
if (result.id == id_task) {
//LOG_VERBOSE("queue_results.push_back", { {"id_task", id_task} });
//LOG_VERBOSE("queue_results_legacy.push_back", { {"id_task", id_task} });
QUE_DBG("queue_results.push_back, id = %d\n", id_task);
queue_results.push_back(result);
queue_results_legacy.push_back(std::move(result));
condition_results.notify_all();
return;
}
}
}
// Send a new result to a waiting id_task
void server_response::send(server_task_result_ptr&& result) {
SRV_DBG("sending result for task id = %d\n", result->id);
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto& id_task : waiting_task_ids) {
if (result->id == id_task) {
SRV_DBG("task id = %d pushed to result queue\n", result->id);
queue_results.emplace_back(std::move(result));
condition_results.notify_all();
return;
}

View File

@@ -19,8 +19,8 @@ struct server_queue {
bool running;
// queues
std::vector<server_task> queue_tasks;
std::vector<server_task> queue_tasks_deferred;
std::deque<server_task> queue_tasks;
std::deque<server_task> queue_tasks_deferred;
std::vector<server_task_multi> queue_multitasks;
@@ -36,6 +36,10 @@ struct server_queue {
// Add a new task to the end of the queue
int post(server_task task);
int post(std::vector<server_task>&& tasks, bool front = false);
void cleanup_pending_task(int id_target);
// Add a new task, but defer until one slot is available
void defer(server_task&& task);
@@ -89,11 +93,14 @@ struct server_response {
typedef std::function<void(int, int, server_task_result&)> callback_multitask_t;
callback_multitask_t callback_update_multitask;
bool running = true;
// for keeping track of all tasks waiting for the result
std::set<int> waiting_task_ids;
// the main result queue
std::vector<server_task_result> queue_results;
// the main result queue (using ptr for polymorphism)
std::vector<server_task_result_ptr> queue_results;
std::vector<server_task_result> queue_results_legacy;
std::mutex mutex_results;
std::condition_variable condition_results;
@@ -101,12 +108,20 @@ struct server_response {
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task);
void add_waiting_tasks(const std::vector<server_task>& tasks);
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task);
void remove_waiting_task_ids(const std::unordered_set<int>& id_tasks);
// This function blocks the thread until there is a response for this id_task
server_task_result recv(int id_task);
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
server_task_result_ptr recv_with_timeout(const std::unordered_set<int>& id_tasks, int timeout);
// Register the function to update multitask
void on_multitask_update(callback_multitask_t callback) {
callback_update_multitask = std::move(callback);
@@ -114,4 +129,12 @@ struct server_response {
// Send a new result to a waiting id_task
void send(server_task_result result);
void send(server_task_result_ptr&& result);
// terminate the waiting loop
void terminate() {
running = false;
condition_results.notify_all();
};
};

View File

@@ -1,6 +1,5 @@
#include "server-task.h"
json result_timings::to_json() const {
json base = {
{"prompt_n", prompt_n},
@@ -26,83 +25,63 @@ json result_timings::to_json() const {
}
json server_task_result::to_json_final() {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
return to_json_non_oaicompat_final();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat_final();
case OAICOMPAT_TYPE_CHAT:
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
case OAICOMPAT_TYPE_ANTHROPIC:
return stream ? to_json_anthropic_stream() : to_json_anthropic_final();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
//json server_task_result_cmpl_partial::to_json_non_oaicompat_partial() {
// // non-OAI-compat JSON
// json res = json{
// {"index", index},
// {"content", content},
// {"tokens", tokens},
// {"stop", false},
// {"id_slot", id_multi},
// {"tokens_predicted", n_decoded},
// {"tokens_evaluated", n_prompt_tokens},
// };
// // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
// if (timings.prompt_n > 0) {
// res.push_back({ "timings", timings.to_json() });
// }
// if (!probs_output.empty()) {
// res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
// }
// return res;
//}
json server_task_result::to_json_partial() {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
return to_json_non_oaicompat_partial();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat_partial();
case OAICOMPAT_TYPE_CHAT:
return to_json_oaicompat_chat_partial();
case OAICOMPAT_TYPE_ANTHROPIC:
return to_json_anthropic_partial();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
//json server_task_result_cmpl_final::to_json_non_oaicompat_final() {
// json res = json{
// {"index", index},
// {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
// {"tokens", stream ? std::vector<llama_token> {} : tokens},
// {"id_slot", id_multi},
// {"stop", true},
// {"model", oaicompat_model},
// {"tokens_predicted", n_decoded},
// {"tokens_evaluated", n_prompt_tokens},
// //{"generation_settings", default_generation_settings_for_props.to_json()},
// {"prompt", prompt},
// {"has_new_line", has_new_line},
// {"truncated", truncated},
// //{"stop_type", stop_type_to_str(STOP_TYPE_EOS)},
// {"stopping_word", stopping_word},
// {"tokens_cached", n_tokens_cached},
// {"timings", timings.to_json()},
// };
// if (!stream && !probs_output.empty()) {
// res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
// }
// return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
//}
json server_task_result::to_json_non_oaicompat_partial() {
json server_task_result_cmpl_partial::to_json_non_oaicompat_partial() {
// non-OAI-compat JSON
json res = json{
{"index", index},
{"content", content},
{"tokens", tokens},
{"stop", false},
{"id_slot", id_multi},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
};
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
if (timings.prompt_n > 0) {
res.push_back({ "timings", timings.to_json() });
}
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
}
return res;
return data;
}
json server_task_result::to_json_non_oaicompat_final() {
json res = json{
{"index", index},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"tokens", stream ? std::vector<llama_token> {} : tokens},
{"id_slot", id_multi},
{"stop", true},
{"model", oaicompat_model},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
//{"generation_settings", default_generation_settings_for_props.to_json()},
{"prompt", prompt},
{"has_new_line", has_new_line},
{"truncated", truncated},
//{"stop_type", stop_type_to_str(STOP_TYPE_EOS)},
{"stopping_word", stopping_word},
{"tokens_cached", n_tokens_cached},
{"timings", timings.to_json()},
};
if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
}
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
json server_task_result_cmpl_final::to_json_non_oaicompat_final() {
// non-OAI-compat JSON
return data;
}
json server_task_result::to_json_oaicompat_partial() {
json server_task_result_cmpl_partial::to_json_oaicompat_partial() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (probs_output.size() > 0) {
@@ -141,7 +120,7 @@ json server_task_result::to_json_oaicompat_partial() {
return res;
}
json server_task_result::to_json_oaicompat_final() {
json server_task_result_cmpl_final::to_json_oaicompat_final() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (!stream && probs_output.size() > 0) {
@@ -184,7 +163,7 @@ json server_task_result::to_json_oaicompat_final() {
return res;
}
json server_task_result::to_json_oaicompat_chat_partial() {
json server_task_result_cmpl_partial::to_json_oaicompat_chat_partial() {
bool first = n_decoded == 1;
std::time_t t = std::time(0);
json choices;
@@ -239,7 +218,7 @@ json server_task_result::to_json_oaicompat_chat_partial() {
return deltas;
}
json server_task_result::to_json_oaicompat_chat_final() {
json server_task_result_cmpl_final::to_json_oaicompat_chat_final() {
std::string finish_reason = "length";
common_chat_msg msg;
if (!oaicompat_msg.empty()) {
@@ -292,7 +271,7 @@ json server_task_result::to_json_oaicompat_chat_final() {
return res;
}
json server_task_result::to_json_oaicompat_chat_stream() {
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
std::time_t t = std::time(0);
std::string finish_reason = "length";
if (stop) {
@@ -357,7 +336,7 @@ json server_task_result::to_json_oaicompat_chat_stream() {
return deltas;
}
json server_task_result::to_json_anthropic_final() {
json server_task_result_cmpl_final::to_json_anthropic_final() {
std::string stop_reason = "max_tokens";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
@@ -416,7 +395,7 @@ json server_task_result::to_json_anthropic_final() {
return res;
}
json server_task_result::to_json_anthropic_stream() {
json server_task_result_cmpl_final::to_json_anthropic_stream() {
json events = json::array();
std::string stop_reason = "max_tokens";
@@ -552,7 +531,7 @@ json server_task_result::to_json_anthropic_stream() {
return events;
}
json server_task_result::to_json_anthropic_partial() {
json server_task_result_cmpl_partial::to_json_anthropic_partial() {
json events = json::array();
bool first = n_decoded == 1;
static bool text_block_started = false;

View File

@@ -42,13 +42,53 @@ enum oaicompat_type {
};
struct slot_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
thinking_tokens think_tokens;
std::vector<std::string> antiprompt;
bool timings_per_token = false;
bool post_sampling_probs = false;
json input_prefix;
json input_suffix;
// speculative decoding parameters
struct {
int n_max = 16; // max drafted tokens
int n_min = 0; // min drafted tokens to accept
float p_min = 0.75f; // min probability required to accept a token in the draft
} speculative;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
};
struct server_task {
int id = -1; // to be filled by server_queue
int id_multi = -1;
int index = -1; // used when there are multiple prompts (batch request)
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
//int id_slot = -1;
int id_slot = -1;
// used by SERVER_TASK_TYPE_INFERENCE
struct slot_params params;
server_tokens tokens;
server_task_type type;
@@ -60,6 +100,18 @@ struct server_task {
server_task() = default;
server_task(server_task_type type) : type(type) {}
int32_t n_tokens() const {
return tokens.size();
}
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task>& tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
}
return ids;
}
};
struct result_timings {
@@ -126,40 +178,142 @@ struct server_task_result {
bool verbose = false;
virtual bool is_error() {
// only used by server_task_result_error
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_*
// in stream mode, final responses are considered stop
return true;
}
virtual json to_json() {
return {};
};
int get_index() {
return index;
}
bool is_stop() {
return true; // in stream mode, final responses are considered stop
};
struct server_task_result_cmpl_partial : server_task_result {
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
}
json to_json_final();
json to_json_partial();
json to_json_non_oaicompat_partial();
json to_json_non_oaicompat_final();
json to_json_oaicompat_partial();
json to_json_oaicompat_final();
json to_json_anthropic_partial();
json to_json_oaicompat_chat_partial();
json to_json_oaicompat_chat_final();
virtual json to_json() override {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
return to_json_non_oaicompat_partial();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat_partial();
case OAICOMPAT_TYPE_CHAT:
return to_json_oaicompat_chat_partial();
case OAICOMPAT_TYPE_ANTHROPIC:
return to_json_anthropic_partial();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
};
}
};
json to_json_oaicompat_chat_stream();
struct server_task_result_cmpl_final : server_task_result {
virtual bool is_stop() override {
return true;
}
json to_json_non_oaicompat_final();
json to_json_oaicompat_final();
json to_json_oaicompat_chat_final();
json to_json_anthropic_final();
json to_json_anthropic_stream();
json to_json_anthropic_partial();
json to_json_oaicompat_chat_stream();
virtual json to_json() override {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
return to_json_non_oaicompat_final();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat_final();
case OAICOMPAT_TYPE_CHAT:
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
case OAICOMPAT_TYPE_ANTHROPIC:
return stream ? to_json_anthropic_stream() : to_json_anthropic_final();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
};
struct server_task_result_error : server_task_result {
int index = 0;
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
int32_t n_prompt_tokens = 0;
int32_t n_ctx = 0;
virtual bool is_error() override {
return true;
}
virtual json to_json() override {
json res = format_error_response(err_msg, err_type);
return res;
}
};
struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<std::vector<float>> embedding;
int32_t n_tokens;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
virtual json to_json() override {
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
? to_json_oaicompat()
: to_json_non_oaicompat();
}
json to_json_non_oaicompat() {
return json{
{"index", index},
{"embedding", embedding},
};
}
json to_json_oaicompat() {
return json{
{"index", index},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
};
// using shared_ptr for polymorphism of server_task_result
using server_task_result_ptr = std::unique_ptr<server_task_result>;
struct server_prompt_checkpoint {
llama_pos pos_min;

View File

@@ -9,6 +9,12 @@
#include "sampling.h"
#include "llama.h"
#include "llama-vocab.h"
#include <fstream>
// mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
@@ -48,6 +54,7 @@ struct DatabaseHandle {
using json = nlohmann::ordered_json;
namespace fs = std::filesystem;
constexpr int HTTP_POLLING_SECONDS = 1;
bool server_verbose = false;
bool server_log_json = true;
@@ -299,6 +306,117 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
});
}
// generator-like API for server responses, support pooling connection state and aggregating results
struct server_response_reader {
std::unordered_set<int> id_tasks;
server_context& ctx_server;
size_t received_count = 0;
bool cancelled = false;
server_response_reader(server_context& ctx_server) : ctx_server(ctx_server) {}
~server_response_reader() {
stop();
}
void post_tasks(std::vector<server_task>&& tasks) {
id_tasks = server_task::get_list_id(tasks);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks));
}
bool has_next() {
return !cancelled && received_count < id_tasks.size();
}
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()>& should_stop) {
while (true) {
server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
if (result == nullptr) {
// timeout, check stop condition
if (should_stop()) {
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
return nullptr;
}
}
else {
if (result->is_error()) {
stop(); // cancel remaining tasks
SRV_DBG("%s", "received error result, stopping further processing\n");
return result;
}
if (result->is_stop()) {
received_count++;
}
return result;
}
}
// should not reach here
}
struct batch_response {
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
std::vector<server_task_result_ptr> results;
server_task_result_ptr error; // nullptr if no error
};
batch_response wait_for_all(const std::function<bool()>& should_stop) {
batch_response batch_res;
batch_res.results.resize(id_tasks.size());
while (has_next()) {
auto res = next(should_stop);
if (res == nullptr) {
batch_res.is_terminated = true;
return batch_res;
}
if (res->error) {
batch_res.error = std::move(res);
return batch_res;
}
const size_t idx = res->get_index();
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
batch_res.results[idx] = std::move(res);
}
return batch_res;
}
void stop() {
ctx_server.queue_results.remove_waiting_task_ids(id_tasks);
if (has_next() && !cancelled) {
// if tasks is not finished yet, cancel them
cancelled = true;
std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
for (const auto& id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task);
server_task task(SERVER_TASK_TYPE_CANCEL);
task.id_target = id_task;
ctx_server.queue_results.remove_waiting_task_id(id_task);
cancel_tasks.push_back(std::move(task));
}
// push to beginning of the queue, so it has highest priority
ctx_server.queue_tasks.post(std::move(cancel_tasks), true);
}
else {
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
}
}
};
auto res_err = [](httplib::Response& res, json error_data) {
json final_response{ {"error", error_data} };
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
};
auto res_ok = [](httplib::Response& res, const json& data) {
res.set_content(data.dump(), "application/json; charset=utf-8");
res.status = 200;
};
std::function<void(int)> shutdown_handler;
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
@@ -380,18 +498,9 @@ int main(int argc, char ** argv) {
svr->set_logger(log_server_request);
auto res_error = [](httplib::Response & res, json error_data) {
json final_response {{"error", error_data}};
res.set_content(final_response.dump(), "application/json; charset=utf-8");
res.status = json_value(error_data, "code", 500);
};
auto res_ok = [](httplib::Response& res, const json& data) {
res.set_content(data.dump(), "application/json; charset=utf-8");
res.status = 200;
};
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
std::string message;
try {
std::rethrow_exception(std::move(ep));
@@ -403,14 +512,14 @@ int main(int argc, char ** argv) {
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_VERBOSE("Got exception", formatted_error);
res_error(res, formatted_error);
res_err(res, formatted_error);
});
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) {
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
}
// for other error codes, we skip processing here because it's already done by res_error()
// for other error codes, we skip processing here because it's already done by res_err()
});
// set timeouts and change hostname and port
@@ -492,7 +601,7 @@ int main(int argc, char ** argv) {
// Middlewares
//
auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
auto middleware_validate_api_key = [&params](const httplib::Request & req, httplib::Response & res) {
static const std::unordered_set<std::string> public_endpoints = {
"/health",
"/v1/health",
@@ -544,7 +653,7 @@ int main(int argc, char ** argv) {
return false;
};
auto middleware_server_state = [&res_error, &state](const httplib::Request& req, httplib::Response& res) {
auto middleware_server_state = [&state](const httplib::Request& req, httplib::Response& res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split<std::string>(req.path, '.');
@@ -557,7 +666,7 @@ int main(int argc, char ** argv) {
return true;
}
else {
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
}
return false;
}
@@ -632,18 +741,18 @@ int main(int argc, char ** argv) {
}
case SERVER_STATE_LOADING_MODEL:
{
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
} break;
case SERVER_STATE_ERROR:
{
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
res_err(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
} break;
}
};
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
if (!params.endpoint_slots) {
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -667,7 +776,7 @@ int main(int argc, char ** argv) {
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!params.endpoint_metrics) {
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -766,11 +875,11 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
const auto handle_slots_save = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_save = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;
@@ -790,17 +899,17 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, result.data);
res_err(res, result.data);
} else {
res.set_content(result.data.dump(), "application/json");
}
};
const auto handle_slots_restore = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_restore = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;
@@ -820,13 +929,13 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, result.data);
res_err(res, result.data);
} else {
res.set_content(result.data.dump(), "application/json");
}
};
const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
server_task task;
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
task.data = {
@@ -840,20 +949,20 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, result.data);
res_err(res, result.data);
} else {
res.set_content(result.data.dump(), "application/json");
}
};
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
const auto handle_slots_action = [&handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
std::string id_slot_str = req.path_params.at("id_slot");
int id_slot;
try {
id_slot = std::stoi(id_slot_str);
} catch (const std::exception &) {
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -866,7 +975,7 @@ int main(int argc, char ** argv) {
} else if (action == "erase") {
handle_slots_erase(req, res, id_slot);
} else {
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
}
};
@@ -931,146 +1040,161 @@ int main(int argc, char ** argv) {
// handle completion-like requests (completion, chat, infill)
// we can optionally provide a custom format for partial results and final results
const auto handle_completions_impl = [&ctx_server, &params, &res_error, &res_ok](
const auto handle_completions_impl = [&ctx_server, &params](
server_task_type type,
json& data,
const std::vector<raw_buffer>& files,
const std::function<bool()>& is_connection_closed,
httplib::Response& res,
oaicompat_type oaicompat) -> void {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION);
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
const auto completion_id = gen_chatcmplid();
// need to store the reader as a pointer, so that it won't be destroyed when the handle returns
// use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
const auto rd = std::make_shared<server_response_reader>(ctx_server);
try {
std::vector<server_task> tasks;
const auto& prompt = data.at("prompt");
// process prompt
std::vector<server_tokens> inputs;
if (oaicompat && ctx_server.mctx != nullptr) {
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
}
else {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.tokens = std::move(inputs[i]);
task.data = data;
//task.params = server_task::params_from_json_cmpl(
// ctx_server.ctx,
// ctx_server.params,
// data);
task.id_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id;
tasks.push_back(std::move(task));
}
rd->post_tasks(std::move(tasks));
}
catch (const std::exception& e) {
res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
return;
}
const auto& prompt = data.at("prompt");
// process prompt
std::vector<server_tokens> inputs;
if (oaicompat && ctx_server.mctx != nullptr) {
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
#ifndef NDEBUG
print_files_info(files);
#endif // !NDEBUG
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
}
else {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
}
const auto completion_id = gen_chatcmplid();
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, data, false, false, std::move(inputs[0]));
bool stream = json_value(data, "stream", false);
if (!stream) {
server_task_result result = ctx_server.queue_results.recv(id_task);
result.oaicompat = oaicompat;
result.oaicompat_cmpl_id = completion_id;
json result_oai;
if (oaicompat) {
if (result.final_result) {
result_oai = result.to_json_final();
}
else {
result_oai = result.to_json_partial();
}
// non-stream, wait for the results
auto all_results = rd->wait_for_all(is_connection_closed);
if (all_results.is_terminated) {
llama_decode_stop(); // send a signal to stop decode process
return; // connection is closed
}
else if (all_results.error) {
res_err(res, all_results.error->to_json());
return;
}
else {
// legacy completions
result_oai = result.data;
}
if (!result.error && result.stop) {
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
}
else {
res_error(res, result_oai);
}
ctx_server.queue_results.remove_waiting_task_id(id_task);
json arr = json::array();
for (auto& res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
arr.push_back(res->to_json());
}
// if single request, return single object instead of array
res_ok(res, arr.size() == 1 ? arr[0] : arr);
}
}
else {
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, oaicompat, send_done = params.send_done](size_t, httplib::DataSink& sink) {
bool successful_completion = false;
const auto sse = [oaicompat, &sink](const json &res) {
if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) {
return server_sent_anthropic_event(sink, res);
} else {
return server_sent_event(sink, res);
}
};
while (true) {
server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error) {
result.oaicompat = oaicompat;
result.oaicompat_cmpl_id = completion_id;
json res_json;
if (oaicompat) {
if (result.final_result) {
res_json = result.to_json_final();
}
else {
res_json = result.to_json_partial();
}
}
else {
// legacy completions
res_json = result.data;
}
if (res_json.is_array()) {
// chat completions and oai completions
for (const auto& res : res_json) {
if (!sse(res)) {
// sending failed (HTTP connection closed), cancel the generation
ctx_server.queue_results.remove_waiting_task_id(id_task);
return false;
}
}
if (result.stop) {
successful_completion = true;
break;
}
}
else {
// legacy completions
if (!sse(res_json)) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
return false;
}
if (result.stop) {
break;
}
}
}
else {
if (!sse(result.data)) {
ctx_server.queue_results.remove_waiting_task_id(id_task);
return false;
}
break;
// in streaming mode, the first error must be treated as non-stream response
// this is to match the OAI API behavior
// ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
server_task_result_ptr first_result = rd->next(is_connection_closed);
if (first_result == nullptr) {
llama_decode_stop(); // send a signal to stop decode process
return; // connection is closed
}
else if (first_result->is_error()) {
res_err(res, first_result->to_json());
return;
}
else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
);
}
// next responses are streamed
json first_result_json = first_result->to_json();
const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink& sink) mutable -> bool {
// flush the first result as it's not an error
if (!first_result_json.empty()) {
if (!server_sent_event(sink, first_result_json)) {
sink.done();
return false; // sending failed, go to on_complete()
}
first_result_json.clear(); // mark as sent
}
bool ok = true;
if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) {
static const std::string done_message = "data: [DONE]\n\n";
LOG_VERBOSE("data stream", { {"to_send", done_message} });
if (!sink.write(done_message.c_str(), done_message.size())) {
// If writing [DONE] fails, the stream is likely already problematic.
ok = false;
}
// receive subsequent results
auto result = rd->next([&sink] { return !sink.is_writable(); });
if (result == nullptr) {
sink.done();
return false; // connection is closed, go to on_complete()
}
sink.done();
ctx_server.queue_results.remove_waiting_task_id(id_task);
return ok;
// send the results
json res_json = result->to_json();
bool ok = false;
if (result->is_error()) {
ok = server_sent_event(sink, json{ { "error", result->to_json() } });
sink.done();
return false; // go to on_complete()
}
else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
ok = server_sent_event(sink, res_json);
}
if (!ok) {
sink.done();
return false; // sending failed, go to on_complete()
}
// check if there is more data
if (!rd->has_next()) {
if (oaicompat != OAICOMPAT_TYPE_NONE) {
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
}
sink.done();
return false; // no more data, go to on_complete()
}
// has next data, continue
return true;
};
auto on_complete = [id_task, &ctx_server](bool) {
// cancel request
ctx_server.request_cancel(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
auto on_complete = [rd](bool) {
rd->stop();
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
@@ -1082,6 +1206,7 @@ int main(int argc, char ** argv) {
SERVER_TASK_TYPE_COMPLETION,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE);
};
@@ -1094,6 +1219,7 @@ int main(int argc, char ** argv) {
SERVER_TASK_TYPE_COMPLETION,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_COMPLETION);
};
@@ -1117,7 +1243,7 @@ int main(int argc, char ** argv) {
const auto handle_chat_completions = [&ctx_server, &params, &handle_completions_impl, &res_error](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &params, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files;
json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt, files);
@@ -1125,6 +1251,7 @@ int main(int argc, char ** argv) {
SERVER_TASK_TYPE_COMPLETION,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_CHAT);
};
@@ -1141,11 +1268,12 @@ int main(int argc, char ** argv) {
SERVER_TASK_TYPE_COMPLETION,
body_parsed,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_ANTHROPIC);
};
const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
std::vector<raw_buffer> files;
json body = json::parse(req.body);
@@ -1164,14 +1292,14 @@ int main(int argc, char ** argv) {
};
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request& req, httplib::Response& res) {
const auto handle_apply_template = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files; // dummy, unused
json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt, files);
res_ok(res, { { "prompt", std::move(data.at("prompt")) } });
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
const int id_task = ctx_server.queue_tasks.get_new_id();
server_tokens token; // dummy tokens
@@ -1182,6 +1310,7 @@ int main(int argc, char ** argv) {
SERVER_TASK_TYPE_INFILL,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
@@ -1211,58 +1340,119 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
};
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
bool is_openai = false;
// an input prompt can be a string or a list of tokens (integer)
json prompt;
if (body.count("input") != 0) {
is_openai = true;
prompt = body.at("input");
} else if (body.count("content") != 0) {
// with "content", we only support single prompt
prompt = std::vector<std::string>{body.at("content")};
} else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request& req, httplib::Response& res, oaicompat_type oaicompat) {
if (!ctx_server.params.embedding) {
res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
// create and queue the task
json responses;
{
const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task);
std::vector<server_tokens> inputs;
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true, std::move(inputs[0]));
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (!result.error) {
if (result.data.count("results")) {
// result for multi-task
responses = result.data.at("results");
} else {
// result for single task
responses = std::vector<json>{ result.data };
}
} else {
// error received, ignore everything else
res_error(res, result.data);
const json body = json::parse(req.body);
// for the shape of input/content, see tokenize_input_prompts()
json prompt;
if (body.count("input") != 0) {
prompt = body.at("input");
}
else if (body.contains("content")) {
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
prompt = body.at("content");
}
else {
res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
bool use_base64 = false;
if (body.count("encoding_format") != 0) {
const std::string& format = body.at("encoding_format");
if (format == "base64") {
use_base64 = true;
}
else if (format != "float") {
res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
auto vocab = llama_get_vocab(ctx_server.ctx);
auto tokenized_prompts = tokenize_input_prompts(vocab, ctx_server.mctx, prompt, true, true);
for (const auto& tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
int embd_normalize = 2; // default to Euclidean/L2 norm
if (body.count("embd_normalize") != 0) {
embd_normalize = body.at("embd_normalize");
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
}
}
// create and queue the task
json responses = json::array();
server_response_reader rd(ctx_server);
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.tokens = std::move(tokenized_prompts[i]);
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.embd_normalize = embd_normalize;
task.embedding = true; // probably not needed
tasks.push_back(std::move(task));
}
rd.post_tasks(std::move(tasks));
}
// wait for the results
auto all_results = rd.wait_for_all(req.is_connection_closed);
// collect results
if (all_results.is_terminated) {
llama_decode_stop();
return; // connection is closed
}
else if (all_results.error) {
res_err(res, all_results.error->to_json());
return;
}
else {
for (auto& res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
}
// write JSON response
json root = is_openai
? format_embeddings_response_oaicompat(body, responses, false)
: responses[0];
return res.set_content(root.dump(), "application/json; charset=utf-8");
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
? format_embeddings_response_oaicompat(body, responses, use_base64)
: json(responses);
res_ok(res, root);
};
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) {
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
};
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) {
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
};
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
json result = json::array();
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
@@ -1277,6 +1467,7 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.lora_adapters.size();
@@ -1718,7 +1909,7 @@ int main(int argc, char ** argv) {
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
svr->Post("/apply-template", handle_apply_template);