Server: refactor and rename functions (#1151)

* Server: rename functions and refactor code

rename functions

refactor update slots

rename params_base

rename timings

* change

* Revert kv cache name changes

* Revert 2

* fix test build error

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-01-18 00:16:57 -06:00
committed by GitHub
parent 7024fdbc72
commit d71a3ec315
38 changed files with 532 additions and 528 deletions

View File

@@ -37,7 +37,7 @@ server_context::~server_context() {
// Clear any sampling context
for (server_slot& slot : slots) {
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
common_sampler_free(slot.ctx_sampling);
}
if (slot.ctx_dft) {
llama_free(slot.ctx_dft);
@@ -52,16 +52,16 @@ server_context::~server_context() {
}
bool server_context::load_model(const gpt_params& params_) {
params = params_;
params_base = params_;
llama_init_result llama_init = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params_base);
model = llama_init.model;
ctx = llama_init.context;
lora_adapters = llama_init.lora_adapters;
if (model == nullptr) {
LOG_ERROR("unable to load model", { {"model", params.model} });
LOG_ERROR("unable to load model", { {"model", params_base.model} });
return false;
}
@@ -70,26 +70,26 @@ bool server_context::load_model(const gpt_params& params_) {
add_bos_token = llama_should_add_bos_token(model);
has_eos_token = llama_add_eos_token(model) != 1;
chat_templates = common_chat_templates_init(model, params.chat_template);
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja, {});
common_chat_format_example(chat_templates.get(), params_base.use_jinja, {});
}
catch (const std::exception& e) {
LOG_WARNING("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_init(model, "chatml");
}
bool has_draft_model = !params.model_draft.empty() || !params.draft_params.empty();
std::string& mmproj_path = params.mmproj.path;
bool has_draft_model = !params_base.model_draft.empty() || !params_base.draft_params.empty();
std::string& mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params.mmproj_use_gpu;
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params.n_threads;
mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.image_min_tokens = params.image_min_tokens;
mparams.image_max_tokens = params.image_max_tokens;
mparams.n_threads = params_base.n_threads;
mparams.flash_attn_type = params_base.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
@@ -97,8 +97,8 @@ bool server_context::load_model(const gpt_params& params_) {
}
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
if (params.ctx_shift) {
params.ctx_shift = false;
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
LOG_WARNING("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
}
@@ -117,15 +117,15 @@ bool server_context::load_model(const gpt_params& params_) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
gpt_params params_dft;
params_dft.devices = params.devices_draft;
params_dft.model = params.model_draft;
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
params_dft.rpc_servers = params.rpc_servers;
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
params_dft.flash_attn = params.flash_attn;
if (!params.draft_params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params.draft_params);
params_dft.devices = params_base.devices_draft;
params_dft.model = params_base.model_draft;
params_dft.n_gpu_layers = params_base.n_gpu_layers_draft;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params_base.cache_type_k_draft.empty() ? params_base.cache_type_k : params_base.cache_type_k_draft;
params_dft.cache_type_v = params_base.cache_type_v_draft.empty() ? params_base.cache_type_v : params_base.cache_type_v_draft;
params_dft.flash_attn = params_base.flash_attn;
if (!params_base.draft_params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params_base.draft_params);
if (!gpt_params_parse(argc, argv, params_dft)) {
gpt_params_print_usage(argc, argv, params_dft);
free_command_line(argc, argv);
@@ -135,16 +135,16 @@ bool server_context::load_model(const gpt_params& params_) {
}
LOG_INFO("", { {"model", params_dft.model} });
if (params_dft.n_ctx == 0) {
params_dft.n_ctx = params.n_ctx_draft;
params_dft.n_ctx = params_base.n_ctx_draft;
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx;
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
llama_model* model_dft = llama_init_dft.model;
if (model_dft == nullptr) {
LOG_ERROR("failed to load draft model", { {"model", params.model_draft} });
LOG_ERROR("failed to load draft model", { {"model", params_base.model_draft} });
return false;
}
@@ -163,22 +163,22 @@ bool server_context::load_model(const gpt_params& params_) {
}
void server_context::init() {
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
LOG_INFO("initializing slots", { {"n_slots", params.n_parallel} });
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
for (int i = 0; i < params.n_parallel; i++) {
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
slot.n_predict = params_base.n_predict;
slot.mctx = mctx;
slot.cache_tokens.has_mtmd = mctx != nullptr;
slot.params.think_tokens = params.think_tokens;
if (params.think_tokens.exclude) {
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params.think_tokens.begin.c_str(), params.think_tokens.end.c_str() );
slot.params.think_tokens = params_base.think_tokens;
if (params_base.think_tokens.exclude) {
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params_base.think_tokens.begin.c_str(), params_base.think_tokens.end.c_str() );
}
else {
SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
@@ -188,8 +188,8 @@ void server_context::init() {
{"n_ctx_slot", slot.n_ctx}
});
const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w;
const int ga_n = params_base.grp_attn_n;
const int ga_w = params_base.grp_attn_w;
if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
@@ -208,7 +208,7 @@ void server_context::init() {
slot.ga_n = ga_n;
slot.ga_w = ga_w;
slot.sparams = params.sparams;
slot.sparams = params_base.sparams;
// Initialize speculative decoding if a draft model is loaded
if (ctx_draft) {
@@ -225,7 +225,7 @@ void server_context::init() {
LOG_ERROR("failed to create speculator", {});
return;
}
for (auto& pair : params.replacements_draft) {
for (auto& pair : params_base.replacements_draft) {
llama_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
}
@@ -245,21 +245,21 @@ void server_context::init() {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
if (params.cache_ram_mib != 0) {
if (params.cache_ram_mib < 0) {
if (params_base.cache_ram_mib != 0) {
if (params_base.cache_ram_mib < 0) {
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit");
}
else {
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib);
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
}
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
// only apply ram size limit. No token limit for now.
prompt_cache = std::make_unique<server_prompt_cache>(ctx, params.cache_ram_mib, 0);
prompt_cache = std::make_unique<server_prompt_cache>(ctx, params_base.cache_ram_mib, 0);
}
else {
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
@@ -268,14 +268,14 @@ void server_context::init() {
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params.use_jinja && params.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
//LLAMA_LOG_INFO("Enable thinking? %d\n", enable_thinking);
oai_parser_opt = {
/* use_jinja */ params.use_jinja,
/* prefill_assistant */ params.prefill_assistant,
/* reasoning_format */ params.reasoning_format,
/* chat_template_kwargs */ params.default_template_kwargs,
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* common_chat_templates */ chat_templates.get(),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio(mctx) : false,
@@ -500,34 +500,19 @@ size_t server_slot::find_stopping_strings(const std::string& text, const size_t
void server_slot::print_timings() const {
char buffer[512];
double t_token = t_prompt_processing / n_prompt_tokens_processed;
double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
//snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
// t_prompt_processing, n_prompt_tokens_processed,
// t_token, n_tokens_second);
double t_gen = t_token_generation / n_decoded;
double n_gen_second = 1e3 / t_token_generation * n_decoded;
//LOG_INFO(buffer, {});
double t_token_gen = t_token_generation / n_decoded;
double n_tokens_second_gen = 1e3 / t_token_generation * n_decoded;
//snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
// t_token_generation, n_decoded,
// t_token, n_tokens_second);
//LOG_INFO(buffer, {});
//snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
//LOG_INFO(buffer, {});
SLT_INF(*this,
"\n"
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" total time = %10.2f ms / %5d tokens\n",
t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second,
t_token_generation, n_decoded, t_token_gen, n_tokens_second_gen,
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
t_token_generation, n_decoded, t_gen, n_gen_second,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
if (n_draft_total > 0) {
@@ -795,7 +780,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
llama_sampling_params default_sparams = params.sparams;
llama_sampling_params default_sparams = params_base.sparams;
auto& data = task.data;
if (data.count("__oaicompat") != 0) {
@@ -848,9 +833,9 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs);
// speculative decoding parameters
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min);
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params_base.n_draft);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params_base.n_draft_min);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params_base.p_draft_min);
// Clamp speculative parameters
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
@@ -945,7 +930,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
if (penalty_prompt != data.end()) {
if (penalty_prompt->is_string()) {
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
slot.sparams.penalty_prompt_tokens = common_tokenize(model, penalty_prompt_string, false);
if (slot.params.n_predict > 0) {
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
@@ -988,7 +973,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
else {
slot.params.oaicompat_chat_syntax.format = default_params.oaicompat_chat_syntax.format;
}
common_reasoning_format reasoning_format = params.reasoning_format;
common_reasoning_format reasoning_format = params_base.reasoning_format;
if (data.contains("reasoning_format")) {
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
}
@@ -1003,7 +988,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto& t : *preserved_tokens) {
auto ids = llama_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG("Preserved token: %d\n", ids[0]);
slot.sparams.preserved_tokens.insert(ids[0]);
@@ -1020,7 +1005,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
server_grammar_trigger ct(t);
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
const auto& word = ct.value.value;
auto ids = llama_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
auto ids = common_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) {
@@ -1085,7 +1070,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
}
}
else if (el[0].is_string()) {
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
slot.sparams.logit_bias[tok] = bias;
}
@@ -1128,9 +1113,9 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
{
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
common_sampler_free(slot.ctx_sampling);
}
slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), slot.sparams);
slot.ctx_sampling = common_sampler_init(llama_get_model_vocab(model), slot.sparams);
if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@@ -1174,10 +1159,10 @@ void server_context::system_prompt_update() {
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
llama_batch_clear(batch);
common_batch_clear(batch);
for (int32_t j = 0; j < n_tokens; ++j) {
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
@@ -1187,7 +1172,7 @@ void server_context::system_prompt_update() {
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) {
for (int32_t i = 1; i <= params_base.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}
@@ -1268,7 +1253,7 @@ bool server_context::process_token(completion_token_output& result, server_slot&
}
// check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
slot.stopped_limit = true;
slot.has_next_token = false;
@@ -1297,7 +1282,7 @@ bool server_context::process_token(completion_token_output& result, server_slot&
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
{ "slot.n_decoded", slot.n_decoded },
{ "slot.n_predict", slot.n_predict },
{ "n_slots", params.n_parallel },
{ "n_slots", params_base.n_parallel },
{ "slot.n_ctx", slot.n_ctx },
{ "n_ctx", n_ctx },
{ "n_ctx_train", n_ctx_train },
@@ -1330,7 +1315,7 @@ void server_context::populate_token_probs(const server_slot& slot, completion_to
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
if (post_sampling) {
const auto* cur_p = llama_sampling_get_candidates(slot.ctx_sampling);
const auto* cur_p = common_sampler_get_candidates(slot.ctx_sampling);
const size_t max_probs = cur_p->size;
// set probability for sampled token
@@ -1346,7 +1331,7 @@ void server_context::populate_token_probs(const server_slot& slot, completion_to
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
result.probs.push_back({
cur_p->data[i].id,
llama_detokenize(ctx, {cur_p->data[i].id}, special),
common_token_to_piece(ctx, {cur_p->data[i].id}, special),
cur_p->data[i].p
});
}
@@ -1362,7 +1347,7 @@ void server_context::populate_token_probs(const server_slot& slot, completion_to
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
result.probs.push_back({
cur[i].id,
llama_detokenize(ctx, {cur[i].id}, special),
common_token_to_piece(ctx, {cur[i].id}, special),
cur[i].p
});
}
@@ -1387,7 +1372,7 @@ json server_context::get_formated_generation(const server_slot& slot) const {
return json{
{"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params.model_alias},
{"model", params_base.model_alias},
{"seed", slot.sparams.seed},
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
@@ -1548,7 +1533,7 @@ void server_context::send_final_response(server_slot& slot) {
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
{"id_slot", slot.id},
{"stop", true},
{"model", params.model_alias},
{"model", params_base.model_alias},
{"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
@@ -2067,12 +2052,8 @@ void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot,
slot.n_prompt_tokens = slot.prompt_tokens.size();
}
void server_context::update_slots() {
if (system_need_update) {
system_prompt_update();
}
// release slots
void server_context::release_slots()
{
for (auto& slot : slots) {
if (slot.command == SLOT_COMMAND_RELEASE) {
slot.state = SLOT_STATE_IDLE;
@@ -2092,11 +2073,10 @@ void server_context::update_slots() {
queue_tasks.notify_slot_changed();
}
}
}
// check if all slots are idle
{
bool server_context::slots_idle(){
bool all_idle = true;
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
all_idle = false;
@@ -2109,27 +2089,16 @@ void server_context::update_slots() {
if (system_prompt.empty() && clean_kv_cache) {
kv_cache_clear();
}
return;
all_idle = true;
}
}
return all_idle;
}
{
LOG_VERBOSE("posting NEXT_RESPONSE", {});
server_task task;
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
task.id_target = -1;
queue_tasks.post(std::move(task));
}
// apply context-shift if needed
// TODO: simplify and improve
void server_context::context_shift() {
for (server_slot& slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
if (!params.ctx_shift) {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
@@ -2176,15 +2145,9 @@ void server_context::update_slots() {
}
}
}
}
// start populating the batch for this iteration
llama_batch_clear(batch);
auto accept_special_token = [&](server_slot& slot, llama_token token) {
return params.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences
void server_context::add_sampled_tokens() {
for (auto& slot : slots) {
if (slot.state == SLOT_STATE_IDLE) {
continue;
@@ -2209,7 +2172,7 @@ void server_context::update_slots() {
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(slot.sampled);
if (slot.params.speculative.n_min > (int)draft.size()) {
@@ -2226,7 +2189,7 @@ void server_context::update_slots() {
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
llama_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
common_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
@@ -2236,7 +2199,7 @@ void server_context::update_slots() {
// no speculative decoding
slot.i_batch = batch.n_tokens;
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(slot.sampled);
@@ -2245,18 +2208,10 @@ void server_context::update_slots() {
}
slot.n_past = slot.cache_tokens.n_tokens();
}
}
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) {
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto& slot : slots) {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
@@ -2275,8 +2230,8 @@ void server_context::update_slots() {
if (slot.infill) {
const bool add_bos = llama_should_add_bos_token(model);
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
if (params_base.input_suffix.find_first_of(' ') == 0 && params_base.input_suffix.size() > 1) {
params_base.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
@@ -2291,8 +2246,8 @@ void server_context::update_slots() {
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
auto embd_inp = params_base.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params_base.spm_infill ? prefix_tokens : suffix_tokens;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
@@ -2350,7 +2305,7 @@ void server_context::update_slots() {
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
// context shift for prompt processing
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
if (!params.ctx_shift) {
if (!params_base.ctx_shift) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
slot.release();
continue;
@@ -2389,7 +2344,7 @@ void server_context::update_slots() {
else {
slot.n_discarded_prompt = 0;
}
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
@@ -2424,7 +2379,7 @@ void server_context::update_slots() {
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
common_sampler_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
}
}
@@ -2486,7 +2441,7 @@ void server_context::update_slots() {
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
}
LOG_INFO("kv cache rm [p0, end)", {
@@ -2546,7 +2501,7 @@ void server_context::update_slots() {
}
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
llama_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
common_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
slot.cache_tokens.push_back(cur_tok);
@@ -2571,11 +2526,11 @@ void server_context::update_slots() {
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size());
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
llama_token id = slot.prompt_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
llama_sampling_accept(slot.ctx_sampling, ctx, id, false);
common_sampler_accept(slot.ctx_sampling, ctx, id, false);
}
}
@@ -2599,51 +2554,107 @@ void server_context::update_slots() {
}
}
}
}
if (batch.n_tokens == 0) {
LOG_VERBOSE("no tokens to decode", {});
return;
void server_context::extend_context(const int32_t n_tokens) {
for (auto& slot : slots) {
if (slot.ga_n != 1) {
// context extension via Self-Extend
// TODO: simplify and/or abstract this
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;
slot.ga_i += slot.ga_w / slot.ga_n;
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}
slot.n_past_se += n_tokens;
}
}
}
LOG_VERBOSE("decoding batch", {
{"n_tokens", batch.n_tokens},
});
void server_context::speculative_decoding_accept() {
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
continue;
}
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
size_t n_draft = slot.drafted.size();
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();
for (auto& slot : slots) {
if (slot.ga_n != 1) {
// context extension via Self-Extend
// TODO: simplify and/or abstract this
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
slot.n_past += ids.size();
slot.n_decoded += ids.size();
const int64_t t_current = ggml_time_us();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
// rollback to the state before sampling the draft tokens
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
// add accepted tokens to the prompt
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
slot.sampled = ids.back(); // last accepted token
slot.n_past = slot.cache_tokens.n_tokens();
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
slot.n_past_se -= bd;
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
slot.ga_i += slot.ga_w / slot.ga_n;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}
if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
}
slot.n_past_se += n_tokens;
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
LOG_VERBOSE("speculative decoding result", {
{"id_slot", slot.id},
{"accepted", (int)ids.size() - 1},
{"total", (int)slot.drafted.size()},
{"new_n_past", slot.n_past}
});
}
}
bool server_context::accept_special_token(const server_slot& slot, const llama_token token) {
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
};
void server_context::process_batch_tokens(int32_t & n_batch) {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
extend_context(n_tokens);
llama_batch batch_view = {
n_tokens,
@@ -2661,14 +2672,11 @@ void server_context::update_slots() {
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
int user_cancel = -3;
// if you get here, it means the KV cache is full - try increasing it via the context size
if (ret == user_cancel) {
LOG_ERROR("Decode process is cancelled by user", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
} else {
LLAMA_LOG_INFO("Decode process is cancelled by user.\n");
}
else {
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
{"i", i},
{"n_batch", ret},
@@ -2684,12 +2692,9 @@ void server_context::update_slots() {
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
}
break; // break loop of n_batch
}
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
@@ -2703,10 +2708,6 @@ void server_context::update_slots() {
continue; // continue loop of n_batch
}
// technically, measuring the time here excludes the sampling time for the last batch
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
const int64_t t_current = ggml_time_us();
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
continue; // continue loop of slots
@@ -2725,9 +2726,9 @@ void server_context::update_slots() {
continue; // sample using speculative decoding
}
const int tok_idx = slot.i_batch - i;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
slot.n_decoded += 1;
@@ -2739,15 +2740,14 @@ void server_context::update_slots() {
metrics.on_prompt_eval(slot);
}
//slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
result.tok = id;
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx);
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {
@@ -2761,64 +2761,67 @@ void server_context::update_slots() {
}
// speculative decoding - main model sample and accept
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
continue;
}
size_t n_draft = slot.drafted.size();
// the accepted tokens from the speculation
const auto ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();
slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
// rollback to the state before sampling the draft tokens
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
// slot.n_past -= n_draft;
// add accepted tokens to the prompt
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
slot.sampled = ids.back(); // last accepted token
slot.n_past = slot.cache_tokens.n_tokens();
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i);
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
LOG_VERBOSE("speculative decoding result", {
{"id_slot", slot.id},
{"accepted", (int)ids.size() - 1},
{"total", (int)slot.drafted.size()},
{"new_n_past", slot.n_past}
});
}
speculative_decoding_accept();
}
}
void server_context::update_slots() {
if (system_need_update) {
system_prompt_update();
}
// release slots
release_slots();
// check if all slots are idle
if (slots_idle()) {
return;
}
{
LOG_VERBOSE("posting NEXT_RESPONSE", {});
server_task task;
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
task.id_target = -1;
queue_tasks.post(std::move(task));
}
// apply context-shift if needed
// TODO: simplify and improve
context_shift();
// start populating the batch for this iteration
common_batch_clear(batch);
// frist, add sampled tokens from any ongoing sequences
add_sampled_tokens();
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
batch_pending_prompt(n_ubatch, n_batch, batch_type);
if (batch.n_tokens == 0) {
LOG_VERBOSE("no tokens to decode", {});
return;
}
LOG_VERBOSE("decoding batch", {
{"n_tokens", batch.n_tokens},
});
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
// process the created batch of tokens
process_batch_tokens(n_batch);
LOG_VERBOSE("run slots completed", {});
}