server: stop processing the prompt when client disconnects (#1134)

implement generator-based API for task results

Update httplib.h to 0.27.0

Fix embedding error

Stop prompt processing when disconnected

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-01-12 23:56:59 -06:00
committed by GitHub
parent d3e3ad40f9
commit 1a461525d5
24 changed files with 7654 additions and 4549 deletions

View File

@@ -316,6 +316,7 @@ void server_slot::reset() {
stopped_limit = false;
stopping_word = "";
n_past = 0;
n_past_prompt = 0;
n_sent_text = 0;
drafted.clear();
@@ -402,7 +403,9 @@ void server_slot::release() {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = SLOT_COMMAND_RELEASE;
task.reset();
llama_decode_reset();
}
}
@@ -803,6 +806,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.oaicompat = false;
slot.oaicompat_model = "";
}
slot.params.oaicompat = task.params.oaicompat;
slot.params.oaicompat_cmpl_id =task.params.oaicompat_cmpl_id;
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
slot.params.stream = json_value(data, "stream", false);
auto stream_opt = json_value(data, "stream_options", json::object());
@@ -927,29 +932,6 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
// get prompt
if (!task.infill) {
// maybe not needed since prompt has been tokenized?
const auto& prompt = data.find("prompt");
if (!slot.prompt_tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if (prompt == data.end()) {
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if ((prompt->is_string()) ||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
}
else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
}
else {
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
slot.prompt_tokens = std::move(task.tokens);
}
@@ -1165,7 +1147,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
{"id_slot", slot.id},
{"id_task", slot.id_task},
});
slot.task = std::make_unique<const server_task>(std::move(task));
return true;
}
@@ -1490,66 +1472,66 @@ bool server_context::ensure_no_mtmd(const int id_task) {
}
void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) {
server_task_result res;
res.final_result = false;
res.id = slot.id_task;
res.id_multi = slot.id_multi;
res.error = false;
res.stop = false;
res.stream = slot.params.stream;
res.content = tkn.text_to_send;
res.post_sampling_probs = slot.params.post_sampling_probs;
res.oaicompat = slot.params.oaicompat;
res.oaicompat_model = slot.params.oaicompat_model;
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.data = json{
auto res = std::make_unique<server_task_result_cmpl_partial>();
res->final_result = false;
res->id = slot.id_task;
res->id_multi = slot.id_multi;
res->error = false;
res->stop = false;
res->stream = slot.params.stream;
res->content = tkn.text_to_send;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->data = json{
{"content", tkn.text_to_send},
{"stop", false},
{"id_slot", slot.id},
{"multimodal", false}
};
slot.update_chat_msg(res.oaicompat_msg_diffs);
slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
// populate res->probs_output
if (slot.sparams.n_probs > 0) {
res.probs_output = { tkn }; // copy the token probs
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
res->probs_output = { tkn }; // copy the token probs
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
}
if (slot.oaicompat) {
res.data["oaicompat_token_ctr"] = slot.n_decoded;
res.data["model"] = slot.oaicompat_model;
res->data["oaicompat_token_ctr"] = slot.n_decoded;
res->data["model"] = slot.oaicompat_model;
}
// populate timings if this is final response or timings_per_token is enabled
if (slot.params.timings_per_token) {
res.timings = slot.get_timings();
res->timings = slot.get_timings();
}
queue_results.send(std::move(res));
}
void server_context::send_final_response(server_slot& slot) {
server_task_result res;
res.final_result = true;
res.id = slot.id_task;
res.id_multi = slot.id_multi;
res.error = false;
res.stop = true; // to do: set value
res.stream = slot.params.stream;
res.include_usage = slot.params.include_usage;
res.content = slot.generated_text;
res.timings = slot.get_timings();
res.post_sampling_probs = slot.params.post_sampling_probs;
res.oaicompat = slot.params.oaicompat;
res.oaicompat_model = slot.params.oaicompat_model;
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs);
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.oaicompat_model = slot.oaicompat_model;
res.data = json{
auto res = std::make_unique<server_task_result_cmpl_final>();
res->final_result = true;
res->id = slot.id_task;
res->id_multi = slot.id_multi;
res->error = false;
res->stop = true; // to do: set value
res->stream = slot.params.stream;
res->include_usage = slot.params.include_usage;
res->content = slot.generated_text;
res->timings = slot.get_timings();
res->post_sampling_probs = slot.params.post_sampling_probs;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->oaicompat_model = slot.oaicompat_model;
res->data = json{
{"content", !slot.params.stream ? slot.generated_text : ""},
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
{"id_slot", slot.id},
@@ -1569,30 +1551,29 @@ void server_context::send_final_response(server_slot& slot) {
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
};
// populate res.probs_output
// populate res->probs_output
if (slot.sparams.n_probs > 0) {
res.probs_output = std::vector<completion_token_output>(
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
}
if (slot.oaicompat) {
res.data["oaicompat_token_ctr"] = slot.n_decoded;
res.data["model"] = slot.oaicompat_model;
res->data["oaicompat_token_ctr"] = slot.n_decoded;
res->data["model"] = slot.oaicompat_model;
}
queue_results.send(std::move(res));
}
void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) {
server_task_result res;
res.id = slot.id_task;
res.id_multi = slot.id_multi;
res.error = false;
res.stop = true;
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.task->id;
res->n_tokens = slot.prompt_tokens.size();
res->oaicompat = slot.task->params.oaicompat;
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f);
@@ -1601,34 +1582,31 @@ void server_context::send_embedding(const server_slot& slot, const llama_batch&
continue;
}
const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
const float* embd = nullptr;
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(ctx, i);
}
else {
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
}
if (embd == NULL) {
LOG_ERROR("failed to get embeddings", {
{"token", batch.token[i]},
{"seq_id", batch.seq_id[i][0]}
});
res.data = json{
{"embedding", std::vector<float>(n_embd, 0.0f)},
{"tokens_evaluated", slot.n_prompt_tokens},
};
if (embd == nullptr) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}
llama_embd_normalize(embd, embd_res.data(), n_embd);
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
res->embedding.push_back(embd_res);
break;
}
res.data = json{
{"embedding", embd_res},
{"tokens_evaluated", slot.n_prompt_tokens},
};
res->embedding.emplace_back(embd, embd + n_embd);
}
queue_results.send(res);
queue_results.send(std::move(res));
}
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
@@ -1708,6 +1686,9 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
void server_context::process_single_task(server_task&& task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
{
const int id_slot = json_value(task.data, "id_slot", -1);
@@ -2471,6 +2452,10 @@ void server_context::update_slots() {
// keep only the common part
// remove the non-common part from the cache
if (slot.n_past < 0)
{
slot.n_past = 0;
}
slot.cache_tokens.keep_first(slot.n_past);
int p0 = (int)system_tokens.size() + slot.n_past;
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
@@ -2549,7 +2534,7 @@ void server_context::update_slots() {
}
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
llama_batch_add(batch, cur_tok, p0, { slot.id }, false);
llama_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
slot.cache_tokens.push_back(cur_tok);
@@ -2663,18 +2648,31 @@ void server_context::update_slots() {
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
int user_cancel = -3;
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
if (ret == user_cancel) {
LOG_ERROR("Decode process is cancelled by user", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
} else {
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
}
for (auto& slot : slots) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release();
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
if (ret != user_cancel) {
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
}
break; // break loop of n_batch
}
@@ -2818,7 +2816,7 @@ json server_context::model_meta() const {
{"vocab_type", llama_vocab_type(model)},
{"n_vocab", llama_n_vocab(model)},
{"n_ctx_train", llama_n_ctx_train(model)},
{"n_embd", llama_n_embd(model)},
{"n_embd", llama_model_n_embd(model)},
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size(model)},
};