Fix slot prompt updating. (#1285)

Co-authored-by: Rkozuch <you@example.com>
This commit is contained in:
rkozuch
2026-02-19 18:15:49 +11:00
committed by GitHub
parent d81cde5cea
commit b855bf92de

View File

@@ -415,7 +415,7 @@ void server_slot::release() {
task.reset();
llama_decode_reset();
}
}
@@ -457,7 +457,7 @@ result_timings server_slot::get_timings() const {
timings.draft_n = n_draft_total;
timings.draft_n_accepted = n_draft_accepted;
}
return timings;
}
@@ -953,6 +953,18 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
// get prompt
if (!task.infill) {
slot.prompt_tokens = std::move(task.tokens);
const auto & prompt = data.find("prompt");
if (prompt != data.end()) {
if (prompt->is_string() ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
}
}
}
// penalize user-provided tokens
@@ -1084,7 +1096,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
if (logit_bias != data.end() && (logit_bias->is_object() || logit_bias->is_array())) {
slot.sparams.logit_bias.clear(); // only clear if user sets it
}
if (logit_bias != data.end() && logit_bias->is_array()) {
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
for (const auto& el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
@@ -1156,7 +1168,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
// ban string
const auto& banned_strings = data.find("banned_strings");
if (banned_strings != data.end() && banned_strings->is_array()) {
slot.ban_phrases.clear();
slot.ban_phrases.clear();
for (const auto& val : data["banned_strings"]) {
if (val.is_string()) {
std::string s = val.get<std::string>();
@@ -1189,7 +1201,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
slot.ban_phrases.push_back(val);
}
}
}
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
params_base.n_buffer = slot.n_buffer;
} else {
@@ -3145,7 +3157,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
void server_context::process_batch_tokens(int32_t & n_batch) {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
extend_context(n_tokens);
extend_context(n_tokens);
llama_batch batch_view = {
n_tokens,