mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-03 13:04:59 +00:00
Server: refactor and rename functions (#1151)
* Server: rename functions and refactor code rename functions refactor update slots rename params_base rename timings * change * Revert kv cache name changes * Revert 2 * fix test build error --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -37,7 +37,7 @@ server_context::~server_context() {
|
||||
// Clear any sampling context
|
||||
for (server_slot& slot : slots) {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
if (slot.ctx_dft) {
|
||||
llama_free(slot.ctx_dft);
|
||||
@@ -52,16 +52,16 @@ server_context::~server_context() {
|
||||
}
|
||||
|
||||
bool server_context::load_model(const gpt_params& params_) {
|
||||
params = params_;
|
||||
params_base = params_;
|
||||
|
||||
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
||||
llama_init_result llama_init = llama_init_from_gpt_params(params_base);
|
||||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
lora_adapters = llama_init.lora_adapters;
|
||||
|
||||
if (model == nullptr) {
|
||||
LOG_ERROR("unable to load model", { {"model", params.model} });
|
||||
LOG_ERROR("unable to load model", { {"model", params_base.model} });
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -70,26 +70,26 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
add_bos_token = llama_should_add_bos_token(model);
|
||||
has_eos_token = llama_add_eos_token(model) != 1;
|
||||
|
||||
chat_templates = common_chat_templates_init(model, params.chat_template);
|
||||
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
||||
try {
|
||||
common_chat_format_example(chat_templates.get(), params.use_jinja, {});
|
||||
common_chat_format_example(chat_templates.get(), params_base.use_jinja, {});
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
LOG_WARNING("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
chat_templates = common_chat_templates_init(model, "chatml");
|
||||
}
|
||||
|
||||
bool has_draft_model = !params.model_draft.empty() || !params.draft_params.empty();
|
||||
std::string& mmproj_path = params.mmproj.path;
|
||||
bool has_draft_model = !params_base.model_draft.empty() || !params_base.draft_params.empty();
|
||||
std::string& mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params.n_threads;
|
||||
mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
mparams.n_threads = params_base.n_threads;
|
||||
mparams.flash_attn_type = params_base.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
@@ -97,8 +97,8 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
}
|
||||
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
|
||||
if (params.ctx_shift) {
|
||||
params.ctx_shift = false;
|
||||
if (params_base.ctx_shift) {
|
||||
params_base.ctx_shift = false;
|
||||
LOG_WARNING("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
|
||||
}
|
||||
|
||||
@@ -117,15 +117,15 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
||||
|
||||
gpt_params params_dft;
|
||||
params_dft.devices = params.devices_draft;
|
||||
params_dft.model = params.model_draft;
|
||||
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
params_dft.rpc_servers = params.rpc_servers;
|
||||
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
|
||||
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
|
||||
params_dft.flash_attn = params.flash_attn;
|
||||
if (!params.draft_params.empty()) {
|
||||
auto [argc, argv] = parse_command_line("llama-server " + params.draft_params);
|
||||
params_dft.devices = params_base.devices_draft;
|
||||
params_dft.model = params_base.model_draft;
|
||||
params_dft.n_gpu_layers = params_base.n_gpu_layers_draft;
|
||||
params_dft.rpc_servers = params_base.rpc_servers;
|
||||
params_dft.cache_type_k = params_base.cache_type_k_draft.empty() ? params_base.cache_type_k : params_base.cache_type_k_draft;
|
||||
params_dft.cache_type_v = params_base.cache_type_v_draft.empty() ? params_base.cache_type_v : params_base.cache_type_v_draft;
|
||||
params_dft.flash_attn = params_base.flash_attn;
|
||||
if (!params_base.draft_params.empty()) {
|
||||
auto [argc, argv] = parse_command_line("llama-server " + params_base.draft_params);
|
||||
if (!gpt_params_parse(argc, argv, params_dft)) {
|
||||
gpt_params_print_usage(argc, argv, params_dft);
|
||||
free_command_line(argc, argv);
|
||||
@@ -135,16 +135,16 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
}
|
||||
LOG_INFO("", { {"model", params_dft.model} });
|
||||
if (params_dft.n_ctx == 0) {
|
||||
params_dft.n_ctx = params.n_ctx_draft;
|
||||
params_dft.n_ctx = params_base.n_ctx_draft;
|
||||
}
|
||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx;
|
||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_batch = params_dft.n_ctx;
|
||||
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
|
||||
|
||||
llama_model* model_dft = llama_init_dft.model;
|
||||
if (model_dft == nullptr) {
|
||||
LOG_ERROR("failed to load draft model", { {"model", params.model_draft} });
|
||||
LOG_ERROR("failed to load draft model", { {"model", params_base.model_draft} });
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -163,22 +163,22 @@ bool server_context::load_model(const gpt_params& params_) {
|
||||
}
|
||||
|
||||
void server_context::init() {
|
||||
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
||||
|
||||
LOG_INFO("initializing slots", { {"n_slots", params.n_parallel} });
|
||||
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
|
||||
|
||||
for (int i = 0; i < params.n_parallel; i++) {
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
slot.id = i;
|
||||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.n_predict = params.n_predict;
|
||||
slot.n_predict = params_base.n_predict;
|
||||
slot.mctx = mctx;
|
||||
slot.cache_tokens.has_mtmd = mctx != nullptr;
|
||||
slot.params.think_tokens = params.think_tokens;
|
||||
if (params.think_tokens.exclude) {
|
||||
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params.think_tokens.begin.c_str(), params.think_tokens.end.c_str() );
|
||||
slot.params.think_tokens = params_base.think_tokens;
|
||||
if (params_base.think_tokens.exclude) {
|
||||
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params_base.think_tokens.begin.c_str(), params_base.think_tokens.end.c_str() );
|
||||
}
|
||||
else {
|
||||
SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
|
||||
@@ -188,8 +188,8 @@ void server_context::init() {
|
||||
{"n_ctx_slot", slot.n_ctx}
|
||||
});
|
||||
|
||||
const int ga_n = params.grp_attn_n;
|
||||
const int ga_w = params.grp_attn_w;
|
||||
const int ga_n = params_base.grp_attn_n;
|
||||
const int ga_w = params_base.grp_attn_w;
|
||||
|
||||
if (ga_n != 1) {
|
||||
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
||||
@@ -208,7 +208,7 @@ void server_context::init() {
|
||||
slot.ga_n = ga_n;
|
||||
slot.ga_w = ga_w;
|
||||
|
||||
slot.sparams = params.sparams;
|
||||
slot.sparams = params_base.sparams;
|
||||
|
||||
// Initialize speculative decoding if a draft model is loaded
|
||||
if (ctx_draft) {
|
||||
@@ -225,7 +225,7 @@ void server_context::init() {
|
||||
LOG_ERROR("failed to create speculator", {});
|
||||
return;
|
||||
}
|
||||
for (auto& pair : params.replacements_draft) {
|
||||
for (auto& pair : params_base.replacements_draft) {
|
||||
llama_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
|
||||
@@ -245,21 +245,21 @@ void server_context::init() {
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
// only a single seq_id per token is needed
|
||||
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
|
||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
|
||||
if (params.cache_ram_mib != 0) {
|
||||
if (params.cache_ram_mib < 0) {
|
||||
if (params_base.cache_ram_mib != 0) {
|
||||
if (params_base.cache_ram_mib < 0) {
|
||||
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit");
|
||||
}
|
||||
else {
|
||||
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib);
|
||||
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
|
||||
// only apply ram size limit. No token limit for now.
|
||||
prompt_cache = std::make_unique<server_prompt_cache>(ctx, params.cache_ram_mib, 0);
|
||||
prompt_cache = std::make_unique<server_prompt_cache>(ctx, params_base.cache_ram_mib, 0);
|
||||
}
|
||||
else {
|
||||
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
|
||||
@@ -268,14 +268,14 @@ void server_context::init() {
|
||||
// thinking is enabled if:
|
||||
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
||||
// 2. The chat template supports it
|
||||
const bool enable_thinking = params.use_jinja && params.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||
//LLAMA_LOG_INFO("Enable thinking? %d\n", enable_thinking);
|
||||
|
||||
oai_parser_opt = {
|
||||
/* use_jinja */ params.use_jinja,
|
||||
/* prefill_assistant */ params.prefill_assistant,
|
||||
/* reasoning_format */ params.reasoning_format,
|
||||
/* chat_template_kwargs */ params.default_template_kwargs,
|
||||
/* use_jinja */ params_base.use_jinja,
|
||||
/* prefill_assistant */ params_base.prefill_assistant,
|
||||
/* reasoning_format */ params_base.reasoning_format,
|
||||
/* chat_template_kwargs */ params_base.default_template_kwargs,
|
||||
/* common_chat_templates */ chat_templates.get(),
|
||||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||
/* allow_audio */ mctx ? mtmd_support_audio(mctx) : false,
|
||||
@@ -500,34 +500,19 @@ size_t server_slot::find_stopping_strings(const std::string& text, const size_t
|
||||
|
||||
void server_slot::print_timings() const {
|
||||
char buffer[512];
|
||||
double t_token = t_prompt_processing / n_prompt_tokens_processed;
|
||||
double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
||||
double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
||||
double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
||||
|
||||
//snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||
// t_prompt_processing, n_prompt_tokens_processed,
|
||||
// t_token, n_tokens_second);
|
||||
double t_gen = t_token_generation / n_decoded;
|
||||
double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
||||
|
||||
//LOG_INFO(buffer, {});
|
||||
|
||||
double t_token_gen = t_token_generation / n_decoded;
|
||||
double n_tokens_second_gen = 1e3 / t_token_generation * n_decoded;
|
||||
|
||||
//snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||
// t_token_generation, n_decoded,
|
||||
// t_token, n_tokens_second);
|
||||
|
||||
//LOG_INFO(buffer, {});
|
||||
|
||||
//snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||
|
||||
//LOG_INFO(buffer, {});
|
||||
SLT_INF(*this,
|
||||
"\n"
|
||||
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
||||
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
||||
" total time = %10.2f ms / %5d tokens\n",
|
||||
t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second,
|
||||
t_token_generation, n_decoded, t_token_gen, n_tokens_second_gen,
|
||||
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
||||
t_token_generation, n_decoded, t_gen, n_gen_second,
|
||||
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
||||
|
||||
if (n_draft_total > 0) {
|
||||
@@ -795,7 +780,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
|
||||
bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) {
|
||||
slot_params default_params;
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||
llama_sampling_params default_sparams = params.sparams;
|
||||
llama_sampling_params default_sparams = params_base.sparams;
|
||||
auto& data = task.data;
|
||||
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
@@ -848,9 +833,9 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs);
|
||||
|
||||
// speculative decoding parameters
|
||||
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
|
||||
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
|
||||
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min);
|
||||
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params_base.n_draft);
|
||||
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params_base.n_draft_min);
|
||||
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params_base.p_draft_min);
|
||||
|
||||
// Clamp speculative parameters
|
||||
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
||||
@@ -945,7 +930,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
if (penalty_prompt != data.end()) {
|
||||
if (penalty_prompt->is_string()) {
|
||||
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
||||
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
||||
slot.sparams.penalty_prompt_tokens = common_tokenize(model, penalty_prompt_string, false);
|
||||
|
||||
if (slot.params.n_predict > 0) {
|
||||
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
||||
@@ -988,7 +973,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
else {
|
||||
slot.params.oaicompat_chat_syntax.format = default_params.oaicompat_chat_syntax.format;
|
||||
}
|
||||
common_reasoning_format reasoning_format = params.reasoning_format;
|
||||
common_reasoning_format reasoning_format = params_base.reasoning_format;
|
||||
if (data.contains("reasoning_format")) {
|
||||
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
||||
}
|
||||
@@ -1003,7 +988,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
const auto preserved_tokens = data.find("preserved_tokens");
|
||||
if (preserved_tokens != data.end()) {
|
||||
for (const auto& t : *preserved_tokens) {
|
||||
auto ids = llama_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
||||
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG("Preserved token: %d\n", ids[0]);
|
||||
slot.sparams.preserved_tokens.insert(ids[0]);
|
||||
@@ -1020,7 +1005,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
server_grammar_trigger ct(t);
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
const auto& word = ct.value.value;
|
||||
auto ids = llama_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
|
||||
auto ids = common_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
auto token = ids[0];
|
||||
if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) {
|
||||
@@ -1085,7 +1070,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
}
|
||||
}
|
||||
else if (el[0].is_string()) {
|
||||
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
||||
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
||||
for (auto tok : toks) {
|
||||
slot.sparams.logit_bias[tok] = bias;
|
||||
}
|
||||
@@ -1128,9 +1113,9 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
|
||||
{
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
common_sampler_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), slot.sparams);
|
||||
slot.ctx_sampling = common_sampler_init(llama_get_model_vocab(model), slot.sparams);
|
||||
if (slot.ctx_sampling == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
@@ -1174,10 +1159,10 @@ void server_context::system_prompt_update() {
|
||||
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
||||
|
||||
llama_batch_clear(batch);
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (int32_t j = 0; j < n_tokens; ++j) {
|
||||
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
||||
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
@@ -1187,7 +1172,7 @@ void server_context::system_prompt_update() {
|
||||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
||||
for (int32_t i = 1; i <= params_base.n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
@@ -1268,7 +1253,7 @@ bool server_context::process_token(completion_token_output& result, server_slot&
|
||||
}
|
||||
|
||||
// check the limits
|
||||
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
|
||||
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
|
||||
@@ -1297,7 +1282,7 @@ bool server_context::process_token(completion_token_output& result, server_slot&
|
||||
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
|
||||
{ "slot.n_decoded", slot.n_decoded },
|
||||
{ "slot.n_predict", slot.n_predict },
|
||||
{ "n_slots", params.n_parallel },
|
||||
{ "n_slots", params_base.n_parallel },
|
||||
{ "slot.n_ctx", slot.n_ctx },
|
||||
{ "n_ctx", n_ctx },
|
||||
{ "n_ctx_train", n_ctx_train },
|
||||
@@ -1330,7 +1315,7 @@ void server_context::populate_token_probs(const server_slot& slot, completion_to
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
if (post_sampling) {
|
||||
const auto* cur_p = llama_sampling_get_candidates(slot.ctx_sampling);
|
||||
const auto* cur_p = common_sampler_get_candidates(slot.ctx_sampling);
|
||||
const size_t max_probs = cur_p->size;
|
||||
|
||||
// set probability for sampled token
|
||||
@@ -1346,7 +1331,7 @@ void server_context::populate_token_probs(const server_slot& slot, completion_to
|
||||
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
llama_detokenize(ctx, {cur_p->data[i].id}, special),
|
||||
common_token_to_piece(ctx, {cur_p->data[i].id}, special),
|
||||
cur_p->data[i].p
|
||||
});
|
||||
}
|
||||
@@ -1362,7 +1347,7 @@ void server_context::populate_token_probs(const server_slot& slot, completion_to
|
||||
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur[i].id,
|
||||
llama_detokenize(ctx, {cur[i].id}, special),
|
||||
common_token_to_piece(ctx, {cur[i].id}, special),
|
||||
cur[i].p
|
||||
});
|
||||
}
|
||||
@@ -1387,7 +1372,7 @@ json server_context::get_formated_generation(const server_slot& slot) const {
|
||||
return json{
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_predict", slot.n_predict}, // Server configured n_predict
|
||||
{"model", params.model_alias},
|
||||
{"model", params_base.model_alias},
|
||||
{"seed", slot.sparams.seed},
|
||||
{"temperature", slot.sparams.temp},
|
||||
{"dynatemp_range", slot.sparams.dynatemp_range},
|
||||
@@ -1548,7 +1533,7 @@ void server_context::send_final_response(server_slot& slot) {
|
||||
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
|
||||
{"id_slot", slot.id},
|
||||
{"stop", true},
|
||||
{"model", params.model_alias},
|
||||
{"model", params_base.model_alias},
|
||||
{"tokens_predicted", slot.n_decoded},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
@@ -2067,12 +2052,8 @@ void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot,
|
||||
slot.n_prompt_tokens = slot.prompt_tokens.size();
|
||||
}
|
||||
|
||||
void server_context::update_slots() {
|
||||
if (system_need_update) {
|
||||
system_prompt_update();
|
||||
}
|
||||
|
||||
// release slots
|
||||
void server_context::release_slots()
|
||||
{
|
||||
for (auto& slot : slots) {
|
||||
if (slot.command == SLOT_COMMAND_RELEASE) {
|
||||
slot.state = SLOT_STATE_IDLE;
|
||||
@@ -2092,11 +2073,10 @@ void server_context::update_slots() {
|
||||
queue_tasks.notify_slot_changed();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if all slots are idle
|
||||
{
|
||||
bool server_context::slots_idle(){
|
||||
bool all_idle = true;
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
|
||||
all_idle = false;
|
||||
@@ -2109,27 +2089,16 @@ void server_context::update_slots() {
|
||||
if (system_prompt.empty() && clean_kv_cache) {
|
||||
kv_cache_clear();
|
||||
}
|
||||
|
||||
return;
|
||||
all_idle = true;
|
||||
}
|
||||
}
|
||||
return all_idle;
|
||||
}
|
||||
|
||||
{
|
||||
LOG_VERBOSE("posting NEXT_RESPONSE", {});
|
||||
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
|
||||
task.id_target = -1;
|
||||
|
||||
queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
void server_context::context_shift() {
|
||||
for (server_slot& slot : slots) {
|
||||
if (slot.ga_n == 1) {
|
||||
if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
||||
if (!params.ctx_shift) {
|
||||
if (!params_base.ctx_shift) {
|
||||
// this check is redundant (for good)
|
||||
// we should never get here, because generation should already stopped in process_token()
|
||||
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
||||
@@ -2176,15 +2145,9 @@ void server_context::update_slots() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// start populating the batch for this iteration
|
||||
llama_batch_clear(batch);
|
||||
|
||||
auto accept_special_token = [&](server_slot& slot, llama_token token) {
|
||||
return params.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
||||
};
|
||||
|
||||
// frist, add sampled tokens from any ongoing sequences
|
||||
void server_context::add_sampled_tokens() {
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state == SLOT_STATE_IDLE) {
|
||||
continue;
|
||||
@@ -2209,7 +2172,7 @@ void server_context::update_slots() {
|
||||
|
||||
// add the sampled token to the batch
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
|
||||
if (slot.params.speculative.n_min > (int)draft.size()) {
|
||||
@@ -2226,7 +2189,7 @@ void server_context::update_slots() {
|
||||
// add all drafted tokens to the batch
|
||||
for (size_t i = 0; i < draft.size(); i++) {
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
llama_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
common_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
slot.cache_tokens.push_back(draft[i]);
|
||||
}
|
||||
slot.drafted = std::move(draft);
|
||||
@@ -2236,7 +2199,7 @@ void server_context::update_slots() {
|
||||
// no speculative decoding
|
||||
slot.i_batch = batch.n_tokens;
|
||||
|
||||
llama_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
||||
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
|
||||
@@ -2245,18 +2208,10 @@ void server_context::update_slots() {
|
||||
}
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
}
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// track if this is an embedding or non-embedding batch
|
||||
// if we've added sampled tokens above, we are in non-embedding mode
|
||||
// -1: none, 0: non-embedding, 1: embedding
|
||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto& slot : slots) {
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
||||
@@ -2275,8 +2230,8 @@ void server_context::update_slots() {
|
||||
if (slot.infill) {
|
||||
const bool add_bos = llama_should_add_bos_token(model);
|
||||
bool suff_rm_leading_spc = true;
|
||||
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||
params.input_suffix.erase(0, 1);
|
||||
if (params_base.input_suffix.find_first_of(' ') == 0 && params_base.input_suffix.size() > 1) {
|
||||
params_base.input_suffix.erase(0, 1);
|
||||
suff_rm_leading_spc = false;
|
||||
}
|
||||
|
||||
@@ -2291,8 +2246,8 @@ void server_context::update_slots() {
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
||||
|
||||
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
||||
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
||||
auto embd_inp = params_base.spm_infill ? suffix_tokens : prefix_tokens;
|
||||
auto embd_end = params_base.spm_infill ? prefix_tokens : suffix_tokens;
|
||||
if (add_bos) {
|
||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
@@ -2350,7 +2305,7 @@ void server_context::update_slots() {
|
||||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||
// context shift for prompt processing
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
if (!params.ctx_shift) {
|
||||
if (!params_base.ctx_shift) {
|
||||
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
@@ -2389,7 +2344,7 @@ void server_context::update_slots() {
|
||||
else {
|
||||
slot.n_discarded_prompt = 0;
|
||||
}
|
||||
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
|
||||
if (!slot.params.cache_prompt) {
|
||||
slot.n_past_se = 0;
|
||||
@@ -2424,7 +2379,7 @@ void server_context::update_slots() {
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2486,7 +2441,7 @@ void server_context::update_slots() {
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
// TODO: is the system prompt ever in the sampling context?
|
||||
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
}
|
||||
|
||||
LOG_INFO("kv cache rm [p0, end)", {
|
||||
@@ -2546,7 +2501,7 @@ void server_context::update_slots() {
|
||||
}
|
||||
|
||||
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
||||
llama_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
|
||||
common_batch_add(batch, cur_tok, p0, { slot.id }, slot.embedding);
|
||||
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
|
||||
@@ -2571,11 +2526,11 @@ void server_context::update_slots() {
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size());
|
||||
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
common_sampler_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
llama_token id = slot.prompt_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, false);
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2599,51 +2554,107 @@ void server_context::update_slots() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
LOG_VERBOSE("no tokens to decode", {});
|
||||
return;
|
||||
void server_context::extend_context(const int32_t n_tokens) {
|
||||
for (auto& slot : slots) {
|
||||
if (slot.ga_n != 1) {
|
||||
// context extension via Self-Extend
|
||||
// TODO: simplify and/or abstract this
|
||||
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
||||
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
||||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
||||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
||||
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
|
||||
slot.n_past_se -= bd;
|
||||
|
||||
slot.ga_i += slot.ga_w / slot.ga_n;
|
||||
|
||||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
||||
}
|
||||
|
||||
slot.n_past_se += n_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LOG_VERBOSE("decoding batch", {
|
||||
{"n_tokens", batch.n_tokens},
|
||||
});
|
||||
void server_context::speculative_decoding_accept() {
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// make sure we're in the right embedding mode
|
||||
llama_set_embeddings(ctx, batch_type == 1);
|
||||
size_t n_draft = slot.drafted.size();
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.ga_n != 1) {
|
||||
// context extension via Self-Extend
|
||||
// TODO: simplify and/or abstract this
|
||||
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
||||
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
||||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
||||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
||||
slot.n_past += ids.size();
|
||||
slot.n_decoded += ids.size();
|
||||
const int64_t t_current = ggml_time_us();
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
// rollback to the state before sampling the draft tokens
|
||||
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
||||
// add accepted tokens to the prompt
|
||||
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
|
||||
slot.n_past_se -= bd;
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
slot.ga_i += slot.ga_w / slot.ga_n;
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
||||
}
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
||||
}
|
||||
|
||||
slot.n_past_se += n_tokens;
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
break;
|
||||
}
|
||||
}
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
||||
LOG_VERBOSE("speculative decoding result", {
|
||||
{"id_slot", slot.id},
|
||||
{"accepted", (int)ids.size() - 1},
|
||||
{"total", (int)slot.drafted.size()},
|
||||
{"new_n_past", slot.n_past}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool server_context::accept_special_token(const server_slot& slot, const llama_token token) {
|
||||
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
||||
};
|
||||
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
extend_context(n_tokens);
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
@@ -2661,14 +2672,11 @@ void server_context::update_slots() {
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
int user_cancel = -3;
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
if (ret == user_cancel) {
|
||||
LOG_ERROR("Decode process is cancelled by user", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
{"ret", ret},
|
||||
});
|
||||
} else {
|
||||
LLAMA_LOG_INFO("Decode process is cancelled by user.\n");
|
||||
}
|
||||
else {
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
||||
{"i", i},
|
||||
{"n_batch", ret},
|
||||
@@ -2684,12 +2692,9 @@ void server_context::update_slots() {
|
||||
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
}
|
||||
|
||||
}
|
||||
break; // break loop of n_batch
|
||||
}
|
||||
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
n_batch /= 2;
|
||||
i -= n_batch;
|
||||
@@ -2703,10 +2708,6 @@ void server_context::update_slots() {
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
// technically, measuring the time here excludes the sampling time for the last batch
|
||||
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
||||
continue; // continue loop of slots
|
||||
@@ -2725,9 +2726,9 @@ void server_context::update_slots() {
|
||||
continue; // sample using speculative decoding
|
||||
}
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
|
||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
|
||||
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
@@ -2739,15 +2740,14 @@ void server_context::update_slots() {
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
//slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
result.tok = id;
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx);
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
@@ -2761,64 +2761,67 @@ void server_context::update_slots() {
|
||||
}
|
||||
|
||||
// speculative decoding - main model sample and accept
|
||||
for (auto& slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t n_draft = slot.drafted.size();
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
slot.n_past += ids.size();
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
|
||||
// rollback to the state before sampling the draft tokens
|
||||
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
||||
// slot.n_past -= n_draft;
|
||||
// add accepted tokens to the prompt
|
||||
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i);
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
break;
|
||||
}
|
||||
}
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
||||
LOG_VERBOSE("speculative decoding result", {
|
||||
{"id_slot", slot.id},
|
||||
{"accepted", (int)ids.size() - 1},
|
||||
{"total", (int)slot.drafted.size()},
|
||||
{"new_n_past", slot.n_past}
|
||||
});
|
||||
}
|
||||
speculative_decoding_accept();
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::update_slots() {
|
||||
if (system_need_update) {
|
||||
system_prompt_update();
|
||||
}
|
||||
// release slots
|
||||
release_slots();
|
||||
|
||||
// check if all slots are idle
|
||||
if (slots_idle()) {
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
LOG_VERBOSE("posting NEXT_RESPONSE", {});
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
|
||||
task.id_target = -1;
|
||||
|
||||
queue_tasks.post(std::move(task));
|
||||
}
|
||||
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
context_shift();
|
||||
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
// frist, add sampled tokens from any ongoing sequences
|
||||
add_sampled_tokens();
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// track if this is an embedding or non-embedding batch
|
||||
// if we've added sampled tokens above, we are in non-embedding mode
|
||||
// -1: none, 0: non-embedding, 1: embedding
|
||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
batch_pending_prompt(n_ubatch, n_batch, batch_type);
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
LOG_VERBOSE("no tokens to decode", {});
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_VERBOSE("decoding batch", {
|
||||
{"n_tokens", batch.n_tokens},
|
||||
});
|
||||
|
||||
// make sure we're in the right embedding mode
|
||||
llama_set_embeddings(ctx, batch_type == 1);
|
||||
|
||||
// process the created batch of tokens
|
||||
process_batch_tokens(n_batch);
|
||||
|
||||
LOG_VERBOSE("run slots completed", {});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user