mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
server: stop processing the prompt when client disconnects (#1134)
implement generator-based API for task results Update httplib.h to 0.27.0 Fix embedding error Stop prompt processing when disconnected Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -316,6 +316,7 @@ void server_slot::reset() {
|
||||
stopped_limit = false;
|
||||
stopping_word = "";
|
||||
n_past = 0;
|
||||
n_past_prompt = 0;
|
||||
n_sent_text = 0;
|
||||
|
||||
drafted.clear();
|
||||
@@ -402,7 +403,9 @@ void server_slot::release() {
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
command = SLOT_COMMAND_RELEASE;
|
||||
task.reset();
|
||||
llama_decode_reset();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -803,6 +806,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.oaicompat = false;
|
||||
slot.oaicompat_model = "";
|
||||
}
|
||||
slot.params.oaicompat = task.params.oaicompat;
|
||||
slot.params.oaicompat_cmpl_id =task.params.oaicompat_cmpl_id;
|
||||
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
auto stream_opt = json_value(data, "stream_options", json::object());
|
||||
@@ -927,29 +932,6 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
|
||||
// get prompt
|
||||
if (!task.infill) {
|
||||
// maybe not needed since prompt has been tokenized?
|
||||
const auto& prompt = data.find("prompt");
|
||||
if (!slot.prompt_tokens.validate(ctx)) {
|
||||
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((prompt->is_string()) ||
|
||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
||||
slot.prompt = *prompt;
|
||||
}
|
||||
else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||
slot.prompt = prompt->at(0);
|
||||
}
|
||||
else {
|
||||
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
slot.prompt_tokens = std::move(task.tokens);
|
||||
}
|
||||
|
||||
@@ -1165,7 +1147,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
});
|
||||
|
||||
slot.task = std::make_unique<const server_task>(std::move(task));
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1490,66 +1472,66 @@ bool server_context::ensure_no_mtmd(const int id_task) {
|
||||
}
|
||||
|
||||
void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) {
|
||||
server_task_result res;
|
||||
res.final_result = false;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = false;
|
||||
res.stream = slot.params.stream;
|
||||
res.content = tkn.text_to_send;
|
||||
res.post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res.oaicompat = slot.params.oaicompat;
|
||||
res.oaicompat_model = slot.params.oaicompat_model;
|
||||
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res.n_decoded = slot.n_decoded;
|
||||
res.n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res.data = json{
|
||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||
res->final_result = false;
|
||||
res->id = slot.id_task;
|
||||
res->id_multi = slot.id_multi;
|
||||
res->error = false;
|
||||
res->stop = false;
|
||||
res->stream = slot.params.stream;
|
||||
res->content = tkn.text_to_send;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->data = json{
|
||||
{"content", tkn.text_to_send},
|
||||
{"stop", false},
|
||||
{"id_slot", slot.id},
|
||||
{"multimodal", false}
|
||||
};
|
||||
slot.update_chat_msg(res.oaicompat_msg_diffs);
|
||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
// populate res->probs_output
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
res.probs_output = { tkn }; // copy the token probs
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
||||
res->probs_output = { tkn }; // copy the token probs
|
||||
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
||||
}
|
||||
|
||||
if (slot.oaicompat) {
|
||||
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res.data["model"] = slot.oaicompat_model;
|
||||
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res->data["model"] = slot.oaicompat_model;
|
||||
}
|
||||
|
||||
// populate timings if this is final response or timings_per_token is enabled
|
||||
if (slot.params.timings_per_token) {
|
||||
res.timings = slot.get_timings();
|
||||
res->timings = slot.get_timings();
|
||||
}
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void server_context::send_final_response(server_slot& slot) {
|
||||
server_task_result res;
|
||||
res.final_result = true;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = true; // to do: set value
|
||||
res.stream = slot.params.stream;
|
||||
res.include_usage = slot.params.include_usage;
|
||||
res.content = slot.generated_text;
|
||||
res.timings = slot.get_timings();
|
||||
res.post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res.oaicompat = slot.params.oaicompat;
|
||||
res.oaicompat_model = slot.params.oaicompat_model;
|
||||
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs);
|
||||
res.n_decoded = slot.n_decoded;
|
||||
res.n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res.oaicompat_model = slot.oaicompat_model;
|
||||
res.data = json{
|
||||
auto res = std::make_unique<server_task_result_cmpl_final>();
|
||||
res->final_result = true;
|
||||
res->id = slot.id_task;
|
||||
res->id_multi = slot.id_multi;
|
||||
res->error = false;
|
||||
res->stop = true; // to do: set value
|
||||
res->stream = slot.params.stream;
|
||||
res->include_usage = slot.params.include_usage;
|
||||
res->content = slot.generated_text;
|
||||
res->timings = slot.get_timings();
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->oaicompat_model = slot.oaicompat_model;
|
||||
res->data = json{
|
||||
{"content", !slot.params.stream ? slot.generated_text : ""},
|
||||
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
|
||||
{"id_slot", slot.id},
|
||||
@@ -1569,30 +1551,29 @@ void server_context::send_final_response(server_slot& slot) {
|
||||
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
|
||||
};
|
||||
|
||||
// populate res.probs_output
|
||||
// populate res->probs_output
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
res.probs_output = std::vector<completion_token_output>(
|
||||
res->probs_output = std::vector<completion_token_output>(
|
||||
slot.generated_token_probs.begin(),
|
||||
slot.generated_token_probs.end());
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
||||
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
||||
}
|
||||
|
||||
if (slot.oaicompat) {
|
||||
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res.data["model"] = slot.oaicompat_model;
|
||||
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res->data["model"] = slot.oaicompat_model;
|
||||
}
|
||||
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
auto res = std::make_unique<server_task_result_embd>();
|
||||
res->id = slot.task->id;
|
||||
res->n_tokens = slot.prompt_tokens.size();
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
|
||||
std::vector<float> embd_res(n_embd, 0.0f);
|
||||
|
||||
@@ -1601,34 +1582,31 @@ void server_context::send_embedding(const server_slot& slot, const llama_batch&
|
||||
continue;
|
||||
}
|
||||
|
||||
const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
const float* embd = nullptr;
|
||||
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
else {
|
||||
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
LOG_ERROR("failed to get embeddings", {
|
||||
{"token", batch.token[i]},
|
||||
{"seq_id", batch.seq_id[i][0]}
|
||||
});
|
||||
|
||||
res.data = json{
|
||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
};
|
||||
if (embd == nullptr) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||
|
||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_embd_normalize(embd, embd_res.data(), n_embd);
|
||||
// normalize only when there is pooling
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
|
||||
res->embedding.push_back(embd_res);
|
||||
break;
|
||||
}
|
||||
|
||||
res.data = json{
|
||||
{"embedding", embd_res},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
};
|
||||
res->embedding.emplace_back(embd, embd + n_embd);
|
||||
}
|
||||
|
||||
queue_results.send(res);
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
|
||||
@@ -1708,6 +1686,9 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
|
||||
void server_context::process_single_task(server_task&& task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
case SERVER_TASK_TYPE_EMBEDDING:
|
||||
case SERVER_TASK_TYPE_RERANK:
|
||||
{
|
||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||
|
||||
@@ -2471,6 +2452,10 @@ void server_context::update_slots() {
|
||||
|
||||
// keep only the common part
|
||||
// remove the non-common part from the cache
|
||||
if (slot.n_past < 0)
|
||||
{
|
||||
slot.n_past = 0;
|
||||
}
|
||||
slot.cache_tokens.keep_first(slot.n_past);
|
||||
int p0 = (int)system_tokens.size() + slot.n_past;
|
||||
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
||||
@@ -2549,7 +2534,7 @@ void server_context::update_slots() {
|
||||
}
|
||||
|
||||
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
||||
llama_batch_add(batch, cur_tok, p0, { slot.id }, false);
|
||||
llama_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
|
||||
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
|
||||
@@ -2663,18 +2648,31 @@ void server_context::update_slots() {
|
||||
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
int user_cancel = -3;
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
{"ret", ret},
|
||||
});
|
||||
if (ret == user_cancel) {
|
||||
LOG_ERROR("Decode process is cancelled by user", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
{"ret", ret},
|
||||
});
|
||||
} else {
|
||||
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
{"ret", ret},
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& slot : slots) {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
slot.release();
|
||||
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
if (ret != user_cancel) {
|
||||
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
}
|
||||
|
||||
}
|
||||
break; // break loop of n_batch
|
||||
}
|
||||
@@ -2818,7 +2816,7 @@ json server_context::model_meta() const {
|
||||
{"vocab_type", llama_vocab_type(model)},
|
||||
{"n_vocab", llama_n_vocab(model)},
|
||||
{"n_ctx_train", llama_n_ctx_train(model)},
|
||||
{"n_embd", llama_n_embd(model)},
|
||||
{"n_embd", llama_model_n_embd(model)},
|
||||
{"n_params", llama_model_n_params(model)},
|
||||
{"size", llama_model_size(model)},
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user