mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
server: stop processing the prompt when client disconnects (#1134)
implement generator-based API for task results Update httplib.h to 0.27.0 Fix embedding error Stop prompt processing when disconnected Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -421,7 +421,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// int n_ctx = llama_n_ctx(ctx);
|
||||
int n_layers = llama_n_layer(model);
|
||||
int n_embd = llama_n_embd(model);
|
||||
int n_embd = llama_model_n_embd(model);
|
||||
// get model hint param (a.k.a model arch name)
|
||||
char model_hint[128];
|
||||
llama_model_meta_val_str(model, "general.architecture", model_hint, 128);
|
||||
|
||||
@@ -72,7 +72,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
}
|
||||
|
||||
float * out = output + embd_pos * n_embd;
|
||||
llama_embd_normalize(embd, out, n_embd, embd_norm);
|
||||
common_embd_normalize(embd, out, n_embd, embd_norm);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_n_embd(model);
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
std::vector<float> embeddings(n_embd_count * n_embd, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
@@ -265,7 +265,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stdout, "\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f ", sim);
|
||||
}
|
||||
fprintf(stdout, "%1.10s", prompts[i].c_str());
|
||||
@@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
|
||||
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
fprintf(stdout, " [");
|
||||
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f", sim);
|
||||
j++;
|
||||
if (j < n_embd_count) fprintf(stdout, ", "); else break;
|
||||
|
||||
@@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
llama_decode(ctx, batch);
|
||||
|
||||
// get embedding dimensions
|
||||
uint64_t n_embd = llama_n_embd(mdl);
|
||||
uint64_t n_embd = llama_model_n_embd(mdl);
|
||||
|
||||
// allocate embedding output
|
||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||
@@ -74,7 +74,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
}
|
||||
|
||||
std::vector<float> emb_norm(emb_unorm.size());
|
||||
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
||||
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
||||
result.push_back(emb_norm);
|
||||
|
||||
#ifdef GRIT_DEBUG
|
||||
@@ -191,12 +191,12 @@ int main(int argc, char * argv[]) {
|
||||
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
||||
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
||||
|
||||
const int n_embd = llama_n_embd(mdl);
|
||||
const int n_embd = llama_model_n_embd(mdl);
|
||||
|
||||
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
||||
const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
|
||||
const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
||||
const float cosine_sim_q1_d0 = common_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q1_d1 = common_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
|
||||
|
||||
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
|
||||
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
|
||||
|
||||
@@ -106,7 +106,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
}
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
llama_embd_normalize(embd, out, n_embd);
|
||||
common_embd_normalize(embd, out, n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ int main(int argc, char ** argv) {
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_n_embd(model);
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
std::vector<float> embeddings(n_chunks * n_embd, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
@@ -272,7 +272,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
std::vector<std::pair<int, float>> similarities;
|
||||
for (int i = 0; i < n_chunks; i++) {
|
||||
float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
|
||||
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
|
||||
similarities.push_back(std::make_pair(i, sim));
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ endif()
|
||||
|
||||
set(TARGET_SRCS
|
||||
server.cpp
|
||||
httplib.h
|
||||
# httplib.h
|
||||
server-task.cpp
|
||||
server-task.h
|
||||
server-queue.cpp
|
||||
@@ -78,7 +78,7 @@ target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ../mtmd)
|
||||
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE common mtmd cpp-httplib ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if (LLAMA_SERVER_SSL)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -478,15 +478,31 @@ json probs_vector_to_json(const llama_context* ctx, const std::vector<completion
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||
bool server_sent_event(httplib::DataSink& sink, const json& data) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
||||
static auto send_single = [](httplib::DataSink& sink, const json& data) -> bool {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
||||
|
||||
//LOG_VERBOSE("data stream, to_send: %s", str.c_str());
|
||||
LOG_DBG("data stream, to_send: %s", str.c_str());
|
||||
return sink.write(str.c_str(), str.size());
|
||||
};
|
||||
|
||||
return sink.write(str.c_str(), str.size());
|
||||
if (data.is_array()) {
|
||||
for (const auto& item : data) {
|
||||
if (!send_single(sink, item)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return send_single(sink, data);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data) {
|
||||
@@ -2197,3 +2213,7 @@ bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
|
||||
bool equal = common_cache == common_prompt;
|
||||
return equal;
|
||||
}
|
||||
|
||||
std::string safe_json_to_str(const json& data) {
|
||||
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
}
|
||||
|
||||
@@ -27,15 +27,15 @@
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
// increase backlog size to avoid connection resets for >> 1 slots
|
||||
#define CPPHTTPLIB_LISTEN_BACKLOG 512
|
||||
// increase max URI length to handle longer prompts in query string
|
||||
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
|
||||
// disable Nagle's algorithm
|
||||
#define CPPHTTPLIB_TCP_NODELAY true
|
||||
#include "httplib.h"
|
||||
//// increase max payload length to allow use of larger context size
|
||||
//#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
//// increase backlog size to avoid connection resets for >> 1 slots
|
||||
//#define CPPHTTPLIB_LISTEN_BACKLOG 512
|
||||
//// increase max URI length to handle longer prompts in query string
|
||||
//#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
|
||||
//// disable Nagle's algorithm
|
||||
//#define CPPHTTPLIB_TCP_NODELAY true
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
@@ -459,3 +459,6 @@ void print_files_info(const std::vector<raw_buffer>& files);
|
||||
|
||||
bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
|
||||
const server_tokens& prompt_tokens, size_t start, const common_prefix& prefix);
|
||||
|
||||
std::string safe_json_to_str(const json& data);
|
||||
|
||||
|
||||
@@ -316,6 +316,7 @@ void server_slot::reset() {
|
||||
stopped_limit = false;
|
||||
stopping_word = "";
|
||||
n_past = 0;
|
||||
n_past_prompt = 0;
|
||||
n_sent_text = 0;
|
||||
|
||||
drafted.clear();
|
||||
@@ -402,7 +403,9 @@ void server_slot::release() {
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
command = SLOT_COMMAND_RELEASE;
|
||||
task.reset();
|
||||
llama_decode_reset();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -803,6 +806,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.oaicompat = false;
|
||||
slot.oaicompat_model = "";
|
||||
}
|
||||
slot.params.oaicompat = task.params.oaicompat;
|
||||
slot.params.oaicompat_cmpl_id =task.params.oaicompat_cmpl_id;
|
||||
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
auto stream_opt = json_value(data, "stream_options", json::object());
|
||||
@@ -927,29 +932,6 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
|
||||
// get prompt
|
||||
if (!task.infill) {
|
||||
// maybe not needed since prompt has been tokenized?
|
||||
const auto& prompt = data.find("prompt");
|
||||
if (!slot.prompt_tokens.validate(ctx)) {
|
||||
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((prompt->is_string()) ||
|
||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
||||
slot.prompt = *prompt;
|
||||
}
|
||||
else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||
slot.prompt = prompt->at(0);
|
||||
}
|
||||
else {
|
||||
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
slot.prompt_tokens = std::move(task.tokens);
|
||||
}
|
||||
|
||||
@@ -1165,7 +1147,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
});
|
||||
|
||||
slot.task = std::make_unique<const server_task>(std::move(task));
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1490,66 +1472,66 @@ bool server_context::ensure_no_mtmd(const int id_task) {
|
||||
}
|
||||
|
||||
void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) {
|
||||
server_task_result res;
|
||||
res.final_result = false;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = false;
|
||||
res.stream = slot.params.stream;
|
||||
res.content = tkn.text_to_send;
|
||||
res.post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res.oaicompat = slot.params.oaicompat;
|
||||
res.oaicompat_model = slot.params.oaicompat_model;
|
||||
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res.n_decoded = slot.n_decoded;
|
||||
res.n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res.data = json{
|
||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||
res->final_result = false;
|
||||
res->id = slot.id_task;
|
||||
res->id_multi = slot.id_multi;
|
||||
res->error = false;
|
||||
res->stop = false;
|
||||
res->stream = slot.params.stream;
|
||||
res->content = tkn.text_to_send;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->data = json{
|
||||
{"content", tkn.text_to_send},
|
||||
{"stop", false},
|
||||
{"id_slot", slot.id},
|
||||
{"multimodal", false}
|
||||
};
|
||||
slot.update_chat_msg(res.oaicompat_msg_diffs);
|
||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
// populate res->probs_output
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
res.probs_output = { tkn }; // copy the token probs
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
||||
res->probs_output = { tkn }; // copy the token probs
|
||||
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
||||
}
|
||||
|
||||
if (slot.oaicompat) {
|
||||
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res.data["model"] = slot.oaicompat_model;
|
||||
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res->data["model"] = slot.oaicompat_model;
|
||||
}
|
||||
|
||||
// populate timings if this is final response or timings_per_token is enabled
|
||||
if (slot.params.timings_per_token) {
|
||||
res.timings = slot.get_timings();
|
||||
res->timings = slot.get_timings();
|
||||
}
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void server_context::send_final_response(server_slot& slot) {
|
||||
server_task_result res;
|
||||
res.final_result = true;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = true; // to do: set value
|
||||
res.stream = slot.params.stream;
|
||||
res.include_usage = slot.params.include_usage;
|
||||
res.content = slot.generated_text;
|
||||
res.timings = slot.get_timings();
|
||||
res.post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res.oaicompat = slot.params.oaicompat;
|
||||
res.oaicompat_model = slot.params.oaicompat_model;
|
||||
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs);
|
||||
res.n_decoded = slot.n_decoded;
|
||||
res.n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res.oaicompat_model = slot.oaicompat_model;
|
||||
res.data = json{
|
||||
auto res = std::make_unique<server_task_result_cmpl_final>();
|
||||
res->final_result = true;
|
||||
res->id = slot.id_task;
|
||||
res->id_multi = slot.id_multi;
|
||||
res->error = false;
|
||||
res->stop = true; // to do: set value
|
||||
res->stream = slot.params.stream;
|
||||
res->include_usage = slot.params.include_usage;
|
||||
res->content = slot.generated_text;
|
||||
res->timings = slot.get_timings();
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->oaicompat_model = slot.oaicompat_model;
|
||||
res->data = json{
|
||||
{"content", !slot.params.stream ? slot.generated_text : ""},
|
||||
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
|
||||
{"id_slot", slot.id},
|
||||
@@ -1569,30 +1551,29 @@ void server_context::send_final_response(server_slot& slot) {
|
||||
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
|
||||
};
|
||||
|
||||
// populate res.probs_output
|
||||
// populate res->probs_output
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
res.probs_output = std::vector<completion_token_output>(
|
||||
res->probs_output = std::vector<completion_token_output>(
|
||||
slot.generated_token_probs.begin(),
|
||||
slot.generated_token_probs.end());
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
||||
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
||||
}
|
||||
|
||||
if (slot.oaicompat) {
|
||||
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res.data["model"] = slot.oaicompat_model;
|
||||
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res->data["model"] = slot.oaicompat_model;
|
||||
}
|
||||
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
auto res = std::make_unique<server_task_result_embd>();
|
||||
res->id = slot.task->id;
|
||||
res->n_tokens = slot.prompt_tokens.size();
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
|
||||
std::vector<float> embd_res(n_embd, 0.0f);
|
||||
|
||||
@@ -1601,34 +1582,31 @@ void server_context::send_embedding(const server_slot& slot, const llama_batch&
|
||||
continue;
|
||||
}
|
||||
|
||||
const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
const float* embd = nullptr;
|
||||
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
else {
|
||||
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
LOG_ERROR("failed to get embeddings", {
|
||||
{"token", batch.token[i]},
|
||||
{"seq_id", batch.seq_id[i][0]}
|
||||
});
|
||||
|
||||
res.data = json{
|
||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
};
|
||||
if (embd == nullptr) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||
|
||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_embd_normalize(embd, embd_res.data(), n_embd);
|
||||
// normalize only when there is pooling
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
|
||||
res->embedding.push_back(embd_res);
|
||||
break;
|
||||
}
|
||||
|
||||
res.data = json{
|
||||
{"embedding", embd_res},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
};
|
||||
res->embedding.emplace_back(embd, embd + n_embd);
|
||||
}
|
||||
|
||||
queue_results.send(res);
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
|
||||
@@ -1708,6 +1686,9 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
|
||||
void server_context::process_single_task(server_task&& task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
case SERVER_TASK_TYPE_EMBEDDING:
|
||||
case SERVER_TASK_TYPE_RERANK:
|
||||
{
|
||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||
|
||||
@@ -2471,6 +2452,10 @@ void server_context::update_slots() {
|
||||
|
||||
// keep only the common part
|
||||
// remove the non-common part from the cache
|
||||
if (slot.n_past < 0)
|
||||
{
|
||||
slot.n_past = 0;
|
||||
}
|
||||
slot.cache_tokens.keep_first(slot.n_past);
|
||||
int p0 = (int)system_tokens.size() + slot.n_past;
|
||||
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
||||
@@ -2549,7 +2534,7 @@ void server_context::update_slots() {
|
||||
}
|
||||
|
||||
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
||||
llama_batch_add(batch, cur_tok, p0, { slot.id }, false);
|
||||
llama_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
|
||||
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
|
||||
@@ -2663,18 +2648,31 @@ void server_context::update_slots() {
|
||||
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
int user_cancel = -3;
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
{"ret", ret},
|
||||
});
|
||||
if (ret == user_cancel) {
|
||||
LOG_ERROR("Decode process is cancelled by user", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
{"ret", ret},
|
||||
});
|
||||
} else {
|
||||
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
{"ret", ret},
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& slot : slots) {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
slot.release();
|
||||
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
if (ret != user_cancel) {
|
||||
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
}
|
||||
|
||||
}
|
||||
break; // break loop of n_batch
|
||||
}
|
||||
@@ -2818,7 +2816,7 @@ json server_context::model_meta() const {
|
||||
{"vocab_type", llama_vocab_type(model)},
|
||||
{"n_vocab", llama_n_vocab(model)},
|
||||
{"n_ctx_train", llama_n_ctx_train(model)},
|
||||
{"n_embd", llama_n_embd(model)},
|
||||
{"n_embd", llama_model_n_embd(model)},
|
||||
{"n_params", llama_model_n_params(model)},
|
||||
{"size", llama_model_size(model)},
|
||||
};
|
||||
|
||||
@@ -22,43 +22,6 @@ enum slot_command {
|
||||
SLOT_COMMAND_RELEASE,
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
bool include_usage = false;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
thinking_tokens think_tokens;
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
|
||||
// speculative decoding parameters
|
||||
struct {
|
||||
int n_max = 16; // max drafted tokens
|
||||
int n_min = 0; // min drafted tokens to accept
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
} speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
|
||||
@@ -28,6 +28,42 @@ int server_queue::post(server_task task) {
|
||||
return task.id;
|
||||
}
|
||||
|
||||
void server_queue::cleanup_pending_task(int id_target) {
|
||||
// no need lock because this is called exclusively by post()
|
||||
auto rm_func = [id_target](const server_task& task) {
|
||||
return task.id == id_target;
|
||||
};
|
||||
queue_tasks.erase(
|
||||
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
|
||||
queue_tasks.end());
|
||||
queue_tasks_deferred.erase(
|
||||
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
|
||||
queue_tasks_deferred.end());
|
||||
}
|
||||
|
||||
// multi-task version of post()
|
||||
int server_queue::post(std::vector<server_task>&& tasks, bool front) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
for (auto& task : tasks) {
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
}
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int)tasks.size(), front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
}
|
||||
else {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
}
|
||||
condition_tasks.notify_one();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void server_queue::defer(server_task&& task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
queue_tasks_deferred.push_back(std::move(task));
|
||||
@@ -68,7 +104,7 @@ void server_queue::start_loop() {
|
||||
break;
|
||||
}
|
||||
server_task task = std::move(queue_tasks.front());
|
||||
queue_tasks.erase(queue_tasks.begin());
|
||||
queue_tasks.pop_front();
|
||||
lock.unlock();
|
||||
//LOG_VERBOSE("callback_new_task", { {"id_task", task.id} });
|
||||
callback_new_task(std::move(task));
|
||||
@@ -134,13 +170,21 @@ void server_queue::update_multitask(int id_multi, int id_sub, server_task_result
|
||||
|
||||
|
||||
void server_response::add_waiting_task_id(int id_task) {
|
||||
//LOG_VERBOSE("waiting for task id", { {"id_task", id_task} });
|
||||
QUE_DBG("waiting for task id, id = %d\n", id_task);
|
||||
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int)waiting_task_ids.size());
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.insert(id_task);
|
||||
}
|
||||
|
||||
void server_response::add_waiting_tasks(const std::vector<server_task>& tasks) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
||||
for (const auto& task : tasks) {
|
||||
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int)waiting_task_ids.size());
|
||||
waiting_task_ids.insert(task.id);
|
||||
}
|
||||
}
|
||||
|
||||
void server_response::remove_waiting_task_id(int id_task) {
|
||||
//LOG_VERBOSE("remove waiting for task id", { {"id_task", id_task} });
|
||||
QUE_DBG("remove waiting for task id, id = %d\n", id_task);
|
||||
@@ -153,14 +197,14 @@ server_task_result server_response::recv(int id_task) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&] {
|
||||
return !queue_results.empty();
|
||||
return !queue_results_legacy.empty();
|
||||
});
|
||||
|
||||
for (int i = 0; i < (int)queue_results.size(); i++) {
|
||||
if (queue_results[i].id == id_task) {
|
||||
assert(queue_results[i].id_multi == -1);
|
||||
server_task_result res = queue_results[i];
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
for (int i = 0; i < (int)queue_results_legacy.size(); i++) {
|
||||
if (queue_results_legacy[i].id == id_task) {
|
||||
assert(queue_results_legacy[i].id_multi == -1);
|
||||
server_task_result res = queue_results_legacy[i];
|
||||
queue_results_legacy.erase(queue_results_legacy.begin() + i);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
@@ -169,6 +213,41 @@ server_task_result server_response::recv(int id_task) {
|
||||
// should never reach here
|
||||
}
|
||||
|
||||
// same as recv(), but have timeout in seconds
|
||||
// if timeout is reached, nullptr is returned
|
||||
server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int>& id_tasks, int timeout) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
||||
for (int i = 0; i < (int)queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
server_task_result_ptr res = std::move(queue_results[i]);
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
|
||||
if (!running) {
|
||||
SRV_DBG("%s : queue result stop\n", __func__);
|
||||
std::terminate(); // we cannot return here since the caller is HTTP code
|
||||
}
|
||||
if (cr_res == std::cv_status::timeout) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
}
|
||||
void server_response::remove_waiting_task_ids(const std::unordered_set<int>& id_tasks) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
||||
for (const auto& id_task : id_tasks) {
|
||||
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int)waiting_task_ids.size());
|
||||
waiting_task_ids.erase(id_task);
|
||||
}
|
||||
}
|
||||
|
||||
void server_response::send(server_task_result result) {
|
||||
//LOG_VERBOSE("send new result", { {"id_task", result.id} });
|
||||
QUE_DBG("send new result, id = %d\n", result.id);
|
||||
@@ -184,9 +263,25 @@ void server_response::send(server_task_result result) {
|
||||
}
|
||||
|
||||
if (result.id == id_task) {
|
||||
//LOG_VERBOSE("queue_results.push_back", { {"id_task", id_task} });
|
||||
//LOG_VERBOSE("queue_results_legacy.push_back", { {"id_task", id_task} });
|
||||
QUE_DBG("queue_results.push_back, id = %d\n", id_task);
|
||||
queue_results.push_back(result);
|
||||
queue_results_legacy.push_back(std::move(result));
|
||||
condition_results.notify_all();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void server_response::send(server_task_result_ptr&& result) {
|
||||
SRV_DBG("sending result for task id = %d\n", result->id);
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
for (const auto& id_task : waiting_task_ids) {
|
||||
if (result->id == id_task) {
|
||||
SRV_DBG("task id = %d pushed to result queue\n", result->id);
|
||||
|
||||
queue_results.emplace_back(std::move(result));
|
||||
condition_results.notify_all();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -19,8 +19,8 @@ struct server_queue {
|
||||
bool running;
|
||||
|
||||
// queues
|
||||
std::vector<server_task> queue_tasks;
|
||||
std::vector<server_task> queue_tasks_deferred;
|
||||
std::deque<server_task> queue_tasks;
|
||||
std::deque<server_task> queue_tasks_deferred;
|
||||
|
||||
std::vector<server_task_multi> queue_multitasks;
|
||||
|
||||
@@ -36,6 +36,10 @@ struct server_queue {
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task);
|
||||
|
||||
int post(std::vector<server_task>&& tasks, bool front = false);
|
||||
|
||||
void cleanup_pending_task(int id_target);
|
||||
|
||||
// Add a new task, but defer until one slot is available
|
||||
void defer(server_task&& task);
|
||||
|
||||
@@ -89,11 +93,14 @@ struct server_response {
|
||||
typedef std::function<void(int, int, server_task_result&)> callback_multitask_t;
|
||||
callback_multitask_t callback_update_multitask;
|
||||
|
||||
bool running = true;
|
||||
// for keeping track of all tasks waiting for the result
|
||||
std::set<int> waiting_task_ids;
|
||||
|
||||
// the main result queue
|
||||
std::vector<server_task_result> queue_results;
|
||||
// the main result queue (using ptr for polymorphism)
|
||||
std::vector<server_task_result_ptr> queue_results;
|
||||
|
||||
std::vector<server_task_result> queue_results_legacy;
|
||||
|
||||
std::mutex mutex_results;
|
||||
std::condition_variable condition_results;
|
||||
@@ -101,12 +108,20 @@ struct server_response {
|
||||
// add the id_task to the list of tasks waiting for response
|
||||
void add_waiting_task_id(int id_task);
|
||||
|
||||
void add_waiting_tasks(const std::vector<server_task>& tasks);
|
||||
|
||||
// when the request is finished, we can remove task associated with it
|
||||
void remove_waiting_task_id(int id_task);
|
||||
|
||||
void remove_waiting_task_ids(const std::unordered_set<int>& id_tasks);
|
||||
|
||||
// This function blocks the thread until there is a response for this id_task
|
||||
server_task_result recv(int id_task);
|
||||
|
||||
// same as recv(), but have timeout in seconds
|
||||
// if timeout is reached, nullptr is returned
|
||||
server_task_result_ptr recv_with_timeout(const std::unordered_set<int>& id_tasks, int timeout);
|
||||
|
||||
// Register the function to update multitask
|
||||
void on_multitask_update(callback_multitask_t callback) {
|
||||
callback_update_multitask = std::move(callback);
|
||||
@@ -114,4 +129,12 @@ struct server_response {
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void send(server_task_result result);
|
||||
|
||||
void send(server_task_result_ptr&& result);
|
||||
|
||||
// terminate the waiting loop
|
||||
void terminate() {
|
||||
running = false;
|
||||
condition_results.notify_all();
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "server-task.h"
|
||||
|
||||
|
||||
json result_timings::to_json() const {
|
||||
json base = {
|
||||
{"prompt_n", prompt_n},
|
||||
@@ -26,83 +25,63 @@ json result_timings::to_json() const {
|
||||
}
|
||||
|
||||
|
||||
json server_task_result::to_json_final() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic_final();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
//json server_task_result_cmpl_partial::to_json_non_oaicompat_partial() {
|
||||
// // non-OAI-compat JSON
|
||||
// json res = json{
|
||||
// {"index", index},
|
||||
// {"content", content},
|
||||
// {"tokens", tokens},
|
||||
// {"stop", false},
|
||||
// {"id_slot", id_multi},
|
||||
// {"tokens_predicted", n_decoded},
|
||||
// {"tokens_evaluated", n_prompt_tokens},
|
||||
// };
|
||||
// // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
||||
// if (timings.prompt_n > 0) {
|
||||
// res.push_back({ "timings", timings.to_json() });
|
||||
// }
|
||||
// if (!probs_output.empty()) {
|
||||
// res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
// }
|
||||
// return res;
|
||||
//}
|
||||
|
||||
json server_task_result::to_json_partial() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return to_json_oaicompat_chat_partial();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic_partial();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
//json server_task_result_cmpl_final::to_json_non_oaicompat_final() {
|
||||
// json res = json{
|
||||
// {"index", index},
|
||||
// {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
// {"tokens", stream ? std::vector<llama_token> {} : tokens},
|
||||
// {"id_slot", id_multi},
|
||||
// {"stop", true},
|
||||
// {"model", oaicompat_model},
|
||||
// {"tokens_predicted", n_decoded},
|
||||
// {"tokens_evaluated", n_prompt_tokens},
|
||||
// //{"generation_settings", default_generation_settings_for_props.to_json()},
|
||||
// {"prompt", prompt},
|
||||
// {"has_new_line", has_new_line},
|
||||
// {"truncated", truncated},
|
||||
// //{"stop_type", stop_type_to_str(STOP_TYPE_EOS)},
|
||||
// {"stopping_word", stopping_word},
|
||||
// {"tokens_cached", n_tokens_cached},
|
||||
// {"timings", timings.to_json()},
|
||||
// };
|
||||
// if (!stream && !probs_output.empty()) {
|
||||
// res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
// }
|
||||
// return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
//}
|
||||
|
||||
json server_task_result::to_json_non_oaicompat_partial() {
|
||||
json server_task_result_cmpl_partial::to_json_non_oaicompat_partial() {
|
||||
// non-OAI-compat JSON
|
||||
json res = json{
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"tokens", tokens},
|
||||
{"stop", false},
|
||||
{"id_slot", id_multi},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
};
|
||||
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
||||
if (timings.prompt_n > 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return res;
|
||||
return data;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_non_oaicompat_final() {
|
||||
json res = json{
|
||||
{"index", index},
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"tokens", stream ? std::vector<llama_token> {} : tokens},
|
||||
{"id_slot", id_multi},
|
||||
{"stop", true},
|
||||
{"model", oaicompat_model},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
//{"generation_settings", default_generation_settings_for_props.to_json()},
|
||||
{"prompt", prompt},
|
||||
{"has_new_line", has_new_line},
|
||||
{"truncated", truncated},
|
||||
//{"stop_type", stop_type_to_str(STOP_TYPE_EOS)},
|
||||
{"stopping_word", stopping_word},
|
||||
{"tokens_cached", n_tokens_cached},
|
||||
{"timings", timings.to_json()},
|
||||
};
|
||||
if (!stream && !probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
json server_task_result_cmpl_final::to_json_non_oaicompat_final() {
|
||||
// non-OAI-compat JSON
|
||||
return data;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_partial() {
|
||||
json server_task_result_cmpl_partial::to_json_oaicompat_partial() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (probs_output.size() > 0) {
|
||||
@@ -141,7 +120,7 @@ json server_task_result::to_json_oaicompat_partial() {
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_final() {
|
||||
json server_task_result_cmpl_final::to_json_oaicompat_final() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
@@ -184,7 +163,7 @@ json server_task_result::to_json_oaicompat_final() {
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_partial() {
|
||||
json server_task_result_cmpl_partial::to_json_oaicompat_chat_partial() {
|
||||
bool first = n_decoded == 1;
|
||||
std::time_t t = std::time(0);
|
||||
json choices;
|
||||
@@ -239,7 +218,7 @@ json server_task_result::to_json_oaicompat_chat_partial() {
|
||||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_final() {
|
||||
json server_task_result_cmpl_final::to_json_oaicompat_chat_final() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
@@ -292,7 +271,7 @@ json server_task_result::to_json_oaicompat_chat_final() {
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_stream() {
|
||||
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
||||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
if (stop) {
|
||||
@@ -357,7 +336,7 @@ json server_task_result::to_json_oaicompat_chat_stream() {
|
||||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_final() {
|
||||
json server_task_result_cmpl_final::to_json_anthropic_final() {
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
@@ -416,7 +395,7 @@ json server_task_result::to_json_anthropic_final() {
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_stream() {
|
||||
json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||
json events = json::array();
|
||||
|
||||
std::string stop_reason = "max_tokens";
|
||||
@@ -552,7 +531,7 @@ json server_task_result::to_json_anthropic_stream() {
|
||||
return events;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_partial() {
|
||||
json server_task_result_cmpl_partial::to_json_anthropic_partial() {
|
||||
json events = json::array();
|
||||
bool first = n_decoded == 1;
|
||||
static bool text_block_started = false;
|
||||
|
||||
@@ -42,13 +42,53 @@ enum oaicompat_type {
|
||||
};
|
||||
|
||||
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
bool include_usage = false;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
thinking_tokens think_tokens;
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
|
||||
// speculative decoding parameters
|
||||
struct {
|
||||
int n_max = 16; // max drafted tokens
|
||||
int n_min = 0; // min drafted tokens to accept
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
} speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int id_multi = -1;
|
||||
|
||||
int index = -1; // used when there are multiple prompts (batch request)
|
||||
|
||||
// used by SERVER_TASK_TYPE_CANCEL
|
||||
int id_target = -1;
|
||||
//int id_slot = -1;
|
||||
int id_slot = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
struct slot_params params;
|
||||
server_tokens tokens;
|
||||
|
||||
server_task_type type;
|
||||
@@ -60,6 +100,18 @@ struct server_task {
|
||||
server_task() = default;
|
||||
server_task(server_task_type type) : type(type) {}
|
||||
|
||||
int32_t n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
// utility function
|
||||
static std::unordered_set<int> get_list_id(const std::vector<server_task>& tasks) {
|
||||
std::unordered_set<int> ids(tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
ids.insert(tasks[i].id);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
@@ -126,40 +178,142 @@ struct server_task_result {
|
||||
|
||||
bool verbose = false;
|
||||
|
||||
virtual bool is_error() {
|
||||
// only used by server_task_result_error
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool is_stop() {
|
||||
// only used by server_task_result_cmpl_*
|
||||
// in stream mode, final responses are considered stop
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual json to_json() {
|
||||
return {};
|
||||
};
|
||||
|
||||
int get_index() {
|
||||
return index;
|
||||
}
|
||||
|
||||
bool is_stop() {
|
||||
return true; // in stream mode, final responses are considered stop
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_partial : server_task_result {
|
||||
virtual bool is_stop() override {
|
||||
return false; // in stream mode, partial responses are not considered stop
|
||||
}
|
||||
|
||||
json to_json_final();
|
||||
|
||||
json to_json_partial();
|
||||
|
||||
json to_json_non_oaicompat_partial();
|
||||
|
||||
json to_json_non_oaicompat_final();
|
||||
|
||||
json to_json_oaicompat_partial();
|
||||
|
||||
json to_json_oaicompat_final();
|
||||
json to_json_anthropic_partial();
|
||||
|
||||
json to_json_oaicompat_chat_partial();
|
||||
|
||||
json to_json_oaicompat_chat_final();
|
||||
virtual json to_json() override {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return to_json_oaicompat_chat_partial();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic_partial();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
json to_json_oaicompat_chat_stream();
|
||||
struct server_task_result_cmpl_final : server_task_result {
|
||||
virtual bool is_stop() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat_final();
|
||||
|
||||
json to_json_oaicompat_final();
|
||||
|
||||
json to_json_oaicompat_chat_final();
|
||||
|
||||
json to_json_anthropic_final();
|
||||
|
||||
json to_json_anthropic_stream();
|
||||
|
||||
json to_json_anthropic_partial();
|
||||
json to_json_oaicompat_chat_stream();
|
||||
|
||||
virtual json to_json() override {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic_final();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_error : server_task_result {
|
||||
int index = 0;
|
||||
error_type err_type = ERROR_TYPE_SERVER;
|
||||
std::string err_msg;
|
||||
|
||||
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_ctx = 0;
|
||||
|
||||
virtual bool is_error() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
json res = format_error_response(err_msg, err_type);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_embd : server_task_result {
|
||||
int index = 0;
|
||||
std::vector<std::vector<float>> embedding;
|
||||
|
||||
int32_t n_tokens;
|
||||
|
||||
// OAI-compat fields
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
|
||||
virtual json to_json() override {
|
||||
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
? to_json_oaicompat()
|
||||
: to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat() {
|
||||
return json{
|
||||
{"index", index},
|
||||
{"embedding", embedding},
|
||||
};
|
||||
}
|
||||
|
||||
json to_json_oaicompat() {
|
||||
return json{
|
||||
{"index", index},
|
||||
{"embedding", embedding[0]},
|
||||
{"tokens_evaluated", n_tokens},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// using shared_ptr for polymorphism of server_task_result
|
||||
using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
||||
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
|
||||
@@ -9,6 +9,12 @@
|
||||
#include "sampling.h"
|
||||
#include "llama.h"
|
||||
#include "llama-vocab.h"
|
||||
#include <fstream>
|
||||
|
||||
|
||||
// mime type for sending response
|
||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
@@ -48,6 +54,7 @@ struct DatabaseHandle {
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
namespace fs = std::filesystem;
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
bool server_verbose = false;
|
||||
bool server_log_json = true;
|
||||
@@ -299,6 +306,117 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
||||
});
|
||||
}
|
||||
|
||||
// generator-like API for server responses, support pooling connection state and aggregating results
|
||||
struct server_response_reader {
|
||||
std::unordered_set<int> id_tasks;
|
||||
server_context& ctx_server;
|
||||
size_t received_count = 0;
|
||||
bool cancelled = false;
|
||||
|
||||
server_response_reader(server_context& ctx_server) : ctx_server(ctx_server) {}
|
||||
~server_response_reader() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void post_tasks(std::vector<server_task>&& tasks) {
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
bool has_next() {
|
||||
return !cancelled && received_count < id_tasks.size();
|
||||
}
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr next(const std::function<bool()>& should_stop) {
|
||||
while (true) {
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
||||
if (result == nullptr) {
|
||||
// timeout, check stop condition
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (result->is_error()) {
|
||||
stop(); // cancel remaining tasks
|
||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||
return result;
|
||||
}
|
||||
if (result->is_stop()) {
|
||||
received_count++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// should not reach here
|
||||
}
|
||||
|
||||
struct batch_response {
|
||||
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
|
||||
std::vector<server_task_result_ptr> results;
|
||||
server_task_result_ptr error; // nullptr if no error
|
||||
};
|
||||
|
||||
batch_response wait_for_all(const std::function<bool()>& should_stop) {
|
||||
batch_response batch_res;
|
||||
batch_res.results.resize(id_tasks.size());
|
||||
while (has_next()) {
|
||||
auto res = next(should_stop);
|
||||
if (res == nullptr) {
|
||||
batch_res.is_terminated = true;
|
||||
return batch_res;
|
||||
}
|
||||
if (res->error) {
|
||||
batch_res.error = std::move(res);
|
||||
return batch_res;
|
||||
}
|
||||
const size_t idx = res->get_index();
|
||||
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
||||
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
||||
batch_res.results[idx] = std::move(res);
|
||||
}
|
||||
return batch_res;
|
||||
}
|
||||
|
||||
void stop() {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(id_tasks);
|
||||
if (has_next() && !cancelled) {
|
||||
// if tasks is not finished yet, cancel them
|
||||
cancelled = true;
|
||||
std::vector<server_task> cancel_tasks;
|
||||
cancel_tasks.reserve(id_tasks.size());
|
||||
for (const auto& id_task : id_tasks) {
|
||||
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(std::move(task));
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
ctx_server.queue_tasks.post(std::move(cancel_tasks), true);
|
||||
}
|
||||
else {
|
||||
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto res_err = [](httplib::Response& res, json error_data) {
|
||||
json final_response{ {"error", error_data} };
|
||||
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
auto res_ok = [](httplib::Response& res, const json& data) {
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
res.status = 200;
|
||||
};
|
||||
|
||||
std::function<void(int)> shutdown_handler;
|
||||
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
@@ -380,18 +498,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
auto res_error = [](httplib::Response & res, json error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(final_response.dump(), "application/json; charset=utf-8");
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
auto res_ok = [](httplib::Response& res, const json& data) {
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
res.status = 200;
|
||||
};
|
||||
|
||||
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
std::string message;
|
||||
try {
|
||||
std::rethrow_exception(std::move(ep));
|
||||
@@ -403,14 +512,14 @@ int main(int argc, char ** argv) {
|
||||
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_VERBOSE("Got exception", formatted_error);
|
||||
res_error(res, formatted_error);
|
||||
res_err(res, formatted_error);
|
||||
});
|
||||
|
||||
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 404) {
|
||||
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
}
|
||||
// for other error codes, we skip processing here because it's already done by res_error()
|
||||
// for other error codes, we skip processing here because it's already done by res_err()
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
@@ -492,7 +601,7 @@ int main(int argc, char ** argv) {
|
||||
// Middlewares
|
||||
//
|
||||
|
||||
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) {
|
||||
static const std::unordered_set<std::string> public_endpoints = {
|
||||
"/health",
|
||||
"/v1/health",
|
||||
@@ -544,7 +653,7 @@ int main(int argc, char ** argv) {
|
||||
return false;
|
||||
};
|
||||
|
||||
auto middleware_server_state = [&res_error, &state](const httplib::Request& req, httplib::Response& res) {
|
||||
auto middleware_server_state = [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
server_state current_state = state.load();
|
||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||
auto tmp = string_split<std::string>(req.path, '.');
|
||||
@@ -557,7 +666,7 @@ int main(int argc, char ** argv) {
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -632,18 +741,18 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
case SERVER_STATE_LOADING_MODEL:
|
||||
{
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
} break;
|
||||
case SERVER_STATE_ERROR:
|
||||
{
|
||||
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
||||
res_err(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
||||
} break;
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_slots) {
|
||||
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -667,7 +776,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_metrics) {
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -766,11 +875,11 @@ int main(int argc, char ** argv) {
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_slots_save = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -790,17 +899,17 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
res_err(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_restore = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -820,13 +929,13 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
res_err(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
|
||||
task.data = {
|
||||
@@ -840,20 +949,20 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
res_err(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_slots_action = [&handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string id_slot_str = req.path_params.at("id_slot");
|
||||
int id_slot;
|
||||
|
||||
try {
|
||||
id_slot = std::stoi(id_slot_str);
|
||||
} catch (const std::exception &) {
|
||||
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -866,7 +975,7 @@ int main(int argc, char ** argv) {
|
||||
} else if (action == "erase") {
|
||||
handle_slots_erase(req, res, id_slot);
|
||||
} else {
|
||||
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -931,146 +1040,161 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// handle completion-like requests (completion, chat, infill)
|
||||
// we can optionally provide a custom format for partial results and final results
|
||||
const auto handle_completions_impl = [&ctx_server, ¶ms, &res_error, &res_ok](
|
||||
const auto handle_completions_impl = [&ctx_server, ¶ms](
|
||||
server_task_type type,
|
||||
json& data,
|
||||
const std::vector<raw_buffer>& files,
|
||||
const std::function<bool()>& is_connection_closed,
|
||||
httplib::Response& res,
|
||||
oaicompat_type oaicompat) -> void {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION);
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
// need to store the reader as a pointer, so that it won't be destroyed when the handle returns
|
||||
// use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
|
||||
const auto rd = std::make_shared<server_response_reader>(ctx_server);
|
||||
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
const auto& prompt = data.at("prompt");
|
||||
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
}
|
||||
else {
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.data = data;
|
||||
//task.params = server_task::params_from_json_cmpl(
|
||||
// ctx_server.ctx,
|
||||
// ctx_server.params,
|
||||
// data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
rd->post_tasks(std::move(tasks));
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& prompt = data.at("prompt");
|
||||
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
#ifndef NDEBUG
|
||||
print_files_info(files);
|
||||
#endif // !NDEBUG
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
}
|
||||
else {
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, data, false, false, std::move(inputs[0]));
|
||||
bool stream = json_value(data, "stream", false);
|
||||
if (!stream) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
result.oaicompat = oaicompat;
|
||||
result.oaicompat_cmpl_id = completion_id;
|
||||
json result_oai;
|
||||
if (oaicompat) {
|
||||
if (result.final_result) {
|
||||
result_oai = result.to_json_final();
|
||||
}
|
||||
else {
|
||||
result_oai = result.to_json_partial();
|
||||
}
|
||||
// non-stream, wait for the results
|
||||
auto all_results = rd->wait_for_all(is_connection_closed);
|
||||
if (all_results.is_terminated) {
|
||||
llama_decode_stop(); // send a signal to stop decode process
|
||||
return; // connection is closed
|
||||
}
|
||||
else if (all_results.error) {
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
}
|
||||
else {
|
||||
// legacy completions
|
||||
result_oai = result.data;
|
||||
}
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
}
|
||||
else {
|
||||
res_error(res, result_oai);
|
||||
}
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
json arr = json::array();
|
||||
for (auto& res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json());
|
||||
}
|
||||
// if single request, return single object instead of array
|
||||
res_ok(res, arr.size() == 1 ? arr[0] : arr);
|
||||
}
|
||||
}
|
||||
else {
|
||||
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, oaicompat, send_done = params.send_done](size_t, httplib::DataSink& sink) {
|
||||
bool successful_completion = false;
|
||||
const auto sse = [oaicompat, &sink](const json &res) {
|
||||
if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) {
|
||||
return server_sent_anthropic_event(sink, res);
|
||||
} else {
|
||||
return server_sent_event(sink, res);
|
||||
}
|
||||
};
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
if (!result.error) {
|
||||
result.oaicompat = oaicompat;
|
||||
result.oaicompat_cmpl_id = completion_id;
|
||||
json res_json;
|
||||
if (oaicompat) {
|
||||
if (result.final_result) {
|
||||
res_json = result.to_json_final();
|
||||
}
|
||||
else {
|
||||
res_json = result.to_json_partial();
|
||||
}
|
||||
}
|
||||
else {
|
||||
// legacy completions
|
||||
res_json = result.data;
|
||||
}
|
||||
if (res_json.is_array()) {
|
||||
// chat completions and oai completions
|
||||
for (const auto& res : res_json) {
|
||||
if (!sse(res)) {
|
||||
// sending failed (HTTP connection closed), cancel the generation
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (result.stop) {
|
||||
successful_completion = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// legacy completions
|
||||
if (!sse(res_json)) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
if (result.stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!sse(result.data)) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
// this is to match the OAI API behavior
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
|
||||
server_task_result_ptr first_result = rd->next(is_connection_closed);
|
||||
if (first_result == nullptr) {
|
||||
llama_decode_stop(); // send a signal to stop decode process
|
||||
return; // connection is closed
|
||||
}
|
||||
else if (first_result->is_error()) {
|
||||
res_err(res, first_result->to_json());
|
||||
return;
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
|
||||
);
|
||||
}
|
||||
// next responses are streamed
|
||||
json first_result_json = first_result->to_json();
|
||||
const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink& sink) mutable -> bool {
|
||||
// flush the first result as it's not an error
|
||||
if (!first_result_json.empty()) {
|
||||
if (!server_sent_event(sink, first_result_json)) {
|
||||
sink.done();
|
||||
return false; // sending failed, go to on_complete()
|
||||
}
|
||||
first_result_json.clear(); // mark as sent
|
||||
}
|
||||
bool ok = true;
|
||||
if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
static const std::string done_message = "data: [DONE]\n\n";
|
||||
LOG_VERBOSE("data stream", { {"to_send", done_message} });
|
||||
if (!sink.write(done_message.c_str(), done_message.size())) {
|
||||
// If writing [DONE] fails, the stream is likely already problematic.
|
||||
ok = false;
|
||||
}
|
||||
|
||||
// receive subsequent results
|
||||
auto result = rd->next([&sink] { return !sink.is_writable(); });
|
||||
if (result == nullptr) {
|
||||
sink.done();
|
||||
return false; // connection is closed, go to on_complete()
|
||||
}
|
||||
sink.done();
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return ok;
|
||||
|
||||
// send the results
|
||||
json res_json = result->to_json();
|
||||
bool ok = false;
|
||||
if (result->is_error()) {
|
||||
ok = server_sent_event(sink, json{ { "error", result->to_json() } });
|
||||
sink.done();
|
||||
return false; // go to on_complete()
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
ok = server_sent_event(sink, res_json);
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
sink.done();
|
||||
return false; // sending failed, go to on_complete()
|
||||
}
|
||||
|
||||
// check if there is more data
|
||||
if (!rd->has_next()) {
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
static const std::string ev_done = "data: [DONE]\n\n";
|
||||
sink.write(ev_done.data(), ev_done.size());
|
||||
}
|
||||
sink.done();
|
||||
return false; // no more data, go to on_complete()
|
||||
}
|
||||
|
||||
// has next data, continue
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [id_task, &ctx_server](bool) {
|
||||
// cancel request
|
||||
ctx_server.request_cancel(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
auto on_complete = [rd](bool) {
|
||||
rd->stop();
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
};
|
||||
@@ -1082,6 +1206,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
@@ -1094,6 +1219,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_COMPLETION);
|
||||
};
|
||||
@@ -1117,7 +1243,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
std::vector<raw_buffer> files;
|
||||
json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt, files);
|
||||
@@ -1125,6 +1251,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
@@ -1141,11 +1268,12 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_ANTHROPIC);
|
||||
};
|
||||
|
||||
const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
std::vector<raw_buffer> files;
|
||||
json body = json::parse(req.body);
|
||||
|
||||
@@ -1164,14 +1292,14 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request& req, httplib::Response& res) {
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) {
|
||||
auto body = json::parse(req.body);
|
||||
std::vector<raw_buffer> files; // dummy, unused
|
||||
json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt, files);
|
||||
res_ok(res, { { "prompt", std::move(data.at("prompt")) } });
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
server_tokens token; // dummy tokens
|
||||
@@ -1182,6 +1310,7 @@ int main(int argc, char ** argv) {
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||
};
|
||||
@@ -1211,58 +1340,119 @@ int main(int argc, char ** argv) {
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
|
||||
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
bool is_openai = false;
|
||||
|
||||
// an input prompt can be a string or a list of tokens (integer)
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
is_openai = true;
|
||||
prompt = body.at("input");
|
||||
} else if (body.count("content") != 0) {
|
||||
// with "content", we only support single prompt
|
||||
prompt = std::vector<std::string>{body.at("content")};
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request& req, httplib::Response& res, oaicompat_type oaicompat) {
|
||||
if (!ctx_server.params.embedding) {
|
||||
res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses;
|
||||
{
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
std::vector<server_tokens> inputs;
|
||||
inputs = tokenize_input_prompts(llama_get_vocab(ctx_server.ctx), ctx_server.mctx, prompt, true, true);
|
||||
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true, std::move(inputs[0]));
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
if (!result.error) {
|
||||
if (result.data.count("results")) {
|
||||
// result for multi-task
|
||||
responses = result.data.at("results");
|
||||
} else {
|
||||
// result for single task
|
||||
responses = std::vector<json>{ result.data };
|
||||
}
|
||||
} else {
|
||||
// error received, ignore everything else
|
||||
res_error(res, result.data);
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
// for the shape of input/content, see tokenize_input_prompts()
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
prompt = body.at("input");
|
||||
}
|
||||
else if (body.contains("content")) {
|
||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||
prompt = body.at("content");
|
||||
}
|
||||
else {
|
||||
res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
bool use_base64 = false;
|
||||
if (body.count("encoding_format") != 0) {
|
||||
const std::string& format = body.at("encoding_format");
|
||||
if (format == "base64") {
|
||||
use_base64 = true;
|
||||
}
|
||||
else if (format != "float") {
|
||||
res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto vocab = llama_get_vocab(ctx_server.ctx);
|
||||
auto tokenized_prompts = tokenize_input_prompts(vocab, ctx_server.mctx, prompt, true, true);
|
||||
for (const auto& tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||
if (body.count("embd_normalize") != 0) {
|
||||
embd_normalize = body.at("embd_normalize");
|
||||
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
|
||||
}
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
server_response_reader rd(ctx_server);
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.embd_normalize = embd_normalize;
|
||||
task.embedding = true; // probably not needed
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
rd.post_tasks(std::move(tasks));
|
||||
}
|
||||
|
||||
// wait for the results
|
||||
auto all_results = rd.wait_for_all(req.is_connection_closed);
|
||||
|
||||
// collect results
|
||||
if (all_results.is_terminated) {
|
||||
llama_decode_stop();
|
||||
return; // connection is closed
|
||||
}
|
||||
else if (all_results.error) {
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
}
|
||||
else {
|
||||
for (auto& res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}
|
||||
|
||||
// write JSON response
|
||||
json root = is_openai
|
||||
? format_embeddings_response_oaicompat(body, responses, false)
|
||||
: responses[0];
|
||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||
: json(responses);
|
||||
res_ok(res, root);
|
||||
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) {
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
|
||||
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request& req, httplib::Response& res) {
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
|
||||
};
|
||||
|
||||
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
json result = json::array();
|
||||
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
|
||||
@@ -1277,6 +1467,7 @@ int main(int argc, char ** argv) {
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
const std::vector<json> body = json::parse(req.body);
|
||||
int max_idx = ctx_server.lora_adapters.size();
|
||||
@@ -1718,7 +1909,7 @@ int main(int argc, char ** argv) {
|
||||
svr->Post("/infill", handle_infill);
|
||||
svr->Post("/embedding", handle_embeddings); // legacy
|
||||
svr->Post("/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
svr->Post("/apply-template", handle_apply_template);
|
||||
|
||||
Reference in New Issue
Block a user